diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..45232b80e --- /dev/null +++ b/.clang-format @@ -0,0 +1,161 @@ +--- +Language: Cpp +AlignAfterOpenBracket: Align +AlignArrayOfStructures: Left +AlignConsecutiveAssignments: AcrossComments +AlignConsecutiveBitFields: AcrossComments +AlignConsecutiveDeclarations: AcrossComments +AlignConsecutiveMacros: AcrossComments +# AlignConsecutiveShortCaseStatements: AcrossComments +AlignEscapedNewlines: Left # LeftWithLastLine +AlignOperands: Align +AlignTrailingComments: + Kind: Always + OverEmptyLines: 1 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: Inline +AllowShortLoopsOnASingleLine: false +AlwaysBreakBeforeMultilineStrings: true +BinPackArguments: true +BinPackParameters: true # OnePerLine +BitFieldColonSpacing: Both +BreakBeforeBraces: Custom # Attach +BraceWrapping: + AfterCaseLabel: true + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +# BreakAdjacentStringLiterals: true +BreakAfterAttributes: Never +BreakBeforeBinaryOperators: None +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: false +# BreakBinaryOperations: Never +BreakConstructorInitializers: AfterColon +# BreakFunctionDefinitionParameters: false +BreakInheritanceList: AfterComma +BreakStringLiterals: true +# BreakTemplateDeclarations: Yes +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: false +DerivePointerAlignment: false +DisableFormat: false +EmptyLineBeforeAccessModifier: Leave +EmptyLineAfterAccessModifier: Never +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +IncludeBlocks: Regroup +IncludeCategories: + - Regex: '^<.*\.h>' + Priority: 1 + SortPriority: 0 + - Regex: '^<.*' + Priority: 2 + SortPriority: 0 + - Regex: '.*' + Priority: 3 + SortPriority: 0 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentCaseBlocks: true +IndentCaseLabels: true +IndentExternBlock: NoIndent +IndentGotoLabels: false +IndentPPDirectives: AfterHash +IndentWidth: 4 +IndentWrappedFunctionNames: false +InsertBraces: true # NOTE: may lead to incorrect formatting +InsertNewlineAtEOF: true +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +LambdaBodyIndentation: Signature +LineEnding: LF +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 4 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: true +PPIndentWidth: -1 +PackConstructorInitializers: CurrentLine +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Middle +QualifierAlignment: Left +#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict'] +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' +ReferenceAlignment: Middle +ReflowComments: false # IndentOnly +SeparateDefinitionBlocks: Always +SortIncludes: CaseInsensitive +SortUsingDeclarations: LexicographicNumeric +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: Never +SpacesInContainerLiterals: true +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParentheses: false +SpacesInSquareBrackets: false +SpaceBeforeSquareBrackets: false +Standard: c++17 +TabWidth: 4 +UseTab: Never +WhitespaceSensitiveMacros: ['STRINGIZE'] +... + diff --git a/.clang-tidy b/.clang-tidy index 952c0cca8..310c3d182 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -17,8 +17,10 @@ Checks: > -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling, performance-*, portability-*, + -portability-simd-intrinsics, misc-*, -misc-const-correctness, -misc-non-private-member-variables-in-classes, -misc-no-recursion, + -misc-use-anonymous-namespace, FormatStyle: none diff --git a/.devops/cpu.Dockerfile b/.devops/cpu.Dockerfile new file mode 100644 index 000000000..522ee8147 --- /dev/null +++ b/.devops/cpu.Dockerfile @@ -0,0 +1,92 @@ +ARG UBUNTU_VERSION=22.04 + +FROM ubuntu:$UBUNTU_VERSION AS build + +ARG TARGETARCH + +ARG GGML_CPU_ARM_ARCH=armv8-a + +RUN apt-get update && \ + apt-get install -y build-essential git cmake libcurl4-openssl-dev + +WORKDIR /app + +COPY . . + +RUN if [ "$TARGETARCH" = "amd64" ]; then \ + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \ + elif [ "$TARGETARCH" = "arm64" ]; then \ + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=${GGML_CPU_ARM_ARCH}; \ + else \ + echo "Unsupported architecture"; \ + exit 1; \ + fi && \ + cmake --build build -j $(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ubuntu:$UBUNTU_VERSION AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3 \ + python3-pip \ + && pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/cuda.Dockerfile b/.devops/cuda.Dockerfile new file mode 100644 index 000000000..974dd78a8 --- /dev/null +++ b/.devops/cuda.Dockerfile @@ -0,0 +1,94 @@ +ARG UBUNTU_VERSION=22.04 +# This needs to generally match the container host's environment. +ARG CUDA_VERSION=12.6.0 +# Target the CUDA build image +ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} + +ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} + +FROM ${BASE_CUDA_DEV_CONTAINER} AS build + +# CUDA architecture to build for (defaults to all supported archs) +ARG CUDA_DOCKER_ARCH=default + +RUN apt-get update && \ + apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1 + +WORKDIR /app + +COPY . . + +RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ + export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ + fi && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ${BASE_CUDA_RUN_CONTAINER} AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3 \ + python3-pip \ + && pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/full-cuda.Dockerfile b/.devops/full-cuda.Dockerfile deleted file mode 100644 index d5acd35e2..000000000 --- a/.devops/full-cuda.Dockerfile +++ /dev/null @@ -1,33 +0,0 @@ -ARG UBUNTU_VERSION=22.04 -# This needs to generally match the container host's environment. -ARG CUDA_VERSION=12.6.0 -# Target the CUDA build image -ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} - -FROM ${BASE_CUDA_DEV_CONTAINER} AS build - -# CUDA architecture to build for (defaults to all supported archs) -ARG CUDA_DOCKER_ARCH=default - -RUN apt-get update && \ - apt-get install -y build-essential cmake python3 python3-pip git libcurl4-openssl-dev libgomp1 - -COPY requirements.txt requirements.txt -COPY requirements requirements - -RUN pip install --upgrade pip setuptools wheel \ - && pip install -r requirements.txt - -WORKDIR /app - -COPY . . - -# Use the default CUDA archs if not specified -RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ - export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ - fi && \ - cmake -B build -DGGML_CUDA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ - cmake --build build --config Release -j$(nproc) && \ - cp build/bin/* . - -ENTRYPOINT ["/app/.devops/tools.sh"] diff --git a/.devops/full-rocm.Dockerfile b/.devops/full-rocm.Dockerfile deleted file mode 100644 index 680d1cb92..000000000 --- a/.devops/full-rocm.Dockerfile +++ /dev/null @@ -1,50 +0,0 @@ -ARG UBUNTU_VERSION=22.04 - -# This needs to generally match the container host's environment. -ARG ROCM_VERSION=5.6 - -# Target the CUDA build image -ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete - -FROM ${BASE_ROCM_DEV_CONTAINER} AS build - -# Unless otherwise specified, we make a fat build. -# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 -# This is mostly tied to rocBLAS supported archs. -ARG ROCM_DOCKER_ARCH=\ - gfx803 \ - gfx900 \ - gfx906 \ - gfx908 \ - gfx90a \ - gfx1010 \ - gfx1030 \ - gfx1100 \ - gfx1101 \ - gfx1102 - -COPY requirements.txt requirements.txt -COPY requirements requirements - -RUN pip install --upgrade pip setuptools wheel \ - && pip install -r requirements.txt - -WORKDIR /app - -COPY . . - -# Set nvcc architecture -ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} -# Enable ROCm -ENV GGML_HIPBLAS=1 -ENV CC=/opt/rocm/llvm/bin/clang -ENV CXX=/opt/rocm/llvm/bin/clang++ - -# Enable cURL -ENV LLAMA_CURL=1 -RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev - -RUN make -j$(nproc) - -ENTRYPOINT ["/app/.devops/tools.sh"] diff --git a/.devops/full.Dockerfile b/.devops/full.Dockerfile deleted file mode 100644 index 2a06f82b7..000000000 --- a/.devops/full.Dockerfile +++ /dev/null @@ -1,25 +0,0 @@ -ARG UBUNTU_VERSION=22.04 - -FROM ubuntu:$UBUNTU_VERSION AS build - -RUN apt-get update && \ - apt-get install -y build-essential python3 python3-pip git libcurl4-openssl-dev libgomp1 - -COPY requirements.txt requirements.txt -COPY requirements requirements - -RUN pip install --upgrade pip setuptools wheel \ - && pip install -r requirements.txt - -WORKDIR /app - -COPY . . - -ENV LLAMA_CURL=1 - - -RUN make -j$(nproc) - -ENV LC_ALL=C.utf8 - -ENTRYPOINT ["/app/.devops/tools.sh"] diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile new file mode 100644 index 000000000..af783f5e9 --- /dev/null +++ b/.devops/intel.Dockerfile @@ -0,0 +1,91 @@ +ARG ONEAPI_VERSION=2025.0.0-0-devel-ubuntu22.04 + +## Build Image + +FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build + +ARG GGML_SYCL_F16=OFF +RUN apt-get update && \ + apt-get install -y git libcurl4-openssl-dev + +WORKDIR /app + +COPY . . + +RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ + echo "GGML_SYCL_F16 is set" \ + && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ + fi && \ + echo "Building with dynamic libs" && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_CURL=ON ${OPT_SYCL_F16} && \ + cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +FROM intel/oneapi-basekit:$ONEAPI_VERSION AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +### Full +FROM base AS full + +COPY --from=build /app/lib/ /app +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3 \ + python3-pip \ + && pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/lib/ /app +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/lib/ /app +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] + diff --git a/.devops/llama-cli-cann.Dockerfile b/.devops/llama-cli-cann.Dockerfile index db5ba2f25..02dce501c 100644 --- a/.devops/llama-cli-cann.Dockerfile +++ b/.devops/llama-cli-cann.Dockerfile @@ -1,6 +1,6 @@ ARG ASCEND_VERSION=8.0.rc2.alpha003-910b-openeuler22.03-py3.8 -FROM cosdt/cann:$ASCEND_VERSION AS build +FROM ascendai/cann:$ASCEND_VERSION AS build WORKDIR /app @@ -22,11 +22,11 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH RUN echo "Building with static libs" && \ source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \ - cmake -B build -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF && \ cmake --build build --config Release --target llama-cli # TODO: use image with NNRT -FROM cosdt/cann:$ASCEND_VERSION AS runtime +FROM ascendai/cann:$ASCEND_VERSION AS runtime COPY --from=build /app/build/bin/llama-cli /llama-cli ENV LC_ALL=C.utf8 diff --git a/.devops/llama-cli-cuda.Dockerfile b/.devops/llama-cli-cuda.Dockerfile deleted file mode 100644 index b75163b94..000000000 --- a/.devops/llama-cli-cuda.Dockerfile +++ /dev/null @@ -1,37 +0,0 @@ -ARG UBUNTU_VERSION=22.04 -# This needs to generally match the container host's environment. -ARG CUDA_VERSION=12.6.0 -# Target the CUDA build image -ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} -# Target the CUDA runtime image -ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} - -FROM ${BASE_CUDA_DEV_CONTAINER} AS build - -# CUDA architecture to build for (defaults to all supported archs) -ARG CUDA_DOCKER_ARCH=default - -RUN apt-get update && \ - apt-get install -y build-essential git cmake - -WORKDIR /app - -COPY . . - -# Use the default CUDA archs if not specified -RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ - export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ - fi && \ - cmake -B build -DGGML_CUDA=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ - cmake --build build --config Release --target llama-cli -j$(nproc) - -FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime - -RUN apt-get update && \ - apt-get install -y libgomp1 - -COPY --from=build /app/build/ggml/src/libggml.so /libggml.so -COPY --from=build /app/build/src/libllama.so /libllama.so -COPY --from=build /app/build/bin/llama-cli /llama-cli - -ENTRYPOINT [ "/llama-cli" ] diff --git a/.devops/llama-cli-intel.Dockerfile b/.devops/llama-cli-intel.Dockerfile deleted file mode 100644 index 79dba06a7..000000000 --- a/.devops/llama-cli-intel.Dockerfile +++ /dev/null @@ -1,28 +0,0 @@ -ARG ONEAPI_VERSION=2024.1.1-devel-ubuntu22.04 - -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build - -ARG GGML_SYCL_F16=OFF -RUN apt-get update && \ - apt-get install -y git - -WORKDIR /app - -COPY . . - -RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ - echo "GGML_SYCL_F16 is set" && \ - export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ - fi && \ - echo "Building with static libs" && \ - cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx \ - ${OPT_SYCL_F16} -DBUILD_SHARED_LIBS=OFF && \ - cmake --build build --config Release --target llama-cli - -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS runtime - -COPY --from=build /app/build/bin/llama-cli /llama-cli - -ENV LC_ALL=C.utf8 - -ENTRYPOINT [ "/llama-cli" ] diff --git a/.devops/llama-cli-rocm.Dockerfile b/.devops/llama-cli-rocm.Dockerfile deleted file mode 100644 index c3d1ab067..000000000 --- a/.devops/llama-cli-rocm.Dockerfile +++ /dev/null @@ -1,45 +0,0 @@ -ARG UBUNTU_VERSION=22.04 - -# This needs to generally match the container host's environment. -ARG ROCM_VERSION=5.6 - -# Target the CUDA build image -ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete - -FROM ${BASE_ROCM_DEV_CONTAINER} AS build - -# Unless otherwise specified, we make a fat build. -# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 -# This is mostly tied to rocBLAS supported archs. -ARG ROCM_DOCKER_ARCH=\ - gfx803 \ - gfx900 \ - gfx906 \ - gfx908 \ - gfx90a \ - gfx1010 \ - gfx1030 \ - gfx1100 \ - gfx1101 \ - gfx1102 - -COPY requirements.txt requirements.txt -COPY requirements requirements - -RUN pip install --upgrade pip setuptools wheel \ - && pip install -r requirements.txt - -WORKDIR /app - -COPY . . - -# Set nvcc architecture -ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} -# Enable ROCm -ENV GGML_HIPBLAS=1 -ENV CC=/opt/rocm/llvm/bin/clang -ENV CXX=/opt/rocm/llvm/bin/clang++ - -RUN make -j$(nproc) llama-cli - -ENTRYPOINT [ "/app/llama-cli" ] diff --git a/.devops/llama-cli-vulkan.Dockerfile b/.devops/llama-cli-vulkan.Dockerfile deleted file mode 100644 index 9b0dad8bf..000000000 --- a/.devops/llama-cli-vulkan.Dockerfile +++ /dev/null @@ -1,27 +0,0 @@ -ARG UBUNTU_VERSION=jammy - -FROM ubuntu:$UBUNTU_VERSION AS build - -# Install build tools -RUN apt update && apt install -y git build-essential cmake wget libgomp1 - -# Install Vulkan SDK -RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \ - wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \ - apt update -y && \ - apt-get install -y vulkan-sdk - -# Build it -WORKDIR /app -COPY . . -RUN cmake -B build -DGGML_VULKAN=1 && \ - cmake --build build --config Release --target llama-cli - -# Clean up -WORKDIR / -RUN cp /app/build/bin/llama-cli /llama-cli && \ - rm -rf /app - -ENV LC_ALL=C.utf8 - -ENTRYPOINT [ "/llama-cli" ] diff --git a/.devops/llama-cli.Dockerfile b/.devops/llama-cli.Dockerfile deleted file mode 100644 index 7f741aa46..000000000 --- a/.devops/llama-cli.Dockerfile +++ /dev/null @@ -1,23 +0,0 @@ -ARG UBUNTU_VERSION=22.04 - -FROM ubuntu:$UBUNTU_VERSION AS build - -RUN apt-get update && \ - apt-get install -y build-essential git - -WORKDIR /app - -COPY . . - -RUN make -j$(nproc) llama-cli - -FROM ubuntu:$UBUNTU_VERSION AS runtime - -RUN apt-get update && \ - apt-get install -y libgomp1 - -COPY --from=build /app/llama-cli /llama-cli - -ENV LC_ALL=C.utf8 - -ENTRYPOINT [ "/llama-cli" ] diff --git a/.devops/llama-server-cuda.Dockerfile b/.devops/llama-server-cuda.Dockerfile deleted file mode 100644 index a40e24205..000000000 --- a/.devops/llama-server-cuda.Dockerfile +++ /dev/null @@ -1,42 +0,0 @@ -ARG UBUNTU_VERSION=22.04 -# This needs to generally match the container host's environment. -ARG CUDA_VERSION=12.6.0 -# Target the CUDA build image -ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} -# Target the CUDA runtime image -ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} - -FROM ${BASE_CUDA_DEV_CONTAINER} AS build - -# CUDA architecture to build for (defaults to all supported archs) -ARG CUDA_DOCKER_ARCH=default - -RUN apt-get update && \ - apt-get install -y build-essential git cmake libcurl4-openssl-dev - -WORKDIR /app - -COPY . . - -# Use the default CUDA archs if not specified -RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ - export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ - fi && \ - cmake -B build -DGGML_CUDA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ - cmake --build build --config Release --target llama-server -j$(nproc) - -FROM ${BASE_CUDA_RUN_CONTAINER} AS runtime - -RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev libgomp1 curl - -COPY --from=build /app/build/ggml/src/libggml.so /libggml.so -COPY --from=build /app/build/src/libllama.so /libllama.so -COPY --from=build /app/build/bin/llama-server /llama-server - -# Must be set to 0.0.0.0 so it can listen to requests from host machine -ENV LLAMA_ARG_HOST=0.0.0.0 - -HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] - -ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/llama-server-intel.Dockerfile b/.devops/llama-server-intel.Dockerfile deleted file mode 100644 index 9c355b664..000000000 --- a/.devops/llama-server-intel.Dockerfile +++ /dev/null @@ -1,34 +0,0 @@ -ARG ONEAPI_VERSION=2024.1.1-devel-ubuntu22.04 - -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build - -ARG GGML_SYCL_F16=OFF -RUN apt-get update && \ - apt-get install -y git libcurl4-openssl-dev - -WORKDIR /app - -COPY . . - -RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ - echo "GGML_SYCL_F16 is set" && \ - export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ - fi && \ - echo "Building with dynamic libs" && \ - cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_CURL=ON ${OPT_SYCL_F16} && \ - cmake --build build --config Release --target llama-server - -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS runtime - -RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev curl - -COPY --from=build /app/build/bin/llama-server /llama-server - -ENV LC_ALL=C.utf8 -# Must be set to 0.0.0.0 so it can listen to requests from host machine -ENV LLAMA_ARG_HOST=0.0.0.0 - -HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] - -ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/llama-server-rocm.Dockerfile b/.devops/llama-server-rocm.Dockerfile deleted file mode 100644 index fd0e19ad6..000000000 --- a/.devops/llama-server-rocm.Dockerfile +++ /dev/null @@ -1,54 +0,0 @@ -ARG UBUNTU_VERSION=22.04 - -# This needs to generally match the container host's environment. -ARG ROCM_VERSION=5.6 - -# Target the CUDA build image -ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete - -FROM ${BASE_ROCM_DEV_CONTAINER} AS build - -# Unless otherwise specified, we make a fat build. -# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 -# This is mostly tied to rocBLAS supported archs. -ARG ROCM_DOCKER_ARCH=\ - gfx803 \ - gfx900 \ - gfx906 \ - gfx908 \ - gfx90a \ - gfx1010 \ - gfx1030 \ - gfx1100 \ - gfx1101 \ - gfx1102 - -COPY requirements.txt requirements.txt -COPY requirements requirements - -RUN pip install --upgrade pip setuptools wheel \ - && pip install -r requirements.txt - -WORKDIR /app - -COPY . . - -# Set nvcc architecture -ENV GPU_TARGETS=${ROCM_DOCKER_ARCH} -# Enable ROCm -ENV GGML_HIPBLAS=1 -ENV CC=/opt/rocm/llvm/bin/clang -ENV CXX=/opt/rocm/llvm/bin/clang++ -# Must be set to 0.0.0.0 so it can listen to requests from host machine -ENV LLAMA_ARG_HOST=0.0.0.0 - -# Enable cURL -ENV LLAMA_CURL=1 -RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev curl - -RUN make -j$(nproc) llama-server - -HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] - -ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/llama-server-vulkan.Dockerfile b/.devops/llama-server-vulkan.Dockerfile deleted file mode 100644 index 93c5e0c26..000000000 --- a/.devops/llama-server-vulkan.Dockerfile +++ /dev/null @@ -1,31 +0,0 @@ -ARG UBUNTU_VERSION=jammy - -FROM ubuntu:$UBUNTU_VERSION AS build - -# Install build tools -RUN apt update && apt install -y git build-essential cmake wget - -# Install Vulkan SDK and cURL -RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \ - wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list && \ - apt update -y && \ - apt-get install -y vulkan-sdk libcurl4-openssl-dev curl - -# Build it -WORKDIR /app -COPY . . -RUN cmake -B build -DGGML_VULKAN=1 -DLLAMA_CURL=1 && \ - cmake --build build --config Release --target llama-server - -# Clean up -WORKDIR / -RUN cp /app/build/bin/llama-server /llama-server && \ - rm -rf /app - -ENV LC_ALL=C.utf8 -# Must be set to 0.0.0.0 so it can listen to requests from host machine -ENV LLAMA_ARG_HOST=0.0.0.0 - -HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] - -ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/llama-server.Dockerfile b/.devops/llama-server.Dockerfile deleted file mode 100644 index 02accc85e..000000000 --- a/.devops/llama-server.Dockerfile +++ /dev/null @@ -1,29 +0,0 @@ -ARG UBUNTU_VERSION=22.04 - -FROM ubuntu:$UBUNTU_VERSION AS build - -RUN apt-get update && \ - apt-get install -y build-essential git libcurl4-openssl-dev - -WORKDIR /app - -COPY . . - -ENV LLAMA_CURL=1 - -RUN make -j$(nproc) llama-server - -FROM ubuntu:$UBUNTU_VERSION AS runtime - -RUN apt-get update && \ - apt-get install -y libcurl4-openssl-dev libgomp1 curl - -COPY --from=build /app/llama-server /llama-server - -ENV LC_ALL=C.utf8 -# Must be set to 0.0.0.0 so it can listen to requests from host machine -ENV LLAMA_ARG_HOST=0.0.0.0 - -HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] - -ENTRYPOINT [ "/llama-server" ] diff --git a/.devops/musa.Dockerfile b/.devops/musa.Dockerfile new file mode 100644 index 000000000..bfd7fc1c1 --- /dev/null +++ b/.devops/musa.Dockerfile @@ -0,0 +1,108 @@ +ARG UBUNTU_VERSION=22.04 +# This needs to generally match the container host's environment. +ARG MUSA_VERSION=rc3.1.0 +# Target the MUSA build image +ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION} + +ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} + +FROM ${BASE_MUSA_DEV_CONTAINER} AS build + +# MUSA architecture to build for (defaults to all supported archs) +ARG MUSA_DOCKER_ARCH=default + +RUN apt-get update && \ + apt-get install -y \ + build-essential \ + cmake \ + python3 \ + python3-pip \ + git \ + libcurl4-openssl-dev \ + libgomp1 + +COPY requirements.txt requirements.txt +COPY requirements requirements + +RUN pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt + +WORKDIR /app + +COPY . . + +# Use the default MUSA archs if not specified +RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \ + export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \ + fi && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_CURL=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ${BASE_MUSA_RUN_CONTAINER} AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3 \ + python3-pip \ + && pip install --upgrade pip setuptools wheel \ + && pip install -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/nix/package.nix b/.devops/nix/package.nix index 5d7d7ea5a..043c4364b 100644 --- a/.devops/nix/package.nix +++ b/.devops/nix/package.nix @@ -31,6 +31,7 @@ # Increases the runtime closure size by ~700M useMpi ? false, useRocm ? config.rocmSupport, + rocmGpuTargets ? builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets, enableCurl ? true, useVulkan ? false, llamaVersion ? "0.0.0", # Arbitrary version, substituted by the flake @@ -126,9 +127,9 @@ effectiveStdenv.mkDerivation (finalAttrs: { }; postPatch = '' - substituteInPlace ./ggml/src/ggml-metal.m \ + substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \ --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";" - substituteInPlace ./ggml/src/ggml-metal.m \ + substituteInPlace ./ggml/src/ggml-metal/ggml-metal.m \ --replace '[bundle pathForResource:@"default" ofType:@"metallib"];' "@\"$out/bin/default.metallib\";" ''; @@ -173,7 +174,7 @@ effectiveStdenv.mkDerivation (finalAttrs: { (cmakeBool "GGML_NATIVE" false) (cmakeBool "GGML_BLAS" useBlas) (cmakeBool "GGML_CUDA" useCuda) - (cmakeBool "GGML_HIPBLAS" useRocm) + (cmakeBool "GGML_HIP" useRocm) (cmakeBool "GGML_METAL" useMetalKit) (cmakeBool "GGML_VULKAN" useVulkan) (cmakeBool "GGML_STATIC" enableStatic) @@ -188,7 +189,7 @@ effectiveStdenv.mkDerivation (finalAttrs: { ] ++ optionals useRocm [ (cmakeFeature "CMAKE_HIP_COMPILER" "${rocmPackages.llvm.clang}/bin/clang") - (cmakeFeature "CMAKE_HIP_ARCHITECTURES" (builtins.concatStringsSep ";" rocmPackages.clr.gpuTargets)) + (cmakeFeature "CMAKE_HIP_ARCHITECTURES" rocmGpuTargets) ] ++ optionals useMetalKit [ (lib.cmakeFeature "CMAKE_C_FLAGS" "-D__ARM_FEATURE_DOTPROD=1") diff --git a/.devops/nix/python-scripts.nix b/.devops/nix/python-scripts.nix index 392e9ffe4..56ea18278 100644 --- a/.devops/nix/python-scripts.nix +++ b/.devops/nix/python-scripts.nix @@ -34,7 +34,7 @@ let # server tests openai - behave + pytest prometheus-client ]; in diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile new file mode 100644 index 000000000..a8088ea00 --- /dev/null +++ b/.devops/rocm.Dockerfile @@ -0,0 +1,113 @@ +ARG UBUNTU_VERSION=24.04 + +# This needs to generally match the container host's environment. +ARG ROCM_VERSION=6.3 +ARG AMDGPU_VERSION=6.3 + +# Target the CUDA build image +ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete + +### Build image +FROM ${BASE_ROCM_DEV_CONTAINER} AS build + +# Unless otherwise specified, we make a fat build. +# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878 +# This is mostly tied to rocBLAS supported archs. +# gfx803, gfx900, gfx1032, gfx1101, gfx1102,not officialy supported +# gfx906 is deprecated +#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.2.4/reference/system-requirements.html + +#ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102' +ARG ROCM_DOCKER_ARCH=gfx1100 + +# Set nvcc architectured +ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} +# Enable ROCm +# ENV CC=/opt/rocm/llvm/bin/clang +# ENV CXX=/opt/rocm/llvm/bin/clang++ + +RUN apt-get update \ + && apt-get install -y \ + build-essential \ + cmake \ + git \ + libcurl4-openssl-dev \ + curl \ + libgomp1 + +WORKDIR /app + +COPY . . + +RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ + cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON \ + && cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib \ + && find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ${BASE_ROCM_DEV_CONTAINER} AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl\ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3-pip \ + python3 \ + python3-wheel\ + && pip install --break-system-packages --upgrade setuptools \ + && pip install --break-system-packages -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.devops/tools.sh b/.devops/tools.sh index 24dcfd350..41a6b1e55 100755 --- a/.devops/tools.sh +++ b/.devops/tools.sh @@ -8,28 +8,36 @@ arg1="$1" shift if [[ "$arg1" == '--convert' || "$arg1" == '-c' ]]; then - python3 ./convert_hf_to_gguf.py "$@" + exec python3 ./convert_hf_to_gguf.py "$@" elif [[ "$arg1" == '--quantize' || "$arg1" == '-q' ]]; then - ./llama-quantize "$@" + exec ./llama-quantize "$@" elif [[ "$arg1" == '--run' || "$arg1" == '-r' ]]; then - ./llama-cli "$@" + exec ./llama-cli "$@" +elif [[ "$arg1" == '--bench' || "$arg1" == '-b' ]]; then + exec ./llama-bench "$@" +elif [[ "$arg1" == '--perplexity' || "$arg1" == '-p' ]]; then + exec ./llama-perplexity "$@" elif [[ "$arg1" == '--all-in-one' || "$arg1" == '-a' ]]; then echo "Converting PTH to GGML..." - for i in `ls $1/$2/ggml-model-f16.bin*`; do + for i in $(ls $1/$2/ggml-model-f16.bin*); do if [ -f "${i/f16/q4_0}" ]; then echo "Skip model quantization, it already exists: ${i/f16/q4_0}" else echo "Converting PTH to GGML: $i into ${i/f16/q4_0}..." - ./llama-quantize "$i" "${i/f16/q4_0}" q4_0 + exec ./llama-quantize "$i" "${i/f16/q4_0}" q4_0 fi done elif [[ "$arg1" == '--server' || "$arg1" == '-s' ]]; then - ./llama-server "$@" + exec ./llama-server "$@" else echo "Unknown command: $arg1" echo "Available commands: " echo " --run (-r): Run a model previously converted into ggml" echo " ex: -m /models/7B/ggml-model-q4_0.bin -p \"Building a website can be done in 10 simple steps:\" -n 512" + echo " --bench (-b): Benchmark the performance of the inference for various parameters." + echo " ex: -m model.gguf" + echo " --perplexity (-p): Measure the perplexity of a model over a given text." + echo " ex: -m model.gguf -f file.txt" echo " --convert (-c): Convert a llama model into ggml" echo " ex: --outtype f16 \"/models/7B/\" " echo " --quantize (-q): Optimize with quantization process ggml" diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile new file mode 100644 index 000000000..9064f3838 --- /dev/null +++ b/.devops/vulkan.Dockerfile @@ -0,0 +1,89 @@ +ARG UBUNTU_VERSION=24.04 + +FROM ubuntu:$UBUNTU_VERSION AS build + +# Install build tools +RUN apt update && apt install -y git build-essential cmake wget + +# Install Vulkan SDK and cURL +RUN wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add - && \ + wget -qO /etc/apt/sources.list.d/lunarg-vulkan-noble.list https://packages.lunarg.com/vulkan/lunarg-vulkan-noble.list && \ + apt update -y && \ + apt-get install -y vulkan-sdk libcurl4-openssl-dev curl + +# Build it +WORKDIR /app + +COPY . . + +RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_CURL=1 && \ + cmake --build build --config Release -j$(nproc) + +RUN mkdir -p /app/lib && \ + find build -name "*.so" -exec cp {} /app/lib \; + +RUN mkdir -p /app/full \ + && cp build/bin/* /app/full \ + && cp *.py /app/full \ + && cp -r gguf-py /app/full \ + && cp -r requirements /app/full \ + && cp requirements.txt /app/full \ + && cp .devops/tools.sh /app/full/tools.sh + +## Base image +FROM ubuntu:$UBUNTU_VERSION AS base + +RUN apt-get update \ + && apt-get install -y libgomp1 curl libvulkan-dev \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +COPY --from=build /app/lib/ /app + +### Full +FROM base AS full + +COPY --from=build /app/full /app + +WORKDIR /app + +RUN apt-get update \ + && apt-get install -y \ + git \ + python3 \ + python3-pip \ + python3-wheel \ + && pip install --break-system-packages --upgrade setuptools \ + && pip install --break-system-packages -r requirements.txt \ + && apt autoremove -y \ + && apt clean -y \ + && rm -rf /tmp/* /var/tmp/* \ + && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ + && find /var/cache -type f -delete + +ENTRYPOINT ["/app/tools.sh"] + +### Light, CLI only +FROM base AS light + +COPY --from=build /app/full/llama-cli /app + +WORKDIR /app + +ENTRYPOINT [ "/app/llama-cli" ] + +### Server, Server only +FROM base AS server + +ENV LLAMA_ARG_HOST=0.0.0.0 + +COPY --from=build /app/full/llama-server /app + +WORKDIR /app + +HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ] + +ENTRYPOINT [ "/app/llama-server" ] diff --git a/.dockerignore b/.dockerignore index 8916e2a66..064b7c7be 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,7 +1,7 @@ *.o *.a .cache/ -.git/ +# Do not ignore .git directory, otherwise the reported build number will always be 0 .github/ .gitignore .vs/ diff --git a/.editorconfig b/.editorconfig index f88f8da67..5d63d0a51 100644 --- a/.editorconfig +++ b/.editorconfig @@ -24,9 +24,27 @@ insert_final_newline = unset [examples/server/public/*] indent_size = 2 +[examples/server/public/deps_*] +trim_trailing_whitespace = unset +indent_style = unset +indent_size = unset + +[examples/server/deps_*] +trim_trailing_whitespace = unset +indent_style = unset +indent_size = unset + [examples/llama.swiftui/llama.swiftui.xcodeproj/*] indent_style = tab [examples/cvector-generator/*.txt] trim_trailing_whitespace = unset insert_final_newline = unset + +[models/templates/*.jinja] +indent_style = unset +indent_size = unset +end_of_line = unset +charset = unset +trim_trailing_whitespace = unset +insert_final_newline = unset diff --git a/.github/ISSUE_TEMPLATE/01-bug-low.yml b/.github/ISSUE_TEMPLATE/01-bug-low.yml deleted file mode 100644 index 54785854f..000000000 --- a/.github/ISSUE_TEMPLATE/01-bug-low.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Low Severity Bugs -description: Used to report low severity bugs in llama.cpp (e.g. cosmetic issues, non critical UI glitches) -title: "Bug: " -labels: ["bug-unconfirmed", "low severity"] -body: - - type: markdown - attributes: - value: | - Thanks for taking the time to fill out this bug report! - Please include information about your system, the steps to reproduce the bug, - and the version of llama.cpp that you are using. - If possible, please provide a minimal code example that reproduces the bug. - - type: textarea - id: what-happened - attributes: - label: What happened? - description: Also tell us, what did you expect to happen? - placeholder: Tell us what you see! - validations: - required: true - - type: textarea - id: version - attributes: - label: Name and Version - description: Which executable and which version of our software are you running? (use `--version` to get a version string) - placeholder: | - $./llama-cli --version - version: 2999 (42b4109e) - built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu - validations: - required: true - - type: dropdown - id: operating-system - attributes: - label: What operating system are you seeing the problem on? - multiple: true - options: - - Linux - - Mac - - Windows - - BSD - - Other? (Please let us know in description) - validations: - required: false - - type: textarea - id: logs - attributes: - label: Relevant log output - description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. - render: shell diff --git a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml new file mode 100644 index 000000000..b85bf5741 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml @@ -0,0 +1,87 @@ +name: Bug (compilation) +description: Something goes wrong when trying to compile llama.cpp. +title: "Compile bug: " +labels: ["bug-unconfirmed", "compilation"] +body: + - type: markdown + attributes: + value: > + Thanks for taking the time to fill out this bug report! + This issue template is intended for bug reports where the compilation of llama.cpp fails. + Before opening an issue, please confirm that the compilation still fails with `-DGGML_CCACHE=OFF`. + If the compilation succeeds with ccache disabled you should be able to permanently fix the issue + by clearing `~/.cache/ccache` (on Linux). + - type: textarea + id: commit + attributes: + label: Git commit + description: Which commit are you trying to compile? + placeholder: | + $git rev-parse HEAD + 84a07a17b1b08cf2b9747c633a2372782848a27f + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: Operating systems + description: Which operating systems do you know to be affected? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: true + - type: dropdown + id: backends + attributes: + label: GGML backends + description: Which GGML backends do you know to be affected? + options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan] + multiple: true + validations: + required: true + - type: textarea + id: info + attributes: + label: Problem description & steps to reproduce + description: > + Please give us a summary of the problem and tell us how to reproduce it. + If you can narrow down the bug to specific compile flags, that information would be very much appreciated by us. + placeholder: > + I'm trying to compile llama.cpp with CUDA support on a fresh install of Ubuntu and get error XY. + Here are the exact commands that I used: ... + validations: + required: true + - type: textarea + id: first_bad_commit + attributes: + label: First Bad Commit + description: > + If the bug was not present on an earlier version: when did it start appearing? + If possible, please do a git bisect and identify the exact commit that introduced the bug. + validations: + required: false + - type: textarea + id: command + attributes: + label: Compile command + description: > + Please provide the exact command you used to compile llama.cpp. For example: `cmake -B ...`. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: true + - type: textarea + id: logs + attributes: + label: Relevant log output + description: > + Please copy and paste any relevant log output, including any generated text. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/011-bug-results.yml b/.github/ISSUE_TEMPLATE/011-bug-results.yml new file mode 100644 index 000000000..1ccef0793 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/011-bug-results.yml @@ -0,0 +1,101 @@ +name: Bug (model use) +description: Something goes wrong when using a model (in general, not specific to a single llama.cpp module). +title: "Eval bug: " +labels: ["bug-unconfirmed", "model evaluation"] +body: + - type: markdown + attributes: + value: > + Thanks for taking the time to fill out this bug report! + This issue template is intended for bug reports where the model evaluation results + (i.e. the generated text) are incorrect or llama.cpp crashes during model evaluation. + If you encountered the issue while using an external UI (e.g. ollama), + please reproduce your issue using one of the examples/binaries in this repository. + The `llama-cli` binary can be used for simple and reproducible model inference. + - type: textarea + id: version + attributes: + label: Name and Version + description: Which version of our software are you running? (use `--version` to get a version string) + placeholder: | + $./llama-cli --version + version: 2999 (42b4109e) + built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: Operating systems + description: Which operating systems do you know to be affected? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: true + - type: dropdown + id: backends + attributes: + label: GGML backends + description: Which GGML backends do you know to be affected? + options: [AMX, BLAS, CPU, CUDA, HIP, Kompute, Metal, Musa, RPC, SYCL, Vulkan] + multiple: true + validations: + required: true + - type: textarea + id: hardware + attributes: + label: Hardware + description: Which CPUs/GPUs are you using? + placeholder: > + e.g. Ryzen 5950X + 2x RTX 4090 + validations: + required: true + - type: textarea + id: model + attributes: + label: Models + description: > + Which model(s) at which quantization were you using when encountering the bug? + If you downloaded a GGUF file off of Huggingface, please provide a link. + placeholder: > + e.g. Meta LLaMA 3.1 Instruct 8b q4_K_M + validations: + required: false + - type: textarea + id: info + attributes: + label: Problem description & steps to reproduce + description: > + Please give us a summary of the problem and tell us how to reproduce it. + If you can narrow down the bug to specific hardware, compile flags, or command line arguments, + that information would be very much appreciated by us. + placeholder: > + e.g. when I run llama-cli with -ngl 99 I get garbled outputs. + When I use -ngl 0 it works correctly. + Here are the exact commands that I used: ... + validations: + required: true + - type: textarea + id: first_bad_commit + attributes: + label: First Bad Commit + description: > + If the bug was not present on an earlier version: when did it start appearing? + If possible, please do a git bisect and identify the exact commit that introduced the bug. + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: > + Please copy and paste any relevant log output, including the command that you entered and any generated text. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/019-bug-misc.yml b/.github/ISSUE_TEMPLATE/019-bug-misc.yml new file mode 100644 index 000000000..1904e31fd --- /dev/null +++ b/.github/ISSUE_TEMPLATE/019-bug-misc.yml @@ -0,0 +1,91 @@ +name: Bug (misc.) +description: Something is not working the way it should (and it's not covered by any of the above cases). +title: "Misc. bug: " +labels: ["bug-unconfirmed"] +body: + - type: markdown + attributes: + value: > + Thanks for taking the time to fill out this bug report! + This issue template is intended for miscellaneous bugs that don't fit into any other category. + If you encountered the issue while using an external UI (e.g. ollama), + please reproduce your issue using one of the examples/binaries in this repository. + - type: textarea + id: version + attributes: + label: Name and Version + description: Which version of our software is affected? (You can use `--version` to get a version string.) + placeholder: | + $./llama-cli --version + version: 2999 (42b4109e) + built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu + validations: + required: true + - type: dropdown + id: operating-system + attributes: + label: Operating systems + description: Which operating systems do you know to be affected? + multiple: true + options: + - Linux + - Mac + - Windows + - BSD + - Other? (Please let us know in description) + validations: + required: false + - type: dropdown + id: module + attributes: + label: Which llama.cpp modules do you know to be affected? + multiple: true + options: + - Documentation/Github + - libllama (core library) + - llama-cli + - llama-server + - llama-bench + - llama-quantize + - Python/Bash scripts + - Test code + - Other (Please specify in the next section) + validations: + required: false + - type: textarea + id: command + attributes: + label: Command line + description: > + Please provide the exact commands you entered, if applicable. For example: `llama-server -m ... -c ...`, `llama-cli -m ...`, etc. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: false + - type: textarea + id: info + attributes: + label: Problem description & steps to reproduce + description: > + Please give us a summary of the problem and tell us how to reproduce it (if applicable). + validations: + required: true + - type: textarea + id: first_bad_commit + attributes: + label: First Bad Commit + description: > + If the bug was not present on an earlier version and it's not trivial to track down: when did it start appearing? + If possible, please do a git bisect and identify the exact commit that introduced the bug. + validations: + required: false + - type: textarea + id: logs + attributes: + label: Relevant log output + description: > + If applicable, please copy and paste any relevant log output, including any generated text. + This will be automatically formatted into code, so no need for backticks. + render: shell + validations: + required: false diff --git a/.github/ISSUE_TEMPLATE/02-bug-medium.yml b/.github/ISSUE_TEMPLATE/02-bug-medium.yml deleted file mode 100644 index a6285c6f0..000000000 --- a/.github/ISSUE_TEMPLATE/02-bug-medium.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Medium Severity Bug -description: Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but generally still useable) -title: "Bug: " -labels: ["bug-unconfirmed", "medium severity"] -body: - - type: markdown - attributes: - value: | - Thanks for taking the time to fill out this bug report! - Please include information about your system, the steps to reproduce the bug, - and the version of llama.cpp that you are using. - If possible, please provide a minimal code example that reproduces the bug. - - type: textarea - id: what-happened - attributes: - label: What happened? - description: Also tell us, what did you expect to happen? - placeholder: Tell us what you see! - validations: - required: true - - type: textarea - id: version - attributes: - label: Name and Version - description: Which executable and which version of our software are you running? (use `--version` to get a version string) - placeholder: | - $./llama-cli --version - version: 2999 (42b4109e) - built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu - validations: - required: true - - type: dropdown - id: operating-system - attributes: - label: What operating system are you seeing the problem on? - multiple: true - options: - - Linux - - Mac - - Windows - - BSD - - Other? (Please let us know in description) - validations: - required: false - - type: textarea - id: logs - attributes: - label: Relevant log output - description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. - render: shell diff --git a/.github/ISSUE_TEMPLATE/05-enhancement.yml b/.github/ISSUE_TEMPLATE/020-enhancement.yml similarity index 97% rename from .github/ISSUE_TEMPLATE/05-enhancement.yml rename to .github/ISSUE_TEMPLATE/020-enhancement.yml index 58fca7318..02dd4f575 100644 --- a/.github/ISSUE_TEMPLATE/05-enhancement.yml +++ b/.github/ISSUE_TEMPLATE/020-enhancement.yml @@ -1,5 +1,5 @@ name: Enhancement -description: Used to request enhancements for llama.cpp +description: Used to request enhancements for llama.cpp. title: "Feature Request: " labels: ["enhancement"] body: diff --git a/.github/ISSUE_TEMPLATE/03-bug-high.yml b/.github/ISSUE_TEMPLATE/03-bug-high.yml deleted file mode 100644 index ff816b937..000000000 --- a/.github/ISSUE_TEMPLATE/03-bug-high.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: High Severity Bug -description: Used to report high severity bugs in llama.cpp (e.g. Malfunctioning features hindering important common workflow) -title: "Bug: " -labels: ["bug-unconfirmed", "high severity"] -body: - - type: markdown - attributes: - value: | - Thanks for taking the time to fill out this bug report! - Please include information about your system, the steps to reproduce the bug, - and the version of llama.cpp that you are using. - If possible, please provide a minimal code example that reproduces the bug. - - type: textarea - id: what-happened - attributes: - label: What happened? - description: Also tell us, what did you expect to happen? - placeholder: Tell us what you see! - validations: - required: true - - type: textarea - id: version - attributes: - label: Name and Version - description: Which executable and which version of our software are you running? (use `--version` to get a version string) - placeholder: | - $./llama-cli --version - version: 2999 (42b4109e) - built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu - validations: - required: true - - type: dropdown - id: operating-system - attributes: - label: What operating system are you seeing the problem on? - multiple: true - options: - - Linux - - Mac - - Windows - - BSD - - Other? (Please let us know in description) - validations: - required: false - - type: textarea - id: logs - attributes: - label: Relevant log output - description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. - render: shell diff --git a/.github/ISSUE_TEMPLATE/06-research.yml b/.github/ISSUE_TEMPLATE/030-research.yml similarity index 97% rename from .github/ISSUE_TEMPLATE/06-research.yml rename to .github/ISSUE_TEMPLATE/030-research.yml index 3ae4e9f8c..18975dbbf 100644 --- a/.github/ISSUE_TEMPLATE/06-research.yml +++ b/.github/ISSUE_TEMPLATE/030-research.yml @@ -1,5 +1,5 @@ name: Research -description: Track new technical research area +description: Track new technical research area. title: "Research: " labels: ["research 🔬"] body: diff --git a/.github/ISSUE_TEMPLATE/04-bug-critical.yml b/.github/ISSUE_TEMPLATE/04-bug-critical.yml deleted file mode 100644 index 7af42a80b..000000000 --- a/.github/ISSUE_TEMPLATE/04-bug-critical.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Critical Severity Bug -description: Used to report critical severity bugs in llama.cpp (e.g. Crashing, Corrupted, Dataloss) -title: "Bug: " -labels: ["bug-unconfirmed", "critical severity"] -body: - - type: markdown - attributes: - value: | - Thanks for taking the time to fill out this bug report! - Please include information about your system, the steps to reproduce the bug, - and the version of llama.cpp that you are using. - If possible, please provide a minimal code example that reproduces the bug. - - type: textarea - id: what-happened - attributes: - label: What happened? - description: Also tell us, what did you expect to happen? - placeholder: Tell us what you see! - validations: - required: true - - type: textarea - id: version - attributes: - label: Name and Version - description: Which executable and which version of our software are you running? (use `--version` to get a version string) - placeholder: | - $./llama-cli --version - version: 2999 (42b4109e) - built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu - validations: - required: true - - type: dropdown - id: operating-system - attributes: - label: What operating system are you seeing the problem on? - multiple: true - options: - - Linux - - Mac - - Windows - - BSD - - Other? (Please let us know in description) - validations: - required: false - - type: textarea - id: logs - attributes: - label: Relevant log output - description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. - render: shell diff --git a/.github/ISSUE_TEMPLATE/07-refactor.yml b/.github/ISSUE_TEMPLATE/040-refactor.yml similarity index 95% rename from .github/ISSUE_TEMPLATE/07-refactor.yml rename to .github/ISSUE_TEMPLATE/040-refactor.yml index 3a68d3d53..b6e6ab36d 100644 --- a/.github/ISSUE_TEMPLATE/07-refactor.yml +++ b/.github/ISSUE_TEMPLATE/040-refactor.yml @@ -1,5 +1,5 @@ name: Refactor (Maintainers) -description: Used to track refactoring opportunities +description: Used to track refactoring opportunities. title: "Refactor: " labels: ["refactor"] body: diff --git a/.github/labeler.yml b/.github/labeler.yml index 89436740d..1b47bc968 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -3,19 +3,18 @@ Kompute: - changed-files: - any-glob-to-any-file: - ggml/include/ggml-kompute.h - - ggml/src/ggml-kompute.cpp + - ggml/src/ggml-kompute/** - README-kompute.md Apple Metal: - changed-files: - any-glob-to-any-file: - ggml/include/ggml-metal.h - - ggml/src/ggml-metal.cpp + - ggml/src/ggml-metal/** - README-metal.md SYCL: - changed-files: - any-glob-to-any-file: - ggml/include/ggml-sycl.h - - ggml/src/ggml-sycl.cpp - ggml/src/ggml-sycl/** - docs/backend/SYCL.md - examples/sycl/** @@ -27,8 +26,8 @@ Nvidia GPU: Vulkan: - changed-files: - any-glob-to-any-file: - - ggml/ggml_vk_generate_shaders.py - - ggml/src/ggml-vulkan* + - ggml/include/ggml-vulkan.h + - ggml/src/ggml-vulkan/** documentation: - changed-files: - any-glob-to-any-file: @@ -75,11 +74,7 @@ server: ggml: - changed-files: - any-glob-to-any-file: - - ggml/include/ggml*.h - - ggml/src/ggml*.c - - ggml/src/ggml*.cpp - - ggml/src/ggml*.h - - ggml-cuda/** + - ggml/** nix: - changed-files: - any-glob-to-any-file: diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 997c6d9d0..d9f5bdc23 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,7 +1 @@ - - -- [x] I have read the [contributing guidelines](https://github.com/ggerganov/llama.cpp/blob/master/CONTRIBUTING.md) -- Self-reported review complexity: - - [ ] Low - - [ ] Medium - - [ ] High +*Make sure to read the [contributing guidelines](https://github.com/ggerganov/llama.cpp/blob/master/CONTRIBUTING.md) before submitting a PR* diff --git a/.github/workflows/bench.yml.disabled b/.github/workflows/bench.yml.disabled index bfdbb4ef5..1c8787ef7 100644 --- a/.github/workflows/bench.yml.disabled +++ b/.github/workflows/bench.yml.disabled @@ -27,10 +27,10 @@ on: push: branches: - master - paths: ['llama.cpp', 'ggml.c', 'ggml-backend.c', 'ggml-quants.c', '**/*.cu', 'examples/server/*.h*', 'examples/server/*.cpp'] + paths: ['llama.cpp', 'ggml.c', 'ggml-backend.cpp', 'ggml-quants.c', '**/*.cu', 'examples/server/*.h*', 'examples/server/*.cpp'] pull_request_target: types: [opened, synchronize, reopened] - paths: ['llama.cpp', 'ggml.c', 'ggml-backend.c', 'ggml-quants.c', '**/*.cu', 'examples/server/*.h*', 'examples/server/*.cpp'] + paths: ['llama.cpp', 'ggml.c', 'ggml-backend.cpp', 'ggml-quants.c', '**/*.cu', 'examples/server/*.h*', 'examples/server/*.cpp'] schedule: - cron: '04 2 * * *' diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c36eaadfb..c02dd6a81 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,10 +19,18 @@ concurrency: group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} cancel-in-progress: true +# Fine-grant permission +# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token +permissions: + contents: write # for creating release + env: BRANCH_NAME: ${{ github.head_ref || github.ref_name }} GGML_NLOOP: 3 GGML_N_THREADS: 1 + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 jobs: macOS-latest-cmake-arm64: @@ -35,6 +43,12 @@ jobs: with: fetch-depth: 0 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-arm64 + evict-old-files: 1d + - name: Dependencies id: depends continue-on-error: true @@ -47,7 +61,13 @@ jobs: sysctl -a mkdir build cd build - cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL_EMBED_LIBRARY=ON -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF .. + cmake .. \ + -DCMAKE_BUILD_RPATH="@loader_path" \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_CURL=ON \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DGGML_RPC=ON cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) - name: Test @@ -74,6 +94,7 @@ jobs: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} run: | cp LICENSE ./build/bin/ + cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/* - name: Upload artifacts @@ -84,7 +105,7 @@ jobs: name: llama-bin-macos-arm64.zip macOS-latest-cmake-x64: - runs-on: macos-12 + runs-on: macos-13 steps: - name: Clone @@ -93,6 +114,12 @@ jobs: with: fetch-depth: 0 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-x64 + evict-old-files: 1d + - name: Dependencies id: depends continue-on-error: true @@ -105,7 +132,12 @@ jobs: sysctl -a # Metal is disabled due to intermittent failures with Github runners not having a GPU: # https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 - cmake -B build -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF + cmake -B build \ + -DCMAKE_BUILD_RPATH="@loader_path" \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_CURL=ON \ + -DGGML_METAL=OFF \ + -DGGML_RPC=ON cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) - name: Test @@ -132,6 +164,7 @@ jobs: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} run: | cp LICENSE ./build/bin/ + cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/* - name: Upload artifacts @@ -141,68 +174,8 @@ jobs: path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip name: llama-bin-macos-x64.zip - ubuntu-focal-make: - runs-on: ubuntu-20.04 - env: - LLAMA_NODE_AVAILABLE: true - LLAMA_PYTHON_AVAILABLE: true - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential gcc-8 - - - uses: actions/setup-node@v4 - with: - node-version: "20" - - - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - - name: Build - id: make_build - env: - LLAMA_FATAL_WARNINGS: 1 - run: | - CC=gcc-8 make -j $(nproc) - - - name: Test - id: make_test - run: | - CC=gcc-8 make tests -j $(nproc) - make test -j $(nproc) - - ubuntu-focal-make-curl: - runs-on: ubuntu-20.04 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential gcc-8 libcurl4-openssl-dev - - - name: Build - id: make_build - env: - LLAMA_FATAL_WARNINGS: 1 - LLAMA_CURL: 1 - run: | - CC=gcc-8 make -j $(nproc) - - ubuntu-latest-cmake: - runs-on: ubuntu-latest + ubuntu-cpu-cmake: + runs-on: ubuntu-22.04 steps: - name: Clone @@ -211,6 +184,12 @@ jobs: with: fetch-depth: 0 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-cpu-cmake + evict-old-files: 1d + - name: Dependencies id: depends run: | @@ -222,7 +201,10 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_CURL=ON \ + -DGGML_RPC=ON cmake --build . --config Release -j $(nproc) - name: Test @@ -260,6 +242,7 @@ jobs: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} run: | cp LICENSE ./build/bin/ + cp examples/run/linenoise.cpp/LICENSE ./build/bin/LICENSE.linenoise.cpp zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.zip ./build/bin/* - name: Upload artifacts @@ -277,13 +260,19 @@ jobs: strategy: matrix: sanitizer: [ADDRESS, THREAD, UNDEFINED] - build_type: [Debug, Release] + build_type: [Debug] steps: - name: Clone id: checkout uses: actions/checkout@v4 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-latest-cmake-sanitizer-${{ matrix.sanitizer }} + evict-old-files: 1d + - name: Dependencies id: depends run: | @@ -296,7 +285,10 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} cmake --build . --config ${{ matrix.build_type }} -j $(nproc) - name: Build (no OpenMP) @@ -305,7 +297,11 @@ jobs: run: | mkdir build cd build - cmake .. -DLLAMA_FATAL_WARNINGS=ON -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DGGML_OPENMP=OFF + cmake .. \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ + -DGGML_OPENMP=OFF cmake --build . --config ${{ matrix.build_type }} -j $(nproc) - name: Test @@ -324,6 +320,12 @@ jobs: id: checkout uses: actions/checkout@v4 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-latest-cmake-rpc + evict-old-files: 1d + - name: Dependencies id: depends run: | @@ -335,7 +337,8 @@ jobs: run: | mkdir build cd build - cmake -DGGML_RPC=ON .. + cmake .. \ + -DGGML_RPC=ON cmake --build . --config Release -j $(nproc) - name: Test @@ -352,22 +355,36 @@ jobs: id: checkout uses: actions/checkout@v4 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-vulkan + evict-old-files: 1d + - name: Dependencies id: depends run: | wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add - sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list sudo apt-get update -y - sudo apt-get install -y build-essential vulkan-sdk + sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk - name: Build id: cmake_build run: | mkdir build cd build - cmake -DGGML_VULKAN=ON .. + cmake .. \ + -DGGML_VULKAN=ON cmake --build . --config Release -j $(nproc) + - name: Test + id: cmake_test + run: | + cd build + # This is using llvmpipe and runs slower than other backends + ctest -L main --verbose --timeout 1800 + ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 container: rocm/dev-ubuntu-22.04:6.0.2 @@ -375,7 +392,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Dependencies id: depends @@ -383,25 +400,64 @@ jobs: sudo apt-get update sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-hip + evict-old-files: 1d + - name: Build with native CMake HIP support id: cmake_build run: | - cmake -B build -S . -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" -DGGML_HIPBLAS=ON + cmake -B build -S . \ + -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \ + -DGGML_HIP=ON cmake --build build --config Release -j $(nproc) - name: Build with legacy HIP support id: cmake_build_legacy_hip run: | - cmake -B build2 -S . -DCMAKE_C_COMPILER=hipcc -DCMAKE_CXX_COMPILER=hipcc -DGGML_HIPBLAS=ON + cmake -B build2 -S . \ + -DCMAKE_C_COMPILER=hipcc \ + -DCMAKE_CXX_COMPILER=hipcc \ + -DGGML_HIP=ON cmake --build build2 --config Release -j $(nproc) + ubuntu-22-cmake-musa: + runs-on: ubuntu-22.04 + container: mthreads/musa:rc3.1.0-devel-ubuntu22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Dependencies + id: depends + run: | + apt-get update + apt-get install -y build-essential git cmake libcurl4-openssl-dev + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-musa + evict-old-files: 1d + + - name: Build with native CMake MUSA support + id: cmake_build + run: | + cmake -B build -S . \ + -DGGML_MUSA=ON + cmake --build build --config Release -j $(nproc) + ubuntu-22-cmake-sycl: runs-on: ubuntu-22.04 continue-on-error: true steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: add oneAPI to apt shell: bash @@ -427,13 +483,22 @@ jobs: id: checkout uses: actions/checkout@v4 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-sycl + evict-old-files: 1d + - name: Build id: cmake_build run: | source /opt/intel/oneapi/setvars.sh mkdir build cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake .. \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx cmake --build . --config Release -j $(nproc) ubuntu-22-cmake-sycl-fp16: @@ -442,7 +507,7 @@ jobs: continue-on-error: true steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: add oneAPI to apt shell: bash @@ -468,85 +533,38 @@ jobs: id: checkout uses: actions/checkout@v4 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-sycl-fp16 + evict-old-files: 1d + - name: Build id: cmake_build run: | source /opt/intel/oneapi/setvars.sh mkdir build cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON .. + cmake .. \ + -DGGML_SYCL=ON \ + -DCMAKE_C_COMPILER=icx \ + -DCMAKE_CXX_COMPILER=icpx \ + -DGGML_SYCL_F16=ON cmake --build . --config Release -j $(nproc) - # TODO: build with GGML_NO_METAL because test-backend-ops fail on "Apple Paravirtual device" and I don't know - # how to debug it. - # ref: https://github.com/ggerganov/llama.cpp/actions/runs/7131777249/job/19420981052#step:5:1124 - macOS-latest-make: - runs-on: macos-latest - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - - - name: Build - id: make_build - env: - LLAMA_FATAL_WARNINGS: 1 - run: | - GGML_NO_METAL=1 make -j $(sysctl -n hw.logicalcpu) - - - name: Test - id: make_test - run: | - GGML_NO_METAL=1 make tests -j $(sysctl -n hw.logicalcpu) - GGML_NO_METAL=1 make test -j $(sysctl -n hw.logicalcpu) - - # TODO: build with GGML_METAL=OFF because test-backend-ops fail on "Apple Paravirtual device" and I don't know - # how to debug it. - # ref: https://github.com/ggerganov/llama.cpp/actions/runs/7132125951/job/19422043567?pr=4359#step:5:6584 - # would be great if we fix these - macOS-latest-cmake: - runs-on: macos-latest - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Dependencies - id: depends - continue-on-error: true - run: | - brew update - - - name: Build - id: cmake_build - run: | - sysctl -a - mkdir build - cd build - cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF .. - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) - - - name: Test - id: cmake_test - run: | - cd build - ctest -L main --verbose --timeout 900 - macOS-latest-cmake-ios: runs-on: macos-latest steps: - name: Clone id: checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-ios + evict-old-files: 1d - name: Dependencies id: depends @@ -561,6 +579,7 @@ jobs: mkdir build cd build cmake -G Xcode .. \ + -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_BUILD_EXAMPLES=OFF \ -DLLAMA_BUILD_TESTS=OFF \ @@ -576,7 +595,13 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-tvos + evict-old-files: 1d - name: Dependencies id: depends @@ -591,6 +616,7 @@ jobs: mkdir build cd build cmake -G Xcode .. \ + -DGGML_METAL_USE_BF16=ON \ -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_BUILD_EXAMPLES=OFF \ -DLLAMA_BUILD_TESTS=OFF \ @@ -610,7 +636,13 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-swift + evict-old-files: 1d - name: Dependencies id: depends @@ -618,15 +650,26 @@ jobs: run: | brew update + - name: Build llama.cpp with CMake + id: cmake_build + run: | + sysctl -a + mkdir build + cd build + cmake -G Xcode .. \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_SERVER=OFF \ + -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" + cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) + sudo cmake --install . --config Release + - name: xcodebuild for swift package id: xcodebuild run: | - xcodebuild -scheme llama -destination "${{ matrix.destination }}" - - - name: Build Swift Example - id: make_build_swift_example - run: | - make swift + xcodebuild -scheme llama-Package -destination "${{ matrix.destination }}" windows-msys2: runs-on: windows-latest @@ -642,6 +685,12 @@ jobs: - name: Clone uses: actions/checkout@v4 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-msys2 + evict-old-files: 1d + - name: Setup ${{ matrix.sys }} uses: msys2/setup-msys2@v2 with: @@ -649,25 +698,11 @@ jobs: msystem: ${{matrix.sys}} install: >- base-devel + git mingw-w64-${{matrix.env}}-toolchain mingw-w64-${{matrix.env}}-cmake mingw-w64-${{matrix.env}}-openblas - - name: Build using make - shell: msys2 {0} - run: | - make -j $(nproc) - - - name: Clean after building using make - shell: msys2 {0} - run: | - make clean - - - name: Build using make w/ OpenBLAS - shell: msys2 {0} - run: | - make GGML_OPENBLAS=1 -j $(nproc) - - name: Build using CMake shell: msys2 {0} run: | @@ -686,7 +721,7 @@ jobs: cmake --build build --config ${{ matrix.build }} -j $(nproc) windows-latest-cmake: - runs-on: windows-2019 + runs-on: windows-latest env: OPENBLAS_VERSION: 0.3.23 @@ -697,23 +732,25 @@ jobs: matrix: include: - build: 'noavx-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF' - build: 'avx2-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON' - build: 'avx-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX2=OFF -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX2=OFF' - build: 'avx512-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX512=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX512=ON' - build: 'openblas-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BLAS=ON -DBUILD_SHARED_LIBS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' - build: 'kompute-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON' - build: 'vulkan-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_VULKAN=ON -DBUILD_SHARED_LIBS=ON' + defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_VULKAN=ON' - build: 'llvm-arm64' - defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON' - build: 'msvc-arm64' - defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-msvc.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DBUILD_SHARED_LIBS=ON' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-msvc.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON' + - build: 'llvm-arm64-opencl-adreno' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' steps: - name: Clone @@ -722,11 +759,17 @@ jobs: with: fetch-depth: 0 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-${{ matrix.build }} + evict-old-files: 1d + - name: Clone Kompute submodule id: clone_kompute if: ${{ matrix.build == 'kompute-x64' }} run: | - git submodule update --init ggml/src/kompute + git submodule update --init ggml/src/ggml-kompute/kompute - name: Download OpenBLAS id: get_openblas @@ -755,6 +798,28 @@ jobs: run: | choco install ninja + - name: Install OpenCL Headers and Libs + id: install_opencl + if: ${{ matrix.build == 'llvm-arm64-opencl-adreno' }} + run: | + git clone https://github.com/KhronosGroup/OpenCL-Headers + cd OpenCL-Headers + mkdir build && cd build + cmake .. ` + -DBUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build . --target install + git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader + cd OpenCL-ICD-Loader + mkdir build-arm64-release && cd build-arm64-release + cmake .. ` + -A arm64 ` + -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build . --target install --config release + - name: Build id: cmake_build run: | @@ -784,7 +849,7 @@ jobs: - name: Test id: cmake_test # not all machines have native AVX-512 - if: ${{ matrix.build != 'msvc-arm64' && matrix.build != 'llvm-arm64' && matrix.build != 'kompute-x64' && matrix.build != 'vulkan-x64' && (matrix.build != 'avx512-x64' || env.HAS_AVX512F == '1') }} + if: ${{ matrix.build != 'msvc-arm64' && matrix.build != 'llvm-arm64' && matrix.build != 'llvm-arm64-opencl-adreno' && matrix.build != 'kompute-x64' && matrix.build != 'vulkan-x64' && (matrix.build != 'avx512-x64' || env.HAS_AVX512F == '1') }} run: | cd build ctest -L main -C Release --verbose --timeout 900 @@ -820,6 +885,7 @@ jobs: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} run: | Copy-Item LICENSE .\build\bin\Release\llama.cpp.txt + Copy-Item .\examples\run\linenoise.cpp\LICENSE .\build\bin\Release\linenoise.cpp.txt 7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\* - name: Upload artifacts @@ -829,12 +895,47 @@ jobs: path: llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip name: llama-bin-win-${{ matrix.build }}.zip - windows-latest-cmake-cuda: + ubuntu-latest-cmake-cuda: + runs-on: ubuntu-latest + container: nvidia/cuda:12.6.2-devel-ubuntu24.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install dependencies + env: + DEBIAN_FRONTEND: noninteractive + run: | + apt update + apt install -y cmake build-essential ninja-build libgomp1 git + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-latest-cmake-cuda + evict-old-files: 1d + + - name: Build with CMake + run: | + cmake -S . -B build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CUDA_ARCHITECTURES=89-real \ + -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_NATIVE=OFF \ + -DGGML_CUDA=ON + cmake --build build + + windows-2019-cmake-cuda: runs-on: windows-2019 strategy: matrix: - cuda: ['12.2.0', '11.7.1'] + cuda: ['12.4', '11.7'] build: ['cuda'] steps: @@ -842,24 +943,88 @@ jobs: id: checkout uses: actions/checkout@v4 with: - fetch-depth: 0 + fetch-depth: 0 - - name: Install CUDA toolkit - id: cuda-toolkit - uses: Jimver/cuda-toolkit@v0.2.15 + - name: Install ccache + uses: hendrikmuhs/ccache-action@v1.2.16 with: - cuda: ${{ matrix.cuda }} - method: 'network' - sub-packages: '["nvcc", "cudart", "cublas", "cublas_dev", "thrust", "visual_studio_integration"]' + key: ${{ github.job }}-${{ matrix.cuda }}-${{ matrix.build }} + evict-old-files: 1d + + - name: Install Cuda Toolkit 11.7 + if: ${{ matrix.cuda == '11.7' }} + run: | + mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" + choco install unzip -y + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-11.7.99-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-11.7.99-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-11.7.99-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-11.7.4.6-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-11.7.91-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-11.7.91-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-11.7.101-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-11.7.91-archive.zip" + unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cudart-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvcc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvrtc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libcublas-windows-x86_64-11.7.4.6-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvtx-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\visual_studio_integration-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvprof-windows-x86_64-11.7.101-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cccl-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V11_7=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Install Cuda Toolkit 12.4 + if: ${{ matrix.cuda == '12.4' }} + run: | + mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" + choco install unzip -y + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.4.131-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.4.5.8-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.4.127-archive.zip" + unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cudart-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvcc-windows-x86_64-12.4.131-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvrtc-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libcublas-windows-x86_64-12.4.5.8-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvtx-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_profiler_api-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\visual_studio_integration-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvprof-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cccl-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V12_4=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja - name: Build id: cmake_build + shell: cmd run: | - mkdir build - cd build - cmake .. -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=ON -DGGML_RPC=ON - cmake --build . --config Release -j $((${env:NUMBER_OF_PROCESSORS} - 1)) -t ggml - cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS} + call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" + cmake -S . -B build -G "Ninja Multi-Config" ^ + -DLLAMA_BUILD_SERVER=ON ^ + -DGGML_NATIVE=OFF ^ + -DGGML_CUDA=ON ^ + -DGGML_RPC=ON + set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 + cmake --build build --config Release -j %NINJA_JOBS% -t ggml + cmake --build build --config Release - name: Determine tag name id: tag @@ -888,10 +1053,12 @@ jobs: name: llama-bin-win-cu${{ matrix.cuda }}-x64.zip - name: Copy and pack Cuda runtime + if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} run: | - echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}" + echo "Cuda install location: ${{ env.CUDA_PATH }}" $dst='.\build\bin\cudart\' - robocopy "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll + robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll + robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll 7z a cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip $dst\* - name: Upload Cuda runtime @@ -909,8 +1076,8 @@ jobs: shell: bash env: - WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7dff44ba-e3af-4448-841c-0d616c8da6e7/w_BaseKit_p_2024.1.0.595_offline.exe - WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel + WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b380d914-366b-4b77-a74a-05e3c38b3514/intel-oneapi-base-toolkit-2025.0.0.882_offline.exe + WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" steps: - name: Clone @@ -919,8 +1086,15 @@ jobs: with: fetch-depth: 0 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-sycl + evict-old-files: 1d + - name: Install - run: scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL + run: | + scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL - name: Build id: cmake_build @@ -939,24 +1113,33 @@ jobs: echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT fi - - name: Pack artifacts + - name: Build the release package id: pack_artifacts if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} run: | echo "cp oneAPI running time dll files in ${{ env.ONEAPI_ROOT }} to ./build/bin" - cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_sycl_blas.4.dll" ./build/bin + + cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_sycl_blas.5.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_core.2.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_tbb_thread.2.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/pi_win_proxy_loader.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/pi_level_zero.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl7.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_level_zero.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_opencl.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_loader.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_win_proxy_loader.dll" ./build/bin + + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl8.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/svml_dispmd.dll" ./build/bin cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin + + cp "${{ env.ONEAPI_ROOT }}/dnnl/latest/bin/dnnl.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin + echo "cp oneAPI running time dll files to ./build/bin done" 7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip ./build/bin/* - - name: Upload artifacts + - name: Upload the release package if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} uses: actions/upload-artifact@v4 with: @@ -964,19 +1147,75 @@ jobs: name: llama-bin-win-sycl-x64.zip windows-latest-cmake-hip: + if: ${{ github.event.inputs.create_release != 'true' }} runs-on: windows-latest steps: - name: Clone id: checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install id: depends run: | $ErrorActionPreference = "Stop" write-host "Downloading AMD HIP SDK Installer" - Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" + Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" + write-host "Installing AMD HIP SDK" + Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait + write-host "Completed AMD HIP SDK installation" + + - name: Verify ROCm + id: verify + run: | + & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version + + - name: Install ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ${{ github.job }} + evict-old-files: 1d + + - name: Build + id: cmake_build + run: | + $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) + $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" + cmake -G "Unix Makefiles" -B build -S . ` + -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` + -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` + -DCMAKE_BUILD_TYPE=Release ` + -DGGML_HIP=ON ` + -DGGML_RPC=ON + cmake --build build -j ${env:NUMBER_OF_PROCESSORS} + + windows-latest-cmake-hip-release: + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + runs-on: windows-latest + + strategy: + matrix: + gpu_target: [gfx1100, gfx1101, gfx1030] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-hip-release + evict-old-files: 1d + + - name: Install + id: depends + run: | + $ErrorActionPreference = "Stop" + write-host "Downloading AMD HIP SDK Installer" + Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" write-host "Installing AMD HIP SDK" Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait write-host "Completed AMD HIP SDK installation" @@ -991,8 +1230,42 @@ jobs: run: | $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - cmake -G "Unix Makefiles" -B build -S . -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" -DGGML_HIPBLAS=ON - cmake --build build --config Release + cmake -G "Unix Makefiles" -B build -S . ` + -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` + -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` + -DCMAKE_BUILD_TYPE=Release ` + -DAMDGPU_TARGETS=${{ matrix.gpu_target }} ` + -DGGML_HIP=ON ` + -DGGML_RPC=ON + cmake --build build -j ${env:NUMBER_OF_PROCESSORS} + md "build\bin\rocblas\library\" + cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\" + cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\" + cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\" + + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER="$(git rev-list --count HEAD)" + SHORT_HASH="$(git rev-parse --short=7 HEAD)" + if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT + else + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT + fi + + - name: Pack artifacts + id: pack_artifacts + run: | + 7z a llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip .\build\bin\* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip + name: llama-bin-win-hip-x64-${{ matrix.gpu_target }}.zip ios-xcode-build: runs-on: macos-latest @@ -1001,6 +1274,29 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Build + id: cmake_build + run: | + sysctl -a + mkdir build + cd build + cmake -G Xcode .. \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + sudo cmake --install . --config Release + + - name: xcodebuild for swift package + id: xcodebuild + run: | + xcodebuild -scheme llama-Package -destination 'generic/platform=iOS' + - name: Build Xcode project run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' build @@ -1011,6 +1307,12 @@ jobs: - name: Clone uses: actions/checkout@v4 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: android-build + evict-old-files: 1d + - name: Set up JDK uses: actions/setup-java@v3 with: @@ -1028,35 +1330,16 @@ jobs: ./gradlew build --no-daemon -# freeBSD-latest: -# runs-on: macos-12 -# steps: -# - name: Clone -# uses: actions/checkout@v4 -# -# - name: Build -# uses: cross-platform-actions/action@v0.19.0 -# with: -# operating_system: freebsd -# version: '13.2' -# hypervisor: 'qemu' -# run: | -# sudo pkg update -# sudo pkg install -y gmake automake autoconf pkgconf llvm15 openblas -# gmake CC=/usr/local/bin/clang15 CXX=/usr/local/bin/clang++15 -j `sysctl -n hw.ncpu` - release: if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} runs-on: ubuntu-latest needs: - - ubuntu-focal-make - - ubuntu-latest-cmake - - macOS-latest-make - - macOS-latest-cmake + - ubuntu-cpu-cmake - windows-latest-cmake - - windows-latest-cmake-cuda + - windows-2019-cmake-cuda + - windows-latest-cmake-hip-release - macOS-latest-cmake-arm64 - macOS-latest-cmake-x64 @@ -1067,6 +1350,12 @@ jobs: with: fetch-depth: 0 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: release + evict-old-files: 1d + - name: Determine tag name id: tag shell: bash @@ -1092,7 +1381,7 @@ jobs: - name: Create release id: create_release - uses: anzz1/action-create-release@v1 + uses: ggml-org/action-create-release@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: @@ -1312,3 +1601,37 @@ jobs: # popd # emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} # make + + openEuler-latest-cmake-cann: + if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} + defaults: + run: + shell: bash -el {0} + runs-on: ubuntu-24.04-arm + strategy: + matrix: + cann: + - '8.0.rc3.beta1-910b-openeuler22.03-py3.10' + device: + - 'ascend910b3' + build: + - 'Release' + container: ascendai/cann:${{ matrix.cann }} + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Dependencies + run: | + yum update -y + yum install -y git gcc gcc-c++ make cmake + + - name: Build + run: | + export LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/$(uname -m)-linux/devlib/:${LD_LIBRARY_PATH} + + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_CANN=on \ + -DSOC_TYPE=${{ matrix.device }} + cmake --build build -j $(nproc) diff --git a/.github/workflows/close-issue.yml b/.github/workflows/close-issue.yml index 69c9f4f69..f63860d14 100644 --- a/.github/workflows/close-issue.yml +++ b/.github/workflows/close-issue.yml @@ -3,6 +3,11 @@ on: schedule: - cron: "42 0 * * *" +# Fine-grant permission +# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token +permissions: + issues: write + jobs: close-issues: runs-on: ubuntu-latest diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 9044cd78b..6955a7dc8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -10,48 +10,50 @@ name: Publish Docker image on: - #pull_request: - push: - branches: - - master - paths: ['.github/workflows/docker.yml', '.devops/*.Dockerfile', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal'] + workflow_dispatch: # allows manual triggering + schedule: + # Rebuild daily rather than on every push because it is expensive + - cron: '12 4 * * *' concurrency: group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} cancel-in-progress: true +# Fine-grant permission +# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token +permissions: + packages: write + jobs: push_to_registry: name: Push Docker image to Docker Hub - #if: github.event.pull_request.draft == false - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 env: COMMIT_SHA: ${{ github.sha }} strategy: + fail-fast: false matrix: config: - - { tag: "light", dockerfile: ".devops/llama-cli.Dockerfile", platforms: "linux/amd64,linux/arm64" } - - { tag: "server", dockerfile: ".devops/llama-server.Dockerfile", platforms: "linux/amd64,linux/arm64" } - - { tag: "full", dockerfile: ".devops/full.Dockerfile", platforms: "linux/amd64,linux/arm64" } - - { tag: "light-cuda", dockerfile: ".devops/llama-cli-cuda.Dockerfile", platforms: "linux/amd64" } - - { tag: "server-cuda", dockerfile: ".devops/llama-server-cuda.Dockerfile", platforms: "linux/amd64" } - - { tag: "full-cuda", dockerfile: ".devops/full-cuda.Dockerfile", platforms: "linux/amd64" } + # Multi-stage build + - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, freediskspace: false} + - { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} + - { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} + - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} + - { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, freediskspace: false} # Note: the rocm images are failing due to a compiler error and are disabled until this is fixed to allow the workflow to complete - #- { tag: "light-rocm", dockerfile: ".devops/llama-cli-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" } - #- { tag: "server-rocm", dockerfile: ".devops/llama-server-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" } - #- { tag: "full-rocm", dockerfile: ".devops/full-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" } - - { tag: "light-intel", dockerfile: ".devops/llama-cli-intel.Dockerfile", platforms: "linux/amd64" } - - { tag: "server-intel", dockerfile: ".devops/llama-server-intel.Dockerfile", platforms: "linux/amd64" } + #- {tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, freediskspace: true } steps: - name: Check out the repo uses: actions/checkout@v4 + with: + fetch-depth: 0 # preserve git history, so we can determine the build number - name: Set up QEMU - uses: docker/setup-qemu-action@v2 + uses: docker/setup-qemu-action@v3 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v2 + uses: docker/setup-buildx-action@v3 - name: Log in to Docker Hub uses: docker/login-action@v2 @@ -60,9 +62,45 @@ jobs: username: ${{ github.repository_owner }} password: ${{ secrets.GITHUB_TOKEN }} - # https://github.com/jlumbroso/free-disk-space/tree/54081f138730dfa15788a46383842cd2f914a1be#example + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER="$(git rev-list --count HEAD)" + SHORT_HASH="$(git rev-parse --short=7 HEAD)" + REPO_OWNER="${GITHUB_REPOSITORY_OWNER@L}" # to lower case + REPO_NAME="${{ github.event.repository.name }}" + + # determine tag name postfix (build number, commit hash) + if [[ "${{ env.GITHUB_BRANCH_NAME }}" == "master" ]]; then + TAG_POSTFIX="-b${BUILD_NUMBER}" + else + SAFE_NAME=$(echo "${{ env.GITHUB_BRANCH_NAME }}" | tr '/' '-') + TAG_POSTFIX="-${SAFE_NAME}-${SHORT_HASH}" + fi + # list all tags possible + if [[ "${{ matrix.config.tag }}" == "cpu" ]]; then + TYPE="" + else + TYPE="-${{ matrix.config.tag }}" + fi + PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:" + FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}${TAG_POSTFIX}" + LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}${TAG_POSTFIX}" + SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}${TAG_POSTFIX}" + echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT + echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT + echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT + echo "full_output_tags=$FULLTAGS" # print out for debugging + echo "light_output_tags=$LIGHTTAGS" # print out for debugging + echo "server_output_tags=$SERVERTAGS" # print out for debugging + env: + GITHUB_BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}' + - name: Free Disk Space (Ubuntu) - uses: jlumbroso/free-disk-space@main + if: ${{ matrix.config.free_disk_space == true }} + uses: ggml-org/free-disk-space@v1.3.1 with: # this might remove tools that are actually needed, # if set to "true" but frees about 6 GB @@ -77,31 +115,59 @@ jobs: docker-images: true swap-storage: true - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Downcase github.repository_owner - run: | - echo "repository_owner_lowercase=${GITHUB_REPOSITORY_OWNER@L}" >> $GITHUB_ENV - env: - GITHUB_REPOSITORY_OWNER: '${{ github.repository_owner }}' - - - name: Build and push Docker image (tagged + versioned) - if: github.event_name == 'push' + - name: Build and push Full Docker image (tagged + versioned) + if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.full == true }} uses: docker/build-push-action@v6 with: context: . push: true platforms: ${{ matrix.config.platforms }} - tags: "ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }},ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }},ghcr.io/${{ env.repository_owner_lowercase }}/llama.cpp:${{ matrix.config.tag }}-${{ steps.tag.outputs.name }}" + # tag list is generated from step above + tags: ${{ steps.tag.outputs.full_output_tags }} file: ${{ matrix.config.dockerfile }} + target: full + provenance: false + # using github experimental cache + cache-from: type=gha + cache-to: type=gha,mode=max + # return to this if the experimental github cache is having issues + #cache-to: type=local,dest=/tmp/.buildx-cache + #cache-from: type=local,src=/tmp/.buildx-cache + + - name: Build and push Light Docker image (tagged + versioned) + if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }} + uses: docker/build-push-action@v6 + with: + context: . + push: true + platforms: ${{ matrix.config.platforms }} + # tag list is generated from step above + tags: ${{ steps.tag.outputs.light_output_tags }} + file: ${{ matrix.config.dockerfile }} + target: light + provenance: false + # using github experimental cache + cache-from: type=gha + cache-to: type=gha,mode=max + # return to this if the experimental github cache is having issues + #cache-to: type=local,dest=/tmp/.buildx-cache + #cache-from: type=local,src=/tmp/.buildx-cache + + - name: Build and push Server Docker image (tagged + versioned) + if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }} + uses: docker/build-push-action@v6 + with: + context: . + push: true + platforms: ${{ matrix.config.platforms }} + # tag list is generated from step above + tags: ${{ steps.tag.outputs.server_output_tags }} + file: ${{ matrix.config.dockerfile }} + target: server + provenance: false + # using github experimental cache + cache-from: type=gha + cache-to: type=gha,mode=max + # return to this if the experimental github cache is having issues + #cache-to: type=local,dest=/tmp/.buildx-cache + #cache-from: type=local,src=/tmp/.buildx-cache diff --git a/.github/workflows/editorconfig.yml b/.github/workflows/editorconfig.yml index ae86e9927..f02b7c219 100644 --- a/.github/workflows/editorconfig.yml +++ b/.github/workflows/editorconfig.yml @@ -23,5 +23,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: editorconfig-checker/action-editorconfig-checker@main + - uses: editorconfig-checker/action-editorconfig-checker@v2 + with: + version: v3.0.3 - run: editorconfig-checker diff --git a/.github/workflows/nix-ci-aarch64.yml b/.github/workflows/nix-ci-aarch64.yml deleted file mode 100644 index 4aa4b2379..000000000 --- a/.github/workflows/nix-ci-aarch64.yml +++ /dev/null @@ -1,65 +0,0 @@ -name: Nix aarch64 builds - -on: - workflow_dispatch: # allows manual triggering - schedule: - # Rebuild daily rather than on every push because QEMU is expensive (e.g. - # 1.5h instead of minutes with the cold cache). - # - # randint(0, 59), randint(0, 23) - - cron: '26 12 * * *' - # But also rebuild if we touched any of the Nix expressions: - push: - branches: - - master - paths: ['**/*.nix', 'flake.lock'] - pull_request: - types: [opened, synchronize, reopened] - paths: ['**/*.nix', 'flake.lock'] - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - nix-build-aarch64: - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - name: Install QEMU - # Copy-paste from https://github.com/orgs/community/discussions/8305#discussioncomment-5888654 - run: | - sudo apt-get update - sudo apt-get install -y qemu-user-static qemu-system-aarch64 - sudo usermod -a -G kvm $USER - - name: Install Nix - uses: DeterminateSystems/nix-installer-action@v9 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - extra-conf: | - extra-platforms = aarch64-linux - extra-system-features = nixos-test kvm - extra-substituters = https://llama-cpp.cachix.org https://cuda-maintainers.cachix.org - extra-trusted-public-keys = llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc= cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E= - - uses: DeterminateSystems/magic-nix-cache-action@v2 - with: - upstream-cache: https://${{ matrix.cachixName }}.cachix.org - - name: Set-up cachix to push the results to - uses: cachix/cachix-action@v13 - with: - authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' - name: llama-cpp - - name: Show all output paths - run: > - nix run github:nix-community/nix-eval-jobs - -- --gc-roots-dir gcroot - --flake - ".#packages.aarch64-linux" - - name: Build - run: > - nix run github:Mic92/nix-fast-build - -- --skip-cached --no-nom - --systems aarch64-linux - --flake - ".#checks.aarch64-linux" diff --git a/.github/workflows/nix-ci.yml b/.github/workflows/nix-ci.yml deleted file mode 100644 index 8955f38d0..000000000 --- a/.github/workflows/nix-ci.yml +++ /dev/null @@ -1,72 +0,0 @@ -name: Nix CI - -on: - workflow_dispatch: # allows manual triggering - push: - branches: - - master - pull_request: - types: [opened, synchronize, reopened] - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -jobs: - nix-eval: - strategy: - fail-fast: false - matrix: - os: [ ubuntu-latest, macos-latest ] - runs-on: ${{ matrix.os }} - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - name: Install Nix - uses: DeterminateSystems/nix-installer-action@v9 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - extra-conf: | - extra-substituters = https://llama-cpp.cachix.org https://cuda-maintainers.cachix.org - extra-trusted-public-keys = llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc= cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E= - - uses: DeterminateSystems/magic-nix-cache-action@v2 - with: - upstream-cache: https://${{ matrix.cachixName }}.cachix.org - - name: List all flake outputs - run: nix flake show --all-systems - - name: Show all output paths - run: > - nix run github:nix-community/nix-eval-jobs - -- --gc-roots-dir gcroot - --flake - ".#packages.$(nix eval --raw --impure --expr builtins.currentSystem)" - nix-build: - strategy: - fail-fast: false - matrix: - os: [ ubuntu-latest, macos-latest ] - runs-on: ${{ matrix.os }} - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - name: Install Nix - uses: DeterminateSystems/nix-installer-action@v9 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - extra-conf: | - extra-substituters = https://llama-cpp.cachix.org https://cuda-maintainers.cachix.org - extra-trusted-public-keys = llama-cpp.cachix.org-1:H75X+w83wUKTIPSO1KWy9ADUrzThyGs8P5tmAbkWhQc= cuda-maintainers.cachix.org-1:0dq3bujKpuEPMCX6U4WylrUDZ9JyUG0VpVZa7CNfq5E= - - uses: DeterminateSystems/magic-nix-cache-action@v2 - with: - upstream-cache: https://${{ matrix.cachixName }}.cachix.org - - name: Set-up cachix to push the results to - uses: cachix/cachix-action@v13 - with: - authToken: '${{ secrets.CACHIX_AUTH_TOKEN }}' - name: llama-cpp - - name: Build - run: > - nix run github:Mic92/nix-fast-build - -- --skip-cached --no-nom - --flake - ".#checks.$(nix eval --raw --impure --expr builtins.currentSystem)" diff --git a/.github/workflows/nix-flake-update.yml b/.github/workflows/nix-flake-update.yml deleted file mode 100644 index 3a6a96e26..000000000 --- a/.github/workflows/nix-flake-update.yml +++ /dev/null @@ -1,22 +0,0 @@ -name: update-flake-lock -on: - workflow_dispatch: - schedule: - - cron: '0 0 * * 0' # runs weekly on Sunday at 00:00 - -jobs: - lockfile: - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - name: Install Nix - uses: DeterminateSystems/nix-installer-action@main - - name: Update flake.lock - uses: DeterminateSystems/update-flake-lock@main - with: - pr-title: "nix: update flake.lock" - pr-labels: | - nix - pr-reviewers: philiptaron,SomeoneSerge - token: ${{ secrets.FLAKE_TOKEN }} diff --git a/.github/workflows/nix-publish-flake.yml b/.github/workflows/nix-publish-flake.yml deleted file mode 100644 index 2c3c1ebda..000000000 --- a/.github/workflows/nix-publish-flake.yml +++ /dev/null @@ -1,36 +0,0 @@ -# Make the flake discoverable on https://flakestry.dev and https://flakehub.com/flakes -name: "Publish a flake to flakestry & flakehub" -on: - push: - tags: - - "*" - workflow_dispatch: - inputs: - tag: - description: "The existing tag to publish" - type: "string" - required: true -jobs: - flakestry-publish: - runs-on: ubuntu-latest - permissions: - id-token: "write" - contents: "read" - steps: - - uses: flakestry/flakestry-publish@main - with: - version: "${{ inputs.tag || github.ref_name }}" - flakehub-publish: - runs-on: "ubuntu-latest" - permissions: - id-token: "write" - contents: "read" - steps: - - uses: "actions/checkout@v4" - with: - ref: "${{ (inputs.tag != null) && format('refs/tags/{0}', inputs.tag) || '' }}" - - uses: "DeterminateSystems/nix-installer-action@main" - - uses: "DeterminateSystems/flakehub-push@main" - with: - visibility: "public" - tag: "${{ inputs.tag }}" diff --git a/.github/workflows/python-lint.yml b/.github/workflows/python-lint.yml index a8d46f31d..ddfdf73b8 100644 --- a/.github/workflows/python-lint.yml +++ b/.github/workflows/python-lint.yml @@ -1,6 +1,13 @@ name: flake8 Lint -on: [push, pull_request] +on: + push: + branches: + - master + paths: ['.github/workflows/python-lint.yml', '**/*.py'] + pull_request: + types: [opened, synchronize, reopened] + paths: ['.github/workflows/python-lint.yml', '**/*.py'] concurrency: group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} diff --git a/.github/workflows/python-type-check.yml b/.github/workflows/python-type-check.yml index e5ff5e6d7..373bb6010 100644 --- a/.github/workflows/python-type-check.yml +++ b/.github/workflows/python-type-check.yml @@ -4,11 +4,13 @@ on: push: paths: - '.github/workflows/python-type-check.yml' + - 'pyrightconfig.json' - '**.py' - '**/requirements*.txt' pull_request: paths: - '.github/workflows/python-type-check.yml' + - 'pyrightconfig.json' - '**.py' - '**/requirements*.txt' @@ -33,6 +35,6 @@ jobs: - name: Type-check with Pyright uses: jakebailey/pyright-action@v2 with: - version: 1.1.370 + version: 1.1.382 level: warning warnings: true diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index 99feb28f2..0cbc3d640 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -20,6 +20,12 @@ on: types: [opened, synchronize, reopened] paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*'] +env: + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + LLAMA_LOG_VERBOSITY: 10 + concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true @@ -70,20 +76,26 @@ jobs: run: | pip install -r examples/server/tests/requirements.txt - - name: Verify server deps - id: verify_server_deps + # Setup nodejs (to be used for verifying bundled index.html) + - uses: actions/setup-node@v4 + with: + node-version: '22.11.0' + + - name: Verify bundled index.html + id: verify_server_index_html run: | git config --global --add safe.directory $(realpath .) - cd examples/server - git ls-files --others --modified + cd examples/server/webui git status - ./deps.sh + npm ci + npm run build git status - not_ignored_files="$(git ls-files --others --modified)" - echo "Modified files: ${not_ignored_files}" - if [ -n "${not_ignored_files}" ]; then - echo "Repository is dirty or server deps are not built as expected" - echo "${not_ignored_files}" + modified_files="$(git status -s)" + echo "Modified files: ${modified_files}" + if [ -n "${modified_files}" ]; then + echo "Repository is dirty or server/webui is not built as expected" + echo "Hint: You may need to follow Web UI build guide in server/README.md" + echo "${modified_files}" exit 1 fi @@ -100,9 +112,9 @@ jobs: -DGGML_OPENMP=OFF ; cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - name: Build - id: cmake_build - if: ${{ matrix.sanitizer != 'THREAD' }} + - name: Build (sanitizers) + id: cmake_build_sanitizers + if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }} run: | cmake -B build \ -DGGML_NATIVE=OFF \ @@ -112,18 +124,37 @@ jobs: -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ; cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server + - name: Build (sanitizers) + id: cmake_build + if: ${{ matrix.sanitizer == '' }} + run: | + cmake -B build \ + -DGGML_NATIVE=OFF \ + -DLLAMA_BUILD_SERVER=ON \ + -DLLAMA_CURL=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ; + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server + - name: Tests id: server_integration_tests + if: ${{ matrix.sanitizer == '' }} run: | cd examples/server/tests - PORT=8888 ./tests.sh + ./tests.sh + + - name: Tests (sanitizers) + id: server_integration_tests_sanitizers + if: ${{ matrix.sanitizer != '' }} + run: | + cd examples/server/tests + LLAMA_SANITIZE=1 ./tests.sh - name: Slow tests id: server_integration_tests_slow if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} run: | cd examples/server/tests - PORT=8888 ./tests.sh --stop --no-skipped --no-capture --tags slow + SLOW_TESTS=1 ./tests.sh server-windows: @@ -173,11 +204,13 @@ jobs: if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }} run: | cd examples/server/tests - behave.exe --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp + $env:PYTHONIOENCODING = ":replace" + pytest -v -x -m "not slow" - name: Slow tests id: server_integration_tests_slow if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} run: | cd examples/server/tests - behave.exe --stop --no-skipped --no-capture --tags slow + $env:SLOW_TESTS = "1" + pytest -v -x diff --git a/.gitignore b/.gitignore index 1092d097a..694f36e04 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.a *.bat *.bin +*.d *.dll *.dot *.etag @@ -17,6 +18,7 @@ *.metallib *.o *.so +*.swp *.tmp # IDE / OS @@ -103,6 +105,10 @@ examples/server/*.mjs.hpp !examples/sycl/*.bat !examples/sycl/*.sh +# Server Web UI temporary files +node_modules +examples/server/webui/dist + # Python /.venv @@ -133,3 +139,7 @@ poetry.toml # Test models for lora adapters /lora-tests + +# Local scripts +/run-vim.sh +/run-chat.sh diff --git a/.gitmodules b/.gitmodules index 5861d59cb..23ce5ff05 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "kompute"] - path = ggml/src/kompute + path = ggml/src/ggml-kompute/kompute url = https://github.com/nomic-ai/kompute.git diff --git a/AUTHORS b/AUTHORS index 1bd36158a..2eb60806a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,4 +1,4 @@ -# date: Wed Jun 26 19:36:34 EEST 2024 +# date: Thu Nov 28 20:46:15 EET 2024 # this file is auto-generated by scripts/gen-authors.sh 0cc4m @@ -7,6 +7,7 @@ 2f38b454 3ooabkhxtn <31479382+3ooabkhxtn@users.noreply.github.com> 44670 <44670@users.noreply.github.com> +65a <10104049+65a@users.noreply.github.com> AN Long AT Aarni Koskela @@ -19,20 +20,28 @@ Adithya Balaji AdithyanI Adrian Adrian Hesketh +Ahmad Tameem <113388789+Tameem-10xE@users.noreply.github.com> Ahmet Zeer AidanBeltonS <87009434+AidanBeltonS@users.noreply.github.com> +AidanBeltonS Aisuko +Akarshan Biswas Akarshan Biswas +Al Mochkin <14274697+amochkin@users.noreply.github.com> Albert Jin Alberto <57916483+albbus-stack@users.noreply.github.com> +Alberto Cabrera Pérez +Alberto Cabrera Pérez Alex Alex Azarov Alex Azarov Alex Klinkhamer Alex Klinkhamer Alex Nguyen +Alex O'Connell <35843486+acon96@users.noreply.github.com> Alex Petenchea Alex Renda +Alex Tuddenham <61622354+AlexsCode@users.noreply.github.com> Alex von Gluck IV Alexey Parfenov Ali Chraghi <63465728+alichraghi@users.noreply.github.com> @@ -45,18 +54,25 @@ AmirAli Mirian <37371367+amiralimi@users.noreply.github.com> Ananta Bastola Anas Ahouzi <112881240+aahouzi@users.noreply.github.com> András Salamon +Andreas (Andi) Kunar Andrei Andrew Canis Andrew Downing Andrew Duffy Andrew Godfrey +Andrew Minh Nguyen <40281306+amqdn@users.noreply.github.com> +Andy Salerno Andy Tai +Anthony Van de Gejuchte +Antonis Makropoulos Arik Poznanski +Armen Kaleshian Artem Artem Zinnatullin Artyom Lebedev Asbjørn Olling Ásgeir Bjarni Ingvarsson +Asghar Ghorbani Ashish <1856117+ashishdatta@users.noreply.github.com> Ashok Gelal <401055+ashokgelal@users.noreply.github.com> Ashraful Islam @@ -76,12 +92,16 @@ Ben Williams Benjamin Findley <39356821+Kartoffelsaft@users.noreply.github.com> Benjamin Lecaillon <84293038+blecaillon@users.noreply.github.com> Bernat Vadell +Bert Wagner Bingan <70050083+binganao@users.noreply.github.com> +Bjarke Viksøe <164612031+bviksoe@users.noreply.github.com> Bodo Graumann Bono Lv Borislav Stanimirov Branden Butler +Brandon Squizzato <35474886+bsquizz@users.noreply.github.com> Brian +Brian Cunnie Bruce MacDonald Bryan Honof CJ Pais @@ -90,32 +110,47 @@ Calvin Laurenson Cameron Cameron Kaiser Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> +CarryFun <76023481+CarryFun@users.noreply.github.com> +Carsten Kragelund Jørgensen +CarterLi999 <664681047@qq.com> Casey Primozic Casey Primozic CausalLM <148736309+CausalLM@users.noreply.github.com> Cebtenzzre Chad Brewbaker +Changyeon Kim Chao Jiang +Charles Xu <63788048+chaxu01@users.noreply.github.com> +Charles Xu +Chen Xi +Chen Xi Cheng Shao +Chenguang Li <87689256+noemotiovon@users.noreply.github.com> Chris Elrod Chris Kuehl Christian Demsar Christian Demsar Christian Falch <875252+chrfalch@users.noreply.github.com> Christian Kögler +Christian Köhnenkamp Christian Zhou-Zheng <59622928+christianazinn@users.noreply.github.com> Clark Saben <76020733+csaben@users.noreply.github.com> Clint Herron +Conrad Kramer CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> +Csaba Kecskemeti Cuong Trinh Manh DAN™ Damian Stewart +Dan Johansson <164997844+eddnjjn@users.noreply.github.com> +Dan Johansson Dane Madsen DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Daniel Bevenius Daniel Drake Daniel Hiltgen Daniel Illescas Romero +Daniel Kleine <53251018+d-kleine@users.noreply.github.com> Daniele <57776841+daniandtheweb@users.noreply.github.com> DannyDaemonic Dat Quoc Nguyen <2412555+datquocnguyen@users.noreply.github.com> @@ -129,19 +164,28 @@ David Pflug David Renshaw David Sommers <12738+databyte@users.noreply.github.com> David Yang +DavidKorczynski Dawid Potocki Dawid Wysocki <62249621+TortillaZHawaii@users.noreply.github.com> Dean Deins +Denis Spasyuk <34203011+dspasyuk@users.noreply.github.com> +Derrick T. Woolworth Deven Mistry <31466137+deven367@users.noreply.github.com> +Dibakar Gope Didzis Gosko +Diego Devesa +Diogo Teles Sant'Anna Djip007 Don Mahurin DooWoong Lee (David) Doomsdayrs <38189170+Doomsdayrs@users.noreply.github.com> +Dou Xinpeng <15529241576@163.com> +Dou Xinpeng <81913537+Dou-Git@users.noreply.github.com> Douglas Hanley Dr. Tom Murphy VII Ph.D <499244+tom7@users.noreply.github.com> Ebey Abraham +Echo Nolan Ed Lee Ed Lepedus Eddie-Wang @@ -151,10 +195,13 @@ Elbios <141279586+Elbios@users.noreply.github.com> Elton Kola Engininja2 <139037756+Engininja2@users.noreply.github.com> Equim +Eric Curtin +Eric Curtin Eric Sommerlade Eric Zhang <34133756+EZForever@users.noreply.github.com> Erik Garrison Erik Scholz +Esko Toivonen Ettore Di Giacinto Evan Jones Evan Miller @@ -166,19 +213,26 @@ FK Fabian Fabio R. Sluzala Faez Shakil +Faisal Zaghloul +Faisal Zaghloul +Fan Shupei FantasyGmm <16450052+FantasyGmm@users.noreply.github.com> +Farbod Bijary <110523279+farbodbj@users.noreply.github.com> Fattire <528174+fat-tire@users.noreply.github.com> Felix Finn Voorhees Firat +FirstTimeEZ <179362031+FirstTimeEZ@users.noreply.github.com> Folko-Ven <71110216+Folko-Ven@users.noreply.github.com> Foul-Tarnished <107711110+Foul-Tarnished@users.noreply.github.com> Francisco Melo <43780565+francis2tm@users.noreply.github.com> Frank Mai FrankHB +Frankie Robertson Fred Douglas <43351173+fredlas@users.noreply.github.com> Frederik Vogel Gabe Goodhart +Gabe Goodhart GainLee Galunid Gary Linscott @@ -187,11 +241,13 @@ Gavin Zhao Genkagaku.GPT Georgi Gerganov Gilad S +Gilad S. <7817232+giladgd@users.noreply.github.com> Giuseppe Scrivano GiviMAD Govlzkoy Guillaume "Vermeille" Sanchez Guillaume Wenzek +Guoliang Hua <32868157+nbcsm@users.noreply.github.com> Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Gustavo Rocha Dias <91472747+gustrd@users.noreply.github.com> Haggai Nuchi @@ -213,11 +269,14 @@ Hong Bo PENG Hongyu Ouyang <96765450+casavaca@users.noreply.github.com> Howard Su Hua Jiang +Huang Qi Huawei Lin Hugo Roussel +Huifeng Ou <79071290+ho2103@users.noreply.github.com> Ian Bull Ian Bull Ian Scrivener +Icecream95 Ido S IgnacioFDM Igor Okulist @@ -226,11 +285,15 @@ Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> Ionoclast Laboratories Isaac McFadyen IsaacDynamo <61521674+IsaacDynamo@users.noreply.github.com> +Ivan +Ivan Filipov <159561759+vanaka11@users.noreply.github.com> Ivan Komarov Ivan Stepanov JH23X <165871467+JH23X@users.noreply.github.com> +Jack Mousseau Jack Mousseau JackJollimore <130917767+JackJollimore@users.noreply.github.com> +Jaeden Amero Jaemin Son Jag Chadha Jakub N @@ -243,10 +306,14 @@ Jannis Schönleber Jared Van Bortel Jared Van Bortel Jason McCartney +Jason Stillerman Jean-Christophe Hoelt Jean-Michaël Celerier Jed Fox +Jeff Bolz +Jeffrey Morgan Jeffrey Quesnelle +Jeroen Mostert Jesse Jojo Johnson Jeximo Jhen-Jie Hong @@ -258,6 +325,9 @@ Jiří Podivín <66251151+jpodivin@users.noreply.github.com> Jiří Sejkora Joan Fontanals Joan Fontanals +João Dinis Ferreira +Joe Eli McIlvain +Joe Todd Johan Johannes Gäßler Johannes Rudolph @@ -274,7 +344,9 @@ Joyce Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Judd Julius Arkenberg +Jun Hee Yoo Jun Jie <71215065+junnjiee16@users.noreply.github.com> +Junil Kim Junyang Lin Juraj Bednar Justin Parker @@ -292,12 +364,14 @@ Karthik Sethuraman Kasumi <90275229+kasumi-1@users.noreply.github.com> Kawrakow <48489457+ikawrakow@users.noreply.github.com> Keiichi Tabata +Keke Han Kenvix ⭐ Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Kevin Gibbons Kevin Ji <1146876+kevinji@users.noreply.github.com> Kevin Kwok Kevin Lo +Kevin Wang Kolen Cheung Konstantin Herud Konstantin Zhuravlyov @@ -315,22 +389,29 @@ LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> Leonardo Neumann Li Tan Linwei Wang +Liu Jia <109258120+Septa2112@users.noreply.github.com> +Liu Jia LoganDark +Loïc Carrère LostRuins <39025047+LostRuins@users.noreply.github.com> Luciano Luo Tian Lyle Dean +M-A M. Yusuf Sarıgöz +Ma Mingfei Maarten ter Huurne Mack Straight Maël Kerbiriou MaggotHATE +Mahesh Madhav <67384846+heshpdx@users.noreply.github.com> Manuel <44313466+makuche@users.noreply.github.com> Marc Köhlbrugge Marco Matthies <71844+marcom@users.noreply.github.com> Marcus Dunn <51931484+MarcusDunn@users.noreply.github.com> Marian Cepok Mark Fairbairn +Mark Zhuang Marko Tasic Markus Tavenrath Martin Delille @@ -342,11 +423,15 @@ MasterYi1024 <39848311+MasterYi1024@users.noreply.github.com> Mateusz Charytoniuk Matheus C. França Matheus Gabriel Alves Silva +Mathieu Geli Mathieu Nayrolles +Mathijs Henquet Mathijs de Bruin Matt Clayton <156335168+mattjcly@users.noreply.github.com> Matt Pulver +Matt Stephenson Matteo Boschini <12133566+mbosc@users.noreply.github.com> +Matteo Mortari Mattheus Chediak Matthew Tejo Matvey Soloviev @@ -356,8 +441,10 @@ Maxime <672982+maximegmd@users.noreply.github.com> Maximilian Winter Meng Zhang Meng, Hengyu +Mengqing Cao Merrick Christensen Michael Coppola +Michael Francis Michael Hueschen Michael Kesper Michael Klimenko @@ -365,41 +452,57 @@ Michael Podvitskiy Michael Potter Michael de Gans Michaël de Vries +Michał Tuszyński Mihai Mike Mikko Juola Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> +Minsoo Cheong Mirko185 Mirror Azure <54669636+MirrorAzure@users.noreply.github.com> +MistApproach <98988043+MistApproach@users.noreply.github.com> Miwa / Ensan <63481257+ensan-hcl@users.noreply.github.com> Mohammadreza Hendiani Mohammadreza Hendiani +Molly Sophia +MorganRO8 <47795945+MorganRO8@users.noreply.github.com> Murilo Santana Musab Gultekin Nam D. Tran <42194884+namtranase@users.noreply.github.com> Nathan Epstein +Natsu NawafAlansari <72708095+NawafAlansari@users.noreply.github.com> Nebula Neo Zhang <14088817+arthw@users.noreply.github.com> Neo Zhang Neo Zhang Jianyu Neuman Vong +Nexes the Old <124105151+Nexesenex@users.noreply.github.com> Nexesenex <124105151+Nexesenex@users.noreply.github.com> Niall Coates <1349685+Niall-@users.noreply.github.com> +Nicholai Tukanov +Nico Bosshard Nicolai Weitkemper Nicolás Pérez Nigel Bosch Niklas Korz +NikolaiLyssogor <59844691+NikolaiLyssogor@users.noreply.github.com> Nikolas <127742645+nneubacher@users.noreply.github.com> Nindaleth +OSecret <135510162+OLSecret@users.noreply.github.com> Oleksandr Nikitin Oleksii Maryshchenko Olivier Chafik Ondřej Čertík Ouadie EL FAROUKI +PAB +Pablo Duboue +Pascal Patry Patrice Ferlet Paul Tsochantaris +Pavel Zloi Pavol Rusnak +Paweł Wodnicki <151604+32bitmicro@users.noreply.github.com> Pedro Cuenca Peter Sugihara Phil H <5756783+phiharri@users.noreply.github.com> @@ -407,10 +510,15 @@ Philip Taron Phillip Kravtsov Pierre Alexandre SCHEMBRI Pierrick Hymbert +Pieter Ouwerkerk +Plamen Minev +Prashant Vithule <119530321+Vithulep@users.noreply.github.com> Przemysław Pawełczyk Qin Yue Chen <71813199+chenqiny@users.noreply.github.com> Qingyou Meng Qu Zongfu <43257352+yancaoweidaode@users.noreply.github.com> +R0CKSTAR +R0CKSTAR RJ Adriaansen Radoslav Gerganov Radosław Gryta @@ -419,11 +527,13 @@ Raj Hammeer Singh Hada Ralph Soika Rand Xie Randall Fitzgerald +Random Fly Reinforce-II Ren Xuancheng Rene Leonhardt <65483435+reneleonhardt@users.noreply.github.com> RhinoDevel Riceball LEE +Rich Dougherty Richard Kiss Richard Roberson Rick G <26732651+TheFlipbook@users.noreply.github.com> @@ -439,21 +549,30 @@ Robey Holderith Robyn Roger Meier Roland <14355895+rbur0425@users.noreply.github.com> +Romain Biessy Romain D <90720+Artefact2@users.noreply.github.com> Romain Neutron Roman Parykin Ron Evans Ron Jailall +Roni Ronny Brendel Ronsor Rowan Hart +Ruchira Hasaranga +Ruixin Huang <18860020911@163.com> Rune <43761327+Rune-AI@users.noreply.github.com> +RunningLeon +RunningLeon Ryan Landay Ryder Wishart Ryuei Rőczey Barnabás <31726601+An0nie@users.noreply.github.com> +SRHMorris <69468379+SRHMorris@users.noreply.github.com> +SXX SakuraUmi Salvador E. Tropea +Salvatore Mesoraca Sam Spilsbury Sami Farin <3876865+Safari77@users.noreply.github.com> Samuel Maynard @@ -463,23 +582,29 @@ Sebastián A SebastianApel <13675545+SebastianApel@users.noreply.github.com> Senemu <10880819+Senemu@users.noreply.github.com> Sergey Alirzaev +Sergio López Sergio López Sertaç Özercan <852750+sozercan@users.noreply.github.com> SeungWon Jeong <65549245+redlion0929@users.noreply.github.com> ShadovvBeast Shakhar Dasgupta +Shane A Shangning Xu <32517059+xushangning@users.noreply.github.com> +Shankar +Shanshan Shen <467638484@qq.com> Shijie <821898965@qq.com> Shintarou Okada Shouzheng Liu <61452103+lshzh-ww@users.noreply.github.com> Shouzheng Liu Shuichi Tsutsumi +Shupei Fan Sigbjørn Skjæret Simon Willison Siwen Yu Sky Yan Slaren <2141330+slaren@users.noreply.github.com> Slava Primenko +Small Grass Forest SoftwareRenderer <138734813+SoftwareRenderer@users.noreply.github.com> Someone Someone Serge @@ -491,12 +616,15 @@ Stefan Sydow Steffen Röcker Stephan Walter Stephen Nichols +Steve Bonds Steve Grubb Steven Prichard Steven Roussey Steward Garcia <57494570+FSSRepo@users.noreply.github.com> +StrangeBytesDev <141275258+StrangeBytesDev@users.noreply.github.com> Suaj Carrot <72162667+SuajCarrot@users.noreply.github.com> SuperUserNameMan +Sutou Kouhei Tai Duc Nguyen Taikono-Himazin Tameem <113388789+AhmadTameem@users.noreply.github.com> @@ -507,7 +635,9 @@ Theia Vogel Thérence <13496987+Royalphax@users.noreply.github.com> Thibault Terrasson Thomas Klausner +Thorsten Sommer Tim Miller +Tim Wang Timmy Knight Timothy Cronin <40186632+4imothy@users.noreply.github.com> Ting Lou @@ -517,24 +647,31 @@ Tom C Tom Jobbins <784313+TheBloke@users.noreply.github.com> Tomas Tomáš Pazdiora +Tony Wasserka <4840017+neobrain@users.noreply.github.com> Tristan Druyen Tristan Ross +Trivikram Kamat <16024985+trivikr@users.noreply.github.com> Tungsten842 <886724vf@anonaddy.me> Tungsten842 Tushar UEXTM.com <84163508+uextm@users.noreply.github.com> +Ujjawal Panchal <31011628+Ujjawal-K-Panchal@users.noreply.github.com> Ulrich Drepper Uzo Nweke Vaibhav Srivastav Val Kharitonov Valentin Konovalov Valentyn Bezshapkin <61702053+valentynbez@users.noreply.github.com> +Vali Malinoiu <0x4139@gmail.com> Victor Nogueira Victor Z. Peng +Viet-Anh NGUYEN (Andrew) +Vinesh Janarthanan <36610342+VJHack@users.noreply.github.com> Vlad Vladimir Vladimir Malyutin Vladimir Zorin +VoidIsVoid <343750470@qq.com> Volodymyr Vitvitskyi <72226+signalpillar@users.noreply.github.com> WangHaoranRobin <56047610+WangHaoranRobin@users.noreply.github.com> Weird Constructor @@ -551,15 +688,22 @@ Xiang (Kevin) Li Xiao-Yong Jin XiaotaoChen Xiaoyi Chen +Xie Yanbo Xingchen Song(宋星辰) +Xinpeng Dou <81913537+Dou-Git@users.noreply.github.com> Xuan Son Nguyen +Yaiko Yann Follet <131855179+YannFollet@users.noreply.github.com> Yaroslav Yazan Agha-Schrader Yiming Cui Yishuo Wang +Yoshi Suhara +Yoshi Suhara +Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Yueh-Po Peng <94939112+y10ab1@users.noreply.github.com> Yui +Yuri Khrustalev Yusuf Kağan Hanoğlu Yuval Peled <31162840+Yuval-Peled@users.noreply.github.com> ZHAOKAI WANG @@ -568,6 +712,8 @@ Zay <95888118+isaiahbjork@users.noreply.github.com> Zenix Zhang Peiyuan Zheng.Deng <32841220+dengzheng-cloud@users.noreply.github.com> +Zhenwei Jin <109658203+kylo5aby@users.noreply.github.com> +Zhiyuan Li ZhouYuChen Ziad Ben Hadj-Alouane Ziang Wu <97337387+ZiangWu-77@users.noreply.github.com> @@ -581,6 +727,7 @@ alexpinel <93524949+alexpinel@users.noreply.github.com> alonfaraj alwqx amd-lalithnc +amritahs-ibm andrijdavid anon998 <131767832+anon998@users.noreply.github.com> anzz1 @@ -588,14 +735,18 @@ apaz apcameron <37645737+apcameron@users.noreply.github.com> arch-btw <57669023+arch-btw@users.noreply.github.com> arcrank +ardfork <134447697+ardfork@users.noreply.github.com> arlo-phoenix <140345165+arlo-phoenix@users.noreply.github.com> at8u <129688334+at8u@users.noreply.github.com> automaticcat +awatuna <23447591+awatuna@users.noreply.github.com> +b4b4o bandoti <141645996+bandoti@users.noreply.github.com> beiller bhubbb <79117352+bhubbb@users.noreply.github.com> bmwl bobqianic <129547291+bobqianic@users.noreply.github.com> +brucepro bryanSwk <93190252+bryanSwk@users.noreply.github.com> bsilvereagle bssrdf @@ -614,10 +765,14 @@ cpumaxx <163466046+cpumaxx@users.noreply.github.com> crasm crasm daboe01 +daghanerdonmez <44506702+daghanerdonmez@users.noreply.github.com> +daminho <37615795+daminho@users.noreply.github.com> david raistrick ddh0 ddpasa <112642920+ddpasa@users.noreply.github.com> deepdiffuser <112834445+deepdiffuser@users.noreply.github.com> +devojony <61173062+devojony@users.noreply.github.com> +ditsuke divinity76 dm4 dotpy314 <33351922+dotpy314@users.noreply.github.com> @@ -629,14 +784,18 @@ ebraminio eiery <19350831+eiery@users.noreply.github.com> eric8607242 fairydreaming <166155368+fairydreaming@users.noreply.github.com> +fengerhu1 <2748250768@qq.com> fraxy-v <65565042+fraxy-v@users.noreply.github.com> github-actions[bot] gliptic goerch grahameth <96447521+grahameth@users.noreply.github.com> +gtygo gwjr <502526+gwjr@users.noreply.github.com> h-h-h-h <13482553+h-h-h-h@users.noreply.github.com> hankcs +haopeng <657407891@qq.com> +hipudding hoangmit hongbo.mo <352280764@qq.com> hopkins385 <98618192+hopkins385@users.noreply.github.com> @@ -649,12 +808,14 @@ hxer7963 hydai iSma iacore <74560659+iacore@users.noreply.github.com> +icppWorld <124377669+icppWorld@users.noreply.github.com> igarnier intelmatt <61025942+intelmatt@users.noreply.github.com> iohub jacobi petrucciani <8117202+jpetrucciani@users.noreply.github.com> jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> jameswu2014 <545426914@qq.com> +jdomke <28772296+jdomke@users.noreply.github.com> jiez <373447296@qq.com> jneem joecryptotoo <80373433+joecryptotoo@users.noreply.github.com> @@ -677,28 +838,35 @@ klosax <131523366+klosax@users.noreply.github.com> kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> kunnis kuronekosaiko +kustaaya <58045274+kustaaya@users.noreply.github.com> kuvaus <22169537+kuvaus@users.noreply.github.com> kwin1412 <42286931+kwin1412@users.noreply.github.com> l3utterfly +laik ldwang le.chang leejet +leo-pony limitedAtonement liuwei-git <14815172+liuwei-git@users.noreply.github.com> lon <114724657+longregen@users.noreply.github.com> loonerin <132926317+loonerin@users.noreply.github.com> +ltoniazzi <61414566+ltoniazzi@users.noreply.github.com> luoyu-intel m3ndax maddes8cht <55592906+maddes8cht@users.noreply.github.com> makomk manikbhandari maor-ps <154728172+maor-ps@users.noreply.github.com> +matiaslin <45382001+matiaslin@users.noreply.github.com> +matteo mdrokz mgroeber9110 <45620825+mgroeber9110@users.noreply.github.com> minarchist mj-shifu <77107165+mj-shifu@users.noreply.github.com> mmyjona momonga <115213907+mmnga@users.noreply.github.com> +momonga <146910567+mmngays@users.noreply.github.com> moritzbrantner <31051084+moritzbrantner@users.noreply.github.com> mzcu nanahi <130121847+na-na-hi@users.noreply.github.com> @@ -716,8 +884,10 @@ omahs <73983677+omahs@users.noreply.github.com> oobabooga <112222186+oobabooga@users.noreply.github.com> opparco ostix360 <55257054+ostix360@users.noreply.github.com> +pculliton pengxin99 perserk +piDack <104877312+piDack@users.noreply.github.com> pmysl postmasters pudepiedj @@ -733,6 +903,7 @@ runfuture sandyiscool sasha0552 semidark +serhii-nakon <57632032+serhii-nakon@users.noreply.github.com> sharpHL <132747147+sharpHL@users.noreply.github.com> shibe2 singularity <12184989+singularity-s0@users.noreply.github.com> @@ -741,42 +912,55 @@ sjxx <63994076+ylsdamxssjxxdd@users.noreply.github.com> slaren <2141330+slaren@users.noreply.github.com> slaren snadampal <87143774+snadampal@users.noreply.github.com> +standby24x7 staviq stduhpf strawberrymelonpanda <152940198+strawberrymelonpanda@users.noreply.github.com> swittk takov751 <40316768+takov751@users.noreply.github.com> tarcey +tc-mb <157115220+tc-mb@users.noreply.github.com> texmex76 <40733439+texmex76@users.noreply.github.com> thement <40525767+thement@users.noreply.github.com> +thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> tjohnman +toyer <2042519524@qq.com> tslmy ubik2 uint256_t uint256_t unbounded +uvos valiray <133289098+valiray@users.noreply.github.com> +vb vik viric vodkaslime <646329483@qq.com> vvhg1 <94630311+vvhg1@users.noreply.github.com> vxiiduu <73044267+vxiiduu@users.noreply.github.com> +wangshuai09 <391746016@qq.com> wbpxre150 <100937007+wbpxre150@users.noreply.github.com> whoreson <139810751+whoreson@users.noreply.github.com> woachk <24752637+woachk@users.noreply.github.com> wonjun Jang woodx <124784234+woodx9@users.noreply.github.com> +wwoodsTM <104587230+wwoodsTM@users.noreply.github.com> wzy <32936898+Freed-Wu@users.noreply.github.com> xaedes xaedes +xctan xloem <0xloem@gmail.com> yangli2 yuiseki +yuri@FreeBSD zakkor zhangkaihuo +zhentaoyu zhouwg <6889919+zhouwg@users.noreply.github.com> zhouwg zrm Ștefan-Gabriel Muscalu +杨朱 · Kiki 源文雨 <41315874+fumiama@users.noreply.github.com> +蕭澧邦 <45505768+shou692199@users.noreply.github.com> Нияз Гарифзянов <112617865+garrnizon@users.noreply.github.com> diff --git a/CMakeLists.txt b/CMakeLists.txt index a31320635..4c62d1788 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,7 @@ endif() list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(LLAMA_STANDALONE ON) @@ -46,6 +47,13 @@ if (WIN32) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) endif() +if (MSVC) + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/utf-8>") + add_compile_options("$<$:/bigobj>") + add_compile_options("$<$:/bigobj>") +endif() + # # option list # @@ -62,6 +70,9 @@ option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) +# utils +option(LLAMA_BUILD_COMMON "llama: build common utils library" ${LLAMA_STANDALONE}) + # extra artifacts option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) @@ -72,21 +83,19 @@ option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) # Required for relocatable CMake package include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake) # override ggml options -set(GGML_SANITIZE_THREAD ${LLAMA_SANITIZE_THREAD}) -set(GGML_SANITIZE_ADDRESS ${LLAMA_SANITIZE_ADDRESS}) -set(GGML_SANITIZE_UNDEFINED ${LLAMA_SANITIZE_UNDEFINED}) -set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS}) -set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS}) +set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS}) +set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS}) # change the default for these ggml options if (NOT DEFINED GGML_LLAMAFILE) - set(GGML_LLAMAFILE ON) + set(GGML_LLAMAFILE_DEFAULT ON) endif() -if (NOT DEFINED GGML_CUDA_USE_GRAPHS) - set(GGML_CUDA_USE_GRAPHS ON) +if (NOT DEFINED GGML_CUDA_GRAPHS) + set(GGML_CUDA_GRAPHS_DEFAULT ON) endif() # transition helpers @@ -108,16 +117,62 @@ llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL) llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) llama_option_depr(WARNING LLAMA_CANN GGML_CANN) +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + message(STATUS "Using -fsanitize=thread") + + add_compile_options(-fsanitize=thread) + link_libraries (-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + message(STATUS "Using -fsanitize=address") + + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries (-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + message(STATUS "Using -fsanitize=undefined") + + add_compile_options(-fsanitize=undefined) + link_libraries (-fsanitize=undefined) + endif() +endif() + # -# build the library +# 3rd-party # if (NOT TARGET ggml) add_subdirectory(ggml) # ... otherwise assume ggml is added by a parent CMakeLists.txt endif() + +# +# build the library +# + add_subdirectory(src) +# +# utils, programs, examples and tests +# + +if (LLAMA_BUILD_COMMON) + add_subdirectory(common) +endif() + +if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) + include(CTest) + add_subdirectory(tests) +endif() + +if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES) + add_subdirectory(examples) + add_subdirectory(pocs) +endif() + # # install # @@ -133,19 +188,14 @@ set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location o set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") +set(LLAMA_PUBLIC_HEADERS + ${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h + ${CMAKE_CURRENT_SOURCE_DIR}/include/llama-cpp.h) -# At the moment some compile definitions are placed within the ggml/src -# directory but not exported on the `ggml` target. This could be improved by -# determining _precisely_ which defines are necessary for the llama-config -# package. -# -get_target_property(GGML_DIRECTORY ggml SOURCE_DIR) -get_directory_property(GGML_DIR_DEFINES DIRECTORY ${GGML_DIRECTORY} COMPILE_DEFINITIONS) -get_target_property(GGML_TARGET_DEFINES ggml COMPILE_DEFINITIONS) -set(GGML_TRANSIENT_DEFINES ${GGML_TARGET_DEFINES} ${GGML_DIR_DEFINES}) -get_target_property(GGML_LINK_LIBRARIES ggml LINK_LIBRARIES) +set_target_properties(llama + PROPERTIES + PUBLIC_HEADER "${LLAMA_PUBLIC_HEADERS}") -set_target_properties(llama PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/llama.h) install(TARGETS llama LIBRARY PUBLIC_HEADER) configure_package_config_file( @@ -183,19 +233,3 @@ configure_file(cmake/llama.pc.in install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" DESTINATION lib/pkgconfig) - -# -# programs, examples and tests -# - -add_subdirectory(common) - -if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) - include(CTest) - add_subdirectory(tests) -endif () - -if (LLAMA_BUILD_EXAMPLES) - add_subdirectory(examples) - add_subdirectory(pocs) -endif() diff --git a/CMakePresets.json b/CMakePresets.json index d22ffa490..13bdd7907 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -24,11 +24,19 @@ "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." } }, - { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, - { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, - { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, - { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, - { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, + { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, + { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, + { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, + { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, + { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, + { "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } }, + + { + "name": "x64-windows-llvm", "hidden": true, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/x64-windows-llvm.cmake" + } + }, { "name": "arm64-windows-msvc", "hidden": true, @@ -48,21 +56,42 @@ } }, - { "name": "arm64-windows-llvm-debug" , "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, - { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, - { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] }, + { + "name": "arm64-apple-clang", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake" + } + }, - { "name": "arm64-windows-msvc-debug" , "inherits": [ "base", "arm64-windows-msvc", "debug" ] }, + { "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, + { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, + { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] }, + + { "name": "arm64-apple-clang-debug", "inherits": [ "base", "arm64-apple-clang", "debug" ] }, + { "name": "arm64-apple-clang-release", "inherits": [ "base", "arm64-apple-clang", "reldbg" ] }, + { "name": "arm64-apple-clang+static-release", "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] }, + + { "name": "arm64-windows-msvc-debug", "inherits": [ "base", "arm64-windows-msvc", "debug" ] }, { "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] }, { "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] }, - { "name": "x64-windows-msvc-debug" , "inherits": [ "base", "debug" ] }, + { "name": "x64-windows-llvm-debug", "inherits": [ "base", "x64-windows-llvm", "debug" ] }, + { "name": "x64-windows-llvm-release", "inherits": [ "base", "x64-windows-llvm", "release" ] }, + { "name": "x64-windows-llvm-reldbg", "inherits": [ "base", "x64-windows-llvm", "reldbg" ] }, + { "name": "x64-windows-llvm+static-release", "inherits": [ "base", "x64-windows-llvm", "reldbg", "static" ] }, + + { "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] }, { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] }, { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, - { "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] }, + { "name": "x64-windows-sycl-debug", "inherits": [ "sycl-base", "debug" ] }, { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, - { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] } + { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }, + + { "name": "x64-windows-vulkan-debug", "inherits": [ "base", "vulkan", "debug" ] }, + { "name": "x64-windows-vulkan-release", "inherits": [ "base", "vulkan", "release" ] } ] } diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 000000000..72d594b46 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,11 @@ +# collaborators can optionally add themselves here to indicate their availability for reviewing related PRs + +/ci/ @ggerganov +/.devops/*.Dockerfile @ngxson +/examples/server/ @ngxson +/ggml/src/ggml-cuda/fattn* @JohannesGaessler +/ggml/src/ggml-cuda/mmq.* @JohannesGaessler +/ggml/src/ggml-cuda/mmv.* @JohannesGaessler +/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler +/ggml/src/ggml-opt.cpp @JohannesGaessler +/ggml/src/gguf.cpp @JohannesGaessler diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a9e000e52..8d411982b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,29 +1,125 @@ # Pull requests (for contributors) - Test your changes: - - Using the commands in the [`tests`](tests) folder. For instance, running the `./tests/test-backend-ops` command tests different backend implementations of the GGML library - - Execute [the full CI locally on your machine](ci/README.md) before publishing -- Please rate the complexity of your PR (i.e. `Review Complexity : Low`, `Review Complexity : Medium`, `Review Complexity : High`). This makes it easier for maintainers to triage the PRs. - - The PR template has a series of review complexity checkboxes `[ ]` that [you can mark as](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/about-task-lists) `[X]` for your convenience -- Consider allowing write access to your branch for faster review + - Execute [the full CI locally on your machine](ci/README.md) before publishing + - Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`) + - If you modified the `ggml` source, run the `test-backend-ops` tool to check whether different backend implementations of the `ggml` operators produce consistent results (this requires access to at least two different `ggml` backends) + - If you modified a `ggml` operator or added a new one, add the corresponding test cases to `test-backend-ops` +- Consider allowing write access to your branch for faster reviews, as reviewers can push commits directly - If your PR becomes stale, don't hesitate to ping the maintainers in the comments # Pull requests (for collaborators) - Squash-merge PRs - Use the following format for the squashed commit title: ` : (#)`. For example: `utils : fix typo in utils.py (#1234)` -- Optionally, pick a `` from here: https://github.com/ggerganov/llama.cpp/wiki/Modules +- Optionally pick a `` from here: https://github.com/ggerganov/llama.cpp/wiki/Modules +- Consider adding yourself to [CODEOWNERS](CODEOWNERS) # Coding guidelines - Avoid adding third-party dependencies, extra files, extra headers, etc. - Always consider cross-compatibility with other operating systems and architectures -- Avoid fancy looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple -- There are no strict rules for the code style, but try to follow the patterns in the code (indentation, spaces, etc.). Vertical alignment makes things more readable and easier to batch edit +- Avoid fancy-looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple +- Vertical alignment makes things more readable and easier to batch edit - Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a` -- Naming usually optimizes for common prefix (see https://github.com/ggerganov/ggml/pull/302#discussion_r1243240963) +- Use sized integer types such as `int32_t` in the public API, e.g. `size_t` may also be appropriate for allocation sizes or byte offsets +- Declare structs with `struct foo {}` instead of `typedef struct foo {} foo` + - In C++ code omit optional `struct` and `enum` keyword whenever they are not necessary + ```cpp + // OK + llama_context * ctx; + const llama_rope_type rope_type; + + // not OK + struct llama_context * ctx; + const enum llama_rope_type rope_type; + ``` + + _(NOTE: this guideline is yet to be applied to the `llama.cpp` codebase. New code should follow this guideline.)_ + +- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` to format the added code +- For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines) - Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices - Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ ![matmul](media/matmul.png) +# Naming guidelines + +- Use `snake_case` for function, variable and type names +- Naming usually optimizes for longest common prefix (see https://github.com/ggerganov/ggml/pull/302#discussion_r1243240963) + + ```cpp + // not OK + int small_number; + int big_number; + + // OK + int number_small; + int number_big; + ``` + +- Enum values are always in upper case and prefixed with the enum name + + ```cpp + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_NONE = 0, + LLAMA_VOCAB_TYPE_SPM = 1, + LLAMA_VOCAB_TYPE_BPE = 2, + LLAMA_VOCAB_TYPE_WPM = 3, + LLAMA_VOCAB_TYPE_UGM = 4, + LLAMA_VOCAB_TYPE_RWKV = 5, + }; + ``` + +- The general naming pattern is `_`, with `` being `_` + + ```cpp + llama_model_init(); // class: "llama_model", method: "init" + llama_sampler_chain_remove(); // class: "llama_sampler_chain", method: "remove" + llama_sampler_get_seed(); // class: "llama_sampler", method: "get_seed" + llama_set_embeddings(); // class: "llama_context", method: "set_embeddings" + llama_n_threads(); // class: "llama_context", method: "n_threads" + llama_adapter_lora_free(); // class: "llama_adapter_lora", method: "free" + ``` + + - The `get` `` can be omitted + - The `` can be omitted if not necessary + - The `_context` suffix of the `` is optional. Use it to disambiguate symbols when needed + - Use `init`/`free` for constructor/destructor `` + +- Use the `_t` suffix when a type is supposed to be opaque to the user - it's not relevant to them if it is a struct or anything else + + ```cpp + typedef struct llama_context * llama_context_t; + + enum llama_pooling_type llama_pooling_type(const llama_context_t ctx); + ``` + + _(NOTE: this guideline is yet to be applied to the `llama.cpp` codebase. New code should follow this guideline)_ + +- C/C++ filenames are all lowercase with dashes. Headers use the `.h` extension. Source files use the `.c` or `.cpp` extension +- Python filenames are all lowercase with underscores + +- _(TODO: abbreviations usage)_ + +# Preprocessor directives + +- _(TODO: add guidelines with examples and apply them to the codebase)_ + + ```cpp + #ifdef FOO + #endif // FOO + ``` + +# Documentation + +- Documentation is a community effort +- When you need to look into the source code to figure out how to use an API consider adding a short summary to the header file for future reference +- When you notice incorrect or outdated documentation, please update it + +# Resources + +The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: + +https://github.com/ggerganov/llama.cpp/projects diff --git a/Makefile b/Makefile index 6053bc17b..ef152d246 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,13 @@ +ifndef LLAMA_MAKEFILE +$(error The Makefile build is deprecated. Use the CMake build instead. For more details, see https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) +endif + # Define the default target now so that it is always the first target BUILD_TARGETS = \ libllava.a \ - llama-baby-llama \ llama-batched \ llama-batched-bench \ llama-bench \ - llama-benchmark-matmult \ llama-cli \ llama-convert-llama2c-to-ggml \ llama-embedding \ @@ -20,6 +22,7 @@ BUILD_TARGETS = \ llama-infill \ llama-llava-cli \ llama-minicpmv-cli\ + llama-qwen2vl-cli\ llama-lookahead \ llama-lookup \ llama-lookup-create \ @@ -35,6 +38,8 @@ BUILD_TARGETS = \ llama-save-load-state \ llama-server \ llama-simple \ + llama-simple-chat \ + llama-run \ llama-speculative \ llama-tokenize \ llama-vdot \ @@ -47,15 +52,15 @@ TEST_TARGETS = \ tests/test-arg-parser \ tests/test-autorelease \ tests/test-backend-ops \ + tests/test-chat \ tests/test-chat-template \ tests/test-double-float \ - tests/test-grad0 \ tests/test-grammar-integration \ tests/test-grammar-parser \ tests/test-json-schema-to-grammar \ tests/test-llama-grammar \ + tests/test-log \ tests/test-model-load-cancel \ - tests/test-opt \ tests/test-quantize-fns \ tests/test-quantize-perf \ tests/test-rope \ @@ -63,11 +68,12 @@ TEST_TARGETS = \ tests/test-tokenizer-0 \ tests/test-tokenizer-1-bpe \ tests/test-tokenizer-1-spm +# tests/test-opt \ # Legacy build targets that were renamed in #7809, but should still be removed when the project is cleaned LEGACY_TARGETS_CLEAN = main quantize quantize-stats perplexity imatrix embedding vdot q8dot convert-llama2c-to-ggml \ simple batched batched-bench save-load-state server gguf gguf-split eval-callback llama-bench libllava.a llava-cli baby-llama \ - retrieval speculative infill tokenize benchmark-matmult parallel export-lora lookahead lookup passkey gritlm + retrieval speculative infill tokenize parallel export-lora lookahead lookup passkey gritlm # Legacy build targets that were renamed in #7809, but we want to build binaries that for them that output a deprecation warning if people try to use them. # We don't want to clutter things too much, so we only build replacements for the most commonly used binaries. @@ -93,11 +99,6 @@ GGML_METAL := 1 DEPRECATE_WARNING := 1 endif -ifdef LLAMA_OPENMP -GGML_OPENMP := 1 -DEPRECATE_WARNING := 1 -endif - ifdef LLAMA_RPC GGML_RPC := 1 DEPRECATE_WARNING := 1 @@ -148,6 +149,14 @@ GGML_NO_METAL := 1 DEPRECATE_WARNING := 1 endif +ifdef LLAMA_DISABLE_LOGS +REMOVE_WARNING := 1 +endif + +ifdef LLAMA_SERVER_VERBOSE +REMOVE_WARNING := 1 +endif + ifndef UNAME_S UNAME_S := $(shell uname -s) endif @@ -248,11 +257,11 @@ endif # Compile flags # -# keep standard at C11 and C++11 -MK_CPPFLAGS = -Iggml/include -Iggml/src -Iinclude -Isrc -Icommon +# keep standard at C11 and C++17 +MK_CPPFLAGS = -Iggml/include -Iggml/src -Iinclude -Isrc -Icommon -DGGML_USE_CPU MK_CFLAGS = -std=c11 -fPIC -MK_CXXFLAGS = -std=c++11 -fPIC -MK_NVCCFLAGS = -std=c++11 +MK_CXXFLAGS = -std=c++17 -fPIC +MK_NVCCFLAGS = -std=c++17 ifdef LLAMA_NO_CCACHE GGML_NO_CCACHE := 1 @@ -288,6 +297,7 @@ endif # some memory allocation are available on Linux through GNU extensions in libc ifeq ($(UNAME_S),Linux) MK_CPPFLAGS += -D_GNU_SOURCE + MK_LDFLAGS += -ldl endif # RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1, @@ -351,18 +361,14 @@ ifdef LLAMA_SANITIZE_UNDEFINED MK_LDFLAGS += -fsanitize=undefined -g endif -ifdef LLAMA_SERVER_VERBOSE - MK_CPPFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE) -endif - ifdef LLAMA_SERVER_SSL MK_CPPFLAGS += -DCPPHTTPLIB_OPENSSL_SUPPORT MK_LDFLAGS += -lssl -lcrypto endif -ifdef LLAMA_DISABLE_LOGS - MK_CPPFLAGS += -DLOG_DISABLE_LOGS -endif # LLAMA_DISABLE_LOGS +ifndef GGML_NO_CPU_AARCH64 + MK_CPPFLAGS += -DGGML_USE_CPU_AARCH64 +endif # warnings WARN_FLAGS = \ @@ -434,13 +440,17 @@ endif # TODO: probably these flags need to be tweaked on some architectures # feel free to update the Makefile for your architecture and send a pull request or issue -ifndef RISCV +ifndef RISCV_CROSS_COMPILE ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686 amd64)) # Use all CPU extensions that are available: MK_CFLAGS += -march=native -mtune=native HOST_CXXFLAGS += -march=native -mtune=native + # Usage AMX build test + #MK_CFLAGS += -march=graniterapids -mtune=graniterapids + #HOST_CXXFLAGS += -march=graniterapids -mtune=graniterapids + # Usage AVX-only #MK_CFLAGS += -mfma -mf16c -mavx #MK_CXXFLAGS += -mfma -mf16c -mavx @@ -514,7 +524,12 @@ ifneq ($(filter loongarch64%,$(UNAME_M)),) MK_CXXFLAGS += -mlasx endif -else +ifneq ($(filter riscv64%,$(UNAME_M)),) + MK_CFLAGS += -march=rv64gcv -mabi=lp64d + MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d +endif + +else # RISC-V CROSS COMPILATION MK_CFLAGS += -march=rv64gcv -mabi=lp64d MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d endif @@ -523,65 +538,62 @@ ifndef GGML_NO_ACCELERATE # Mac OS - include Accelerate framework. # `-framework Accelerate` works both with Apple Silicon and Mac Intel ifeq ($(UNAME_S),Darwin) - MK_CPPFLAGS += -DGGML_USE_ACCELERATE -DGGML_USE_BLAS - MK_CPPFLAGS += -DACCELERATE_NEW_LAPACK - MK_CPPFLAGS += -DACCELERATE_LAPACK_ILP64 - MK_LDFLAGS += -framework Accelerate - OBJ_GGML += ggml/src/ggml-blas.o + MK_CPPFLAGS += -DGGML_USE_ACCELERATE -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE + MK_CPPFLAGS += -DACCELERATE_NEW_LAPACK + MK_CPPFLAGS += -DACCELERATE_LAPACK_ILP64 + MK_LDFLAGS += -framework Accelerate + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o endif endif # GGML_NO_ACCELERATE -ifdef GGML_MUSA - CC := clang - CXX := clang++ - GGML_CUDA := 1 - MK_CPPFLAGS += -DGGML_USE_MUSA -endif - ifndef GGML_NO_OPENMP MK_CPPFLAGS += -DGGML_USE_OPENMP MK_CFLAGS += -fopenmp MK_CXXFLAGS += -fopenmp - ifdef GGML_MUSA - MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp - MK_LDFLAGS += -L/usr/lib/llvm-10/lib - endif # GGML_MUSA endif # GGML_NO_OPENMP ifdef GGML_OPENBLAS - MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas) - MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas) - MK_LDFLAGS += $(shell pkg-config --libs openblas) - OBJ_GGML += ggml/src/ggml-blas.o + MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas) + MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas) + MK_LDFLAGS += $(shell pkg-config --libs openblas) + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o endif # GGML_OPENBLAS ifdef GGML_OPENBLAS64 - MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas64) - MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas64) - MK_LDFLAGS += $(shell pkg-config --libs openblas64) - OBJ_GGML += ggml/src/ggml-blas.o + MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas64) + MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas64) + MK_LDFLAGS += $(shell pkg-config --libs openblas64) + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o endif # GGML_OPENBLAS64 ifdef GGML_BLIS - MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis - MK_LDFLAGS += -lblis -L/usr/local/lib - OBJ_GGML += ggml/src/ggml-blas.o + MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis + MK_LDFLAGS += -lblis -L/usr/local/lib + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o endif # GGML_BLIS ifdef GGML_NVPL - MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas - MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp - OBJ_GGML += ggml/src/ggml-blas.o + MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas + MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp + OBJ_GGML_EXT += ggml/src/ggml-blas/ggml-blas.o endif # GGML_NVPL ifndef GGML_NO_LLAMAFILE - MK_CPPFLAGS += -DGGML_USE_LLAMAFILE - OBJ_GGML += ggml/src/llamafile/sgemm.o + MK_CPPFLAGS += -DGGML_USE_LLAMAFILE + OBJ_GGML_EXT += ggml/src/ggml-cpu/llamafile/sgemm.o endif +ifndef GGML_NO_AMX + MK_CPPFLAGS += -DGGML_USE_AMX + OBJ_GGML_EXT += ggml/src/ggml-cpu/amx/amx.o ggml/src/ggml-cpu/amx/mmq.o +endif + +# only necessary for the CPU backend files +MK_CPPFLAGS += -Iggml/src/ggml-cpu + ifdef GGML_RPC - MK_CPPFLAGS += -DGGML_USE_RPC - OBJ_GGML += ggml/src/ggml-rpc.o + MK_CPPFLAGS += -DGGML_USE_RPC + OBJ_GGML_EXT += ggml/src/ggml-rpc.o endif # GGML_RPC OBJ_CUDA_TMPL = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu)) @@ -596,41 +608,27 @@ else endif # GGML_CUDA_FA_ALL_QUANTS ifdef GGML_CUDA - ifdef GGML_MUSA - ifneq ('', '$(wildcard /opt/musa)') - CUDA_PATH ?= /opt/musa - else - CUDA_PATH ?= /usr/local/musa - endif - - MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include - MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64 - MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22 + ifneq ('', '$(wildcard /opt/cuda)') + CUDA_PATH ?= /opt/cuda else - ifneq ('', '$(wildcard /opt/cuda)') - CUDA_PATH ?= /opt/cuda - else - CUDA_PATH ?= /usr/local/cuda - endif + CUDA_PATH ?= /usr/local/cuda + endif - MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS - MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib - MK_NVCCFLAGS += -use_fast_math - endif # GGML_MUSA + MK_CPPFLAGS += -DGGML_USE_CUDA -DGGML_CUDA_USE_GRAPHS -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include + MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib + MK_NVCCFLAGS += -use_fast_math - OBJ_GGML += ggml/src/ggml-cuda.o - OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) - OBJ_GGML += $(OBJ_CUDA_TMPL) + OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o + OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) + OBJ_GGML_EXT += $(OBJ_CUDA_TMPL) ifdef LLAMA_FATAL_WARNINGS MK_NVCCFLAGS += -Werror all-warnings endif # LLAMA_FATAL_WARNINGS -ifndef GGML_MUSA ifndef JETSON_EOL_MODULE_DETECT MK_NVCCFLAGS += --forward-unknown-to-host-compiler endif # JETSON_EOL_MODULE_DETECT -endif # GGML_MUSA ifdef LLAMA_DEBUG MK_NVCCFLAGS += -lineinfo @@ -643,11 +641,7 @@ endif # GGML_CUDA_DEBUG ifdef GGML_CUDA_NVCC NVCC = $(CCACHE) $(GGML_CUDA_NVCC) else - ifdef GGML_MUSA - NVCC = $(CCACHE) mcc - else - NVCC = $(CCACHE) nvcc - endif # GGML_MUSA + NVCC = $(CCACHE) nvcc endif # GGML_CUDA_NVCC ifdef CUDA_DOCKER_ARCH @@ -656,10 +650,6 @@ else ifndef CUDA_POWER_ARCH MK_NVCCFLAGS += -arch=native endif # CUDA_DOCKER_ARCH -ifdef GGML_CUDA_FORCE_DMMV - MK_NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV -endif # GGML_CUDA_FORCE_DMMV - ifdef GGML_CUDA_FORCE_MMQ MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ endif # GGML_CUDA_FORCE_MMQ @@ -668,20 +658,6 @@ ifdef GGML_CUDA_FORCE_CUBLAS MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS endif # GGML_CUDA_FORCE_CUBLAS -ifdef GGML_CUDA_DMMV_X - MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) -else - MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=32 -endif # GGML_CUDA_DMMV_X - -ifdef GGML_CUDA_MMV_Y - MK_NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) -else ifdef GGML_CUDA_DMMV_Y - MK_NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_DMMV_Y) # for backwards compatibility -else - MK_NVCCFLAGS += -DGGML_CUDA_MMV_Y=1 -endif # GGML_CUDA_MMV_Y - ifdef GGML_CUDA_F16 MK_NVCCFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_F16 @@ -690,12 +666,6 @@ ifdef GGML_CUDA_DMMV_F16 MK_NVCCFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_DMMV_F16 -ifdef GGML_CUDA_KQUANTS_ITER - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) -else - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 -endif - ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) else @@ -719,15 +689,9 @@ define NVCC_COMPILE $(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ endef # NVCC_COMPILE else - ifdef GGML_MUSA -define NVCC_COMPILE - $(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -c $< -o $@ -endef # NVCC_COMPILE - else define NVCC_COMPILE $(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@ endef # NVCC_COMPILE - endif # GGML_MUSA endif # JETSON_EOL_MODULE_DETECT ggml/src/ggml-cuda/%.o: \ @@ -737,8 +701,8 @@ ggml/src/ggml-cuda/%.o: \ ggml/src/ggml-cuda/common.cuh $(NVCC_COMPILE) -ggml/src/ggml-cuda.o: \ - ggml/src/ggml-cuda.cu \ +ggml/src/ggml-cuda/ggml-cuda.o: \ + ggml/src/ggml-cuda/ggml-cuda.cu \ ggml/include/ggml-cuda.h \ ggml/include/ggml.h \ ggml/include/ggml-backend.h \ @@ -749,9 +713,9 @@ ggml/src/ggml-cuda.o: \ endif # GGML_CUDA ifdef GGML_VULKAN - MK_CPPFLAGS += -DGGML_USE_VULKAN - MK_LDFLAGS += $(shell pkg-config --libs vulkan) - OBJ_GGML += ggml/src/ggml-vulkan.o ggml/src/ggml-vulkan-shaders.o + MK_CPPFLAGS += -DGGML_USE_VULKAN + MK_LDFLAGS += $(shell pkg-config --libs vulkan) + OBJ_GGML_EXT += ggml/src/ggml-vulkan.o ggml/src/ggml-vulkan-shaders.o ifdef GGML_VULKAN_CHECK_RESULTS MK_CPPFLAGS += -DGGML_VULKAN_CHECK_RESULTS @@ -781,10 +745,10 @@ GLSLC_CMD = glslc _ggml_vk_genshaders_cmd = $(shell pwd)/vulkan-shaders-gen _ggml_vk_header = ggml/src/ggml-vulkan-shaders.hpp _ggml_vk_source = ggml/src/ggml-vulkan-shaders.cpp -_ggml_vk_input_dir = ggml/src/vulkan-shaders +_ggml_vk_input_dir = ggml/src/ggml-vulkan/vulkan-shaders _ggml_vk_shader_deps = $(echo $(_ggml_vk_input_dir)/*.comp) -ggml/src/ggml-vulkan.o: ggml/src/ggml-vulkan.cpp ggml/include/ggml-vulkan.h $(_ggml_vk_header) $(_ggml_vk_source) +ggml/src/ggml-vulkan.o: ggml/src/ggml-vulkan/ggml-vulkan.cpp ggml/include/ggml-vulkan.h $(_ggml_vk_header) $(_ggml_vk_source) $(CXX) $(CXXFLAGS) $(shell pkg-config --cflags vulkan) -c $< -o $@ $(_ggml_vk_header): $(_ggml_vk_source) @@ -796,12 +760,12 @@ $(_ggml_vk_source): $(_ggml_vk_shader_deps) vulkan-shaders-gen --target-hpp $(_ggml_vk_header) \ --target-cpp $(_ggml_vk_source) -vulkan-shaders-gen: ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp - $(CXX) $(CXXFLAGS) -o $@ $(LDFLAGS) ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +vulkan-shaders-gen: ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp + $(CXX) $(CXXFLAGS) -o $@ $(LDFLAGS) ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp endif # GGML_VULKAN -ifdef GGML_HIPBLAS +ifdef GGML_HIP ifeq ($(wildcard /opt/rocm),) ROCM_PATH ?= /usr AMDGPU_TARGETS ?= $(shell $(shell which amdgpu-arch)) @@ -810,11 +774,7 @@ ifdef GGML_HIPBLAS AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch) endif - GGML_CUDA_DMMV_X ?= 32 - GGML_CUDA_MMV_Y ?= 1 - GGML_CUDA_KQUANTS_ITER ?= 2 - - MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUDA + MK_CPPFLAGS += -DGGML_USE_HIP -DGGML_USE_CUDA ifdef GGML_HIP_UMA MK_CPPFLAGS += -DGGML_HIP_UMA @@ -827,13 +787,6 @@ endif # GGML_HIP_UMA HIPCC ?= $(CCACHE) $(ROCM_PATH)/bin/hipcc HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS)) - HIPFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) - HIPFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) - HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) - -ifdef GGML_CUDA_FORCE_DMMV - HIPFLAGS += -DGGML_CUDA_FORCE_DMMV -endif # GGML_CUDA_FORCE_DMMV ifdef GGML_CUDA_FORCE_MMQ HIPFLAGS += -DGGML_CUDA_FORCE_MMQ @@ -847,12 +800,12 @@ ifdef GGML_CUDA_NO_PEER_COPY HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY endif # GGML_CUDA_NO_PEER_COPY - OBJ_GGML += ggml/src/ggml-cuda.o - OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) - OBJ_GGML += $(OBJ_CUDA_TMPL) + OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o + OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) + OBJ_GGML_EXT += $(OBJ_CUDA_TMPL) -ggml/src/ggml-cuda.o: \ - ggml/src/ggml-cuda.cu \ +ggml/src/ggml-cuda/ggml-cuda.o: \ + ggml/src/ggml-cuda/ggml-cuda.cu \ ggml/include/ggml-cuda.h \ ggml/include/ggml.h \ ggml/include/ggml-backend.h \ @@ -867,70 +820,173 @@ ggml/src/ggml-cuda/%.o: \ ggml/src/ggml-common.h \ ggml/src/ggml-cuda/common.cuh $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< -endif # GGML_HIPBLAS +endif # GGML_HIP + +ifdef GGML_MUSA + ifeq ($(wildcard /opt/musa),) + MUSA_PATH ?= /usr/local/musa + else + MUSA_PATH ?= /opt/musa + endif + MUSA_ARCHITECTURES ?= 21;22 + + MK_CPPFLAGS += -DGGML_USE_MUSA -DGGML_USE_CUDA + MK_LDFLAGS += -L$(MUSA_PATH)/lib -Wl,-rpath=$(MUSA_PATH)/lib + MK_LDFLAGS += -lmusa -lmusart -lmublas + + ifndef GGML_NO_OPENMP + # For Ubuntu Focal + MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp + MK_LDFLAGS += -L/usr/lib/llvm-10/lib + # For Ubuntu Jammy + MK_CPPFLAGS += -I/usr/lib/llvm-14/lib/clang/14.0.0/include + MK_LDFLAGS += -L/usr/lib/llvm-14/lib + endif # GGML_NO_OPENMP + + CC := $(MUSA_PATH)/bin/clang + CXX := $(MUSA_PATH)/bin/clang++ + MCC := $(CCACHE) $(MUSA_PATH)/bin/mcc + + MUSAFLAGS = -x musa -mtgpu + MUSAFLAGS += $(foreach arch,$(subst ;, ,$(MUSA_ARCHITECTURES)),--cuda-gpu-arch=mp_$(arch)) + +ifdef GGML_CUDA_FORCE_MMQ + MUSAFLAGS += -DGGML_CUDA_FORCE_MMQ +endif # GGML_CUDA_FORCE_MMQ + +ifdef GGML_CUDA_FORCE_CUBLAS + MUSAFLAGS += -DGGML_CUDA_FORCE_CUBLAS +endif # GGML_CUDA_FORCE_CUBLAS + +ifdef GGML_CUDA_F16 + MUSAFLAGS += -DGGML_CUDA_F16 +endif # GGML_CUDA_F16 + +ifdef GGML_CUDA_DMMV_F16 + MUSAFLAGS += -DGGML_CUDA_F16 +endif # GGML_CUDA_DMMV_F16 + +ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE + MUSAFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) +else + MUSAFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 +endif # GGML_CUDA_PEER_MAX_BATCH_SIZE + +ifdef GGML_CUDA_NO_PEER_COPY + MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY +endif # GGML_CUDA_NO_PEER_COPY + +ifdef GGML_CUDA_FA_ALL_QUANTS + MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS +endif # GGML_CUDA_FA_ALL_QUANTS + + OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o + OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) + OBJ_GGML_EXT += $(OBJ_CUDA_TMPL) + +ggml/src/ggml-cuda/ggml-cuda.o: \ + ggml/src/ggml-cuda/ggml-cuda.cu \ + ggml/include/ggml-cuda.h \ + ggml/include/ggml.h \ + ggml/include/ggml-backend.h \ + ggml/src/ggml-backend-impl.h \ + ggml/src/ggml-common.h \ + $(wildcard ggml/src/ggml-cuda/*.cuh) + $(MCC) $(CXXFLAGS) $(MUSAFLAGS) -c -o $@ $< + +ggml/src/ggml-cuda/%.o: \ + ggml/src/ggml-cuda/%.cu \ + ggml/include/ggml.h \ + ggml/src/ggml-common.h \ + ggml/src/ggml-cuda/common.cuh + $(MCC) $(CXXFLAGS) $(MUSAFLAGS) -c -o $@ $< +endif # GGML_MUSA ifdef GGML_METAL - MK_CPPFLAGS += -DGGML_USE_METAL - MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit - OBJ_GGML += ggml/src/ggml-metal.o + MK_CPPFLAGS += -DGGML_USE_METAL + MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit + OBJ_GGML_EXT += ggml/src/ggml-metal/ggml-metal.o + +ifdef GGML_METAL_USE_BF16 + MK_CPPFLAGS += -DGGML_METAL_USE_BF16 +endif # GGML_METAL_USE_BF16 ifdef GGML_METAL_NDEBUG MK_CPPFLAGS += -DGGML_METAL_NDEBUG endif ifdef GGML_METAL_EMBED_LIBRARY - MK_CPPFLAGS += -DGGML_METAL_EMBED_LIBRARY - OBJ_GGML += ggml/src/ggml-metal-embed.o + MK_CPPFLAGS += -DGGML_METAL_EMBED_LIBRARY + OBJ_GGML_EXT += ggml/src/ggml-metal-embed.o endif endif # GGML_METAL ifdef GGML_METAL -ggml/src/ggml-metal.o: \ - ggml/src/ggml-metal.m \ +ggml/src/ggml-metal/ggml-metal.o: \ + ggml/src/ggml-metal/ggml-metal.m \ + ggml/src/ggml-metal/ggml-metal-impl.h \ ggml/include/ggml-metal.h \ ggml/include/ggml.h $(CC) $(CFLAGS) -c $< -o $@ ifdef GGML_METAL_EMBED_LIBRARY ggml/src/ggml-metal-embed.o: \ - ggml/src/ggml-metal.metal \ + ggml/src/ggml-metal/ggml-metal.metal \ + ggml/src/ggml-metal/ggml-metal-impl.h \ ggml/src/ggml-common.h @echo "Embedding Metal library" - @sed -e '/#include "ggml-common.h"/r ggml/src/ggml-common.h' -e '/#include "ggml-common.h"/d' < ggml/src/ggml-metal.metal > ggml/src/ggml-metal-embed.metal + @sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp + @sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal $(eval TEMP_ASSEMBLY=$(shell mktemp -d)) - @echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s - @echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s - @echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s - @echo ".incbin \"ggml/src/ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s - @echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s - @echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo "_ggml_metallib_start:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo ".incbin \"ggml/src/ggml-metal/ggml-metal-embed.metal\"" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo ".globl _ggml_metallib_end" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s + @echo "_ggml_metallib_end:" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s $(CC) $(CFLAGS) -c $(TEMP_ASSEMBLY)/ggml-metal-embed.s -o $@ @rm -f ${TEMP_ASSEMBLY}/ggml-metal-embed.s @rmdir ${TEMP_ASSEMBLY} endif endif # GGML_METAL -OBJ_GGML += \ - ggml/src/ggml.o \ - ggml/src/ggml-alloc.o \ - ggml/src/ggml-backend.o \ - ggml/src/ggml-quants.o \ - ggml/src/ggml-aarch64.o +DIR_GGML = ggml +DIR_LLAMA = src +DIR_COMMON = common + +OBJ_GGML = \ + $(DIR_GGML)/src/ggml.o \ + $(DIR_GGML)/src/ggml-alloc.o \ + $(DIR_GGML)/src/ggml-backend.o \ + $(DIR_GGML)/src/ggml-backend-reg.o \ + $(DIR_GGML)/src/ggml-opt.o \ + $(DIR_GGML)/src/ggml-quants.o \ + $(DIR_GGML)/src/ggml-threading.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu_cpp.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu-aarch64.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu-hbm.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu-quants.o \ + $(DIR_GGML)/src/ggml-cpu/ggml-cpu-traits.o \ + $(OBJ_GGML_EXT) OBJ_LLAMA = \ - src/llama.o \ - src/llama-vocab.o \ - src/llama-grammar.o \ - src/llama-sampling.o \ - src/unicode.o \ - src/unicode-data.o + $(DIR_LLAMA)/llama.o \ + $(DIR_LLAMA)/llama-vocab.o \ + $(DIR_LLAMA)/llama-grammar.o \ + $(DIR_LLAMA)/llama-sampling.o \ + $(DIR_LLAMA)/unicode.o \ + $(DIR_LLAMA)/unicode-data.o OBJ_COMMON = \ - common/common.o \ - common/console.o \ - common/ngram-cache.o \ - common/sampling.o \ - common/train.o \ - common/build-info.o \ - common/json-schema-to-grammar.o + $(DIR_COMMON)/common.o \ + $(DIR_COMMON)/arg.o \ + $(DIR_COMMON)/log.o \ + $(DIR_COMMON)/console.o \ + $(DIR_COMMON)/ngram-cache.o \ + $(DIR_COMMON)/sampling.o \ + $(DIR_COMMON)/speculative.o \ + $(DIR_COMMON)/chat.o \ + $(DIR_COMMON)/build-info.o \ + $(DIR_COMMON)/json-schema-to-grammar.o OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON) @@ -986,7 +1042,6 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1)) ifdef GGML_CUDA $(info I NVCC: $(shell $(NVCC) --version | tail -n 1)) CUDA_VERSION := $(shell $(NVCC) --version | grep -oP 'release (\K[0-9]+\.[0-9])') -ifndef GGML_MUSA ifeq ($(shell awk -v "v=$(CUDA_VERSION)" 'BEGIN { print (v < 11.7) }'),1) ifndef CUDA_DOCKER_ARCH @@ -996,7 +1051,6 @@ endif # CUDA_POWER_ARCH endif # CUDA_DOCKER_ARCH endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1) -endif # GGML_MUSA endif # GGML_CUDA $(info ) @@ -1021,202 +1075,90 @@ $(info - LLAMA_NO_CCACHE) $(info ) endif +ifdef REMOVE_WARNING +$(info !!! REMOVAL WARNING !!!) +$(info The following LLAMA_ options have been removed and are no longer supported) +$(info - LLAMA_DISABLE_LOGS (https://github.com/ggerganov/llama.cpp/pull/9418)) +$(info - LLAMA_SERVER_VERBOSE (https://github.com/ggerganov/llama.cpp/pull/9418)) +$(info ) +endif + # # Build libraries # -# ggml +# Libraries +LIB_GGML = libggml.so +LIB_GGML_S = libggml.a -ggml/src/ggml.o: \ - ggml/src/ggml.c \ - ggml/include/ggml.h - $(CC) $(CFLAGS) -c $< -o $@ +LIB_LLAMA = libllama.so +LIB_LLAMA_S = libllama.a -ggml/src/ggml-alloc.o: \ - ggml/src/ggml-alloc.c \ - ggml/include/ggml.h \ - ggml/include/ggml-alloc.h - $(CC) $(CFLAGS) -c $< -o $@ +LIB_COMMON = libcommon.so +LIB_COMMON_S = libcommon.a -ggml/src/ggml-backend.o: \ - ggml/src/ggml-backend.c \ - ggml/include/ggml.h \ - ggml/include/ggml-backend.h - $(CC) $(CFLAGS) -c $< -o $@ +# Targets +BUILD_TARGETS += $(LIB_GGML) $(LIB_GGML_S) $(LIB_LLAMA) $(LIB_LLAMA_S) $(LIB_COMMON) $(LIB_COMMON_S) -ggml/src/ggml-quants.o: \ - ggml/src/ggml-quants.c \ - ggml/include/ggml.h \ - ggml/src/ggml-quants.h \ - ggml/src/ggml-common.h - $(CC) $(CFLAGS) -c $< -o $@ +# Dependency files +DEP_FILES = $(OBJ_GGML:.o=.d) $(OBJ_LLAMA:.o=.d) $(OBJ_COMMON:.o=.d) -ggml/src/ggml-aarch64.o: \ - ggml/src/ggml-aarch64.c \ - ggml/include/ggml.h \ - ggml/src/ggml-aarch64.h \ - ggml/src/ggml-common.h - $(CC) $(CFLAGS) -c $< -o $@ +# Default target +all: $(BUILD_TARGETS) -ggml/src/ggml-blas.o: \ - ggml/src/ggml-blas.cpp \ - ggml/include/ggml-blas.h - $(CXX) $(CXXFLAGS) -c $< -o $@ +# force c++ build for source file that have same name as c file +# Note: need this exception because `ggml-cpu.c` and `ggml-cpu.cpp` both produce the same obj/dep files +$(DIR_GGML)/%_cpp.o: $(DIR_GGML)/%.cpp + $(CXX) $(CXXFLAGS) -MMD -c $< -o $@ -ifndef GGML_NO_LLAMAFILE -ggml/src/llamafile/sgemm.o: \ - ggml/src/llamafile/sgemm.cpp \ - ggml/src/llamafile/sgemm.h \ - ggml/include/ggml.h - $(CXX) $(CXXFLAGS) -c $< -o $@ -endif # GGML_NO_LLAMAFILE +# Rules for building object files +$(DIR_GGML)/%.o: $(DIR_GGML)/%.c + $(CC) $(CFLAGS) -MMD -c $< -o $@ -ifdef GGML_RPC -ggml/src/ggml-rpc.o: \ - ggml/src/ggml-rpc.cpp \ - ggml/include/ggml-rpc.h - $(CXX) $(CXXFLAGS) -c $< -o $@ -endif # GGML_RPC +$(DIR_GGML)/%.o: $(DIR_GGML)/%.cpp + $(CXX) $(CXXFLAGS) -MMD -c $< -o $@ -$(LIB_GGML): \ - $(OBJ_GGML) +$(DIR_LLAMA)/%.o: $(DIR_LLAMA)/%.cpp + $(CXX) $(CXXFLAGS) -MMD -c $< -o $@ + +$(DIR_COMMON)/%.o: $(DIR_COMMON)/%.cpp + $(CXX) $(CXXFLAGS) -MMD -c $< -o $@ + +# Rules for building libraries +$(LIB_GGML): $(OBJ_GGML) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) -$(LIB_GGML_S): \ - $(OBJ_GGML) +$(LIB_GGML_S): $(OBJ_GGML) ar rcs $(LIB_GGML_S) $^ -# llama - -src/unicode.o: \ - src/unicode.cpp \ - src/unicode.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -src/unicode-data.o: \ - src/unicode-data.cpp \ - src/unicode-data.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -src/llama.o: \ - src/llama.cpp \ - src/llama-impl.h \ - src/llama-vocab.h \ - src/llama-grammar.h \ - src/llama-sampling.h \ - src/unicode.h \ - include/llama.h \ - ggml/include/ggml-cuda.h \ - ggml/include/ggml-metal.h \ - ggml/include/ggml.h \ - ggml/include/ggml-alloc.h \ - ggml/include/ggml-backend.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -src/llama-vocab.o: \ - src/llama-vocab.cpp \ - src/llama-vocab.h \ - src/llama-impl.h \ - include/llama.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -src/llama-grammar.o: \ - src/llama-grammar.cpp \ - src/llama-grammar.h \ - src/llama-impl.h \ - src/llama-vocab.h \ - src/llama-sampling.h \ - include/llama.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -src/llama-sampling.o: \ - src/llama-sampling.cpp \ - src/llama-sampling.h \ - src/llama-impl.h \ - include/llama.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -$(LIB_LLAMA): \ - $(OBJ_LLAMA) \ - $(LIB_GGML) +$(LIB_LLAMA): $(OBJ_LLAMA) $(LIB_GGML) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) -$(LIB_LLAMA_S): \ - $(OBJ_LLAMA) +$(LIB_LLAMA_S): $(OBJ_LLAMA) ar rcs $(LIB_LLAMA_S) $^ -# common - -common/common.o: \ - common/common.cpp \ - common/common.h \ - common/console.h \ - common/sampling.h \ - common/json.hpp \ - common/json-schema-to-grammar.h \ - include/llama.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -common/sampling.o: \ - common/sampling.cpp \ - common/sampling.h \ - include/llama.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -common/console.o: \ - common/console.cpp \ - common/console.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -common/json-schema-to-grammar.o: \ - common/json-schema-to-grammar.cpp \ - common/json-schema-to-grammar.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -common/train.o: \ - common/train.cpp \ - common/train.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -common/ngram-cache.o: \ - common/ngram-cache.cpp \ - common/ngram-cache.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - -$(LIB_COMMON): \ - $(OBJ_COMMON) \ - $(LIB_LLAMA) \ - $(LIB_GGML) +$(LIB_COMMON): $(OBJ_COMMON) $(LIB_LLAMA) $(LIB_GGML) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) -$(LIB_COMMON_S): \ - $(OBJ_COMMON) +$(LIB_COMMON_S): $(OBJ_COMMON) ar rcs $(LIB_COMMON_S) $^ -clean: - rm -vrf *.dot $(BUILD_TARGETS) $(TEST_TARGETS) - rm -rvf src/*.o - rm -rvf tests/*.o - rm -rvf examples/*.o - rm -rvf common/*.o - rm -rvf *.a - rm -rvf *.dll - rm -rvf *.so - rm -rvf *.dot - rm -rvf ggml/*.a - rm -rvf ggml/*.dll - rm -rvf ggml/*.so - rm -vrf ggml/src/*.o - rm -rvf ggml/src/llamafile/*.o - rm -rvf common/build-info.cpp - rm -vrf ggml/src/ggml-metal-embed.metal - rm -vrf ggml/src/ggml-cuda/*.o - rm -vrf ggml/src/ggml-cuda/template-instances/*.o - rm -rvf $(BUILD_TARGETS) - rm -rvf $(TEST_TARGETS) - rm -f vulkan-shaders-gen ggml/src/ggml-vulkan-shaders.hpp ggml/src/ggml-vulkan-shaders.cpp - rm -rvf $(LEGACY_TARGETS_CLEAN) - find examples pocs -type f -name "*.o" -delete +# Include dependency files +-include $(DEP_FILES) + +# Clean generated server assets +clean-server-assets: + find examples/server -type f -name "*.js.hpp" -delete + find examples/server -type f -name "*.mjs.hpp" -delete + find examples/server -type f -name "*.css.hpp" -delete + find examples/server -type f -name "*.html.hpp" -delete + +# Clean rule +clean: clean-server-assets + rm -vrf $(BUILD_TARGETS) $(TEST_TARGETS) + rm -rvf *.a *.dll *.so *.dot + find ggml src common tests examples pocs -type f -name "*.o" -delete + find ggml src common tests examples pocs -type f -name "*.d" -delete # # Examples @@ -1242,11 +1184,21 @@ llama-infill: examples/infill/infill.cpp \ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +llama-run: examples/run/run.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + llama-simple: examples/simple/simple.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +llama-simple-chat: examples/simple-chat/simple-chat.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + llama-tokenize: examples/tokenize/tokenize.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) @@ -1335,16 +1287,11 @@ llama-cvector-generator: examples/cvector-generator/cvector-generator.cpp \ $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) llama-convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp \ - $(OBJ_GGML) $(OBJ_LLAMA) - $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) - $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) - -llama-bench: examples/llama-bench/llama-bench.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-baby-llama: examples/baby-llama/baby-llama.cpp \ +llama-bench: examples/llama-bench/llama-bench.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) @@ -1414,29 +1361,19 @@ llama-server: \ examples/server/server.cpp \ examples/server/utils.hpp \ examples/server/httplib.h \ - examples/server/colorthemes.css.hpp \ - examples/server/style.css.hpp \ - examples/server/theme-beeninorder.css.hpp \ - examples/server/theme-ketivah.css.hpp \ - examples/server/theme-mangotango.css.hpp \ - examples/server/theme-playground.css.hpp \ - examples/server/theme-polarnight.css.hpp \ - examples/server/theme-snowstorm.css.hpp \ examples/server/index.html.hpp \ - examples/server/index-new.html.hpp \ - examples/server/index.js.hpp \ - examples/server/completion.js.hpp \ - examples/server/system-prompts.js.hpp \ - examples/server/prompt-formats.js.hpp \ - examples/server/json-schema-to-grammar.mjs.hpp \ + examples/server/loading.html.hpp \ + common/chat.cpp \ + common/chat.hpp \ + common/chat-template.hpp \ common/json.hpp \ - common/stb_image.h \ + common/minja.hpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) # Portable equivalent of `cd examples/server/public && xxd -i $(notdir $<) ../$(notdir $<).hpp`: -examples/server/%.hpp: examples/server/public/% Makefile +examples/server/%.hpp: examples/server/public/% FORCE Makefile @( export NAME=$(subst .,_,$(subst -,_,$(notdir $<))) && \ echo "unsigned char $${NAME}[] = {" && \ cat $< | od -v -t x1 -An | sed -E 's/([0-9a-fA-F]+)/0x\1, /g' && \ @@ -1448,7 +1385,6 @@ llama-gen-docs: examples/gen-docs/gen-docs.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) - ./llama-gen-docs libllava.a: examples/llava/llava.cpp \ examples/llava/llava.h \ @@ -1475,6 +1411,14 @@ llama-minicpmv-cli: examples/llava/minicpmv-cli.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual +llama-qwen2vl-cli: examples/llava/qwen2vl-cli.cpp \ + examples/llava/llava.cpp \ + examples/llava/llava.h \ + examples/llava/clip.cpp \ + examples/llava/clip.h \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual + ifeq ($(UNAME_S),Darwin) swift: examples/batched.swift (cd examples/batched.swift; make build) @@ -1497,16 +1441,6 @@ common/build-info.o: common/build-info.cpp tests: $(TEST_TARGETS) -llama-benchmark-matmult: examples/benchmark/benchmark-matmult.cpp \ - $(OBJ_GGML) common/build-info.o - $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) - $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) - -run-benchmark-matmult: llama-benchmark-matmult - ./$@ - -.PHONY: run-benchmark-matmult swift - tests/test-arg-parser: tests/test-arg-parser.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) @@ -1517,6 +1451,11 @@ tests/test-llama-grammar: tests/test-llama-grammar.cpp \ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +tests/test-log: tests/test-log.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + tests/test-grammar-parser: tests/test-grammar-parser.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) @@ -1536,9 +1475,9 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \ $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -tests/test-grad0: tests/test-grad0.cpp \ - $(OBJ_GGML) - $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) +tests/test-chat: tests/test-chat.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) tests/test-opt: tests/test-opt.cpp \ @@ -1622,7 +1561,7 @@ llama-q8dot: pocs/vdot/q8dot.cpp ggml/src/ggml.o \ # Deprecated binaries that we want to keep around long enough for people to migrate to the new filenames, then these can be removed. # # Mark legacy binary targets as .PHONY so that they are always checked. -.PHONY: main quantize perplexity embedding server +.PHONY: FORCE main quantize perplexity embedding server # Define the object file target examples/deprecation-warning/deprecation-warning.o: examples/deprecation-warning/deprecation-warning.cpp diff --git a/Package.swift b/Package.swift index 1d90b47bf..01c996d24 100644 --- a/Package.swift +++ b/Package.swift @@ -2,48 +2,6 @@ import PackageDescription -var sources = [ - "src/llama.cpp", - "src/llama-vocab.cpp", - "src/llama-grammar.cpp", - "src/llama-sampling.cpp", - "src/unicode.cpp", - "src/unicode-data.cpp", - "ggml/src/ggml.c", - "ggml/src/ggml-alloc.c", - "ggml/src/ggml-backend.c", - "ggml/src/ggml-quants.c", - "ggml/src/ggml-aarch64.c", -] - -var resources: [Resource] = [] -var linkerSettings: [LinkerSetting] = [] -var cSettings: [CSetting] = [ - .unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]), - .unsafeFlags(["-fno-objc-arc"]), - // NOTE: NEW_LAPACK will required iOS version 16.4+ - // We should consider add this in the future when we drop support for iOS 14 - // (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc) - // .define("ACCELERATE_NEW_LAPACK"), - // .define("ACCELERATE_LAPACK_ILP64") -] - -#if canImport(Darwin) -sources.append("ggml/src/ggml-metal.m") -resources.append(.process("ggml/src/ggml-metal.metal")) -linkerSettings.append(.linkedFramework("Accelerate")) -cSettings.append( - contentsOf: [ - .define("GGML_USE_ACCELERATE"), - .define("GGML_USE_METAL") - ] -) -#endif - -#if os(Linux) - cSettings.append(.define("_GNU_SOURCE")) -#endif - let package = Package( name: "llama", platforms: [ @@ -56,24 +14,6 @@ let package = Package( .library(name: "llama", targets: ["llama"]), ], targets: [ - .target( - name: "llama", - path: ".", - exclude: [ - "cmake", - "examples", - "scripts", - "models", - "tests", - "CMakeLists.txt", - "Makefile" - ], - sources: sources, - resources: resources, - publicHeadersPath: "spm-headers", - cSettings: cSettings, - linkerSettings: linkerSettings - ) - ], - cxxLanguageStandard: .cxx11 + .systemLibrary(name: "llama", pkgConfig: "llama"), + ] ) diff --git a/README.md b/README.md index e30ab0c8c..d40309875 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![Server](https://github.com/ggerganov/llama.cpp/actions/workflows/server.yml/badge.svg)](https://github.com/ggerganov/llama.cpp/actions/workflows/server.yml) -[![Conan Center](https://shields.io/conan/v/llama-cpp)](https://conan.io/center/llama-cpp) [Roadmap](https://github.com/users/ggerganov/projects/7) / [Project status](https://github.com/ggerganov/llama.cpp/discussions/3471) / [Manifesto](https://github.com/ggerganov/llama.cpp/discussions/205) / [ggml](https://github.com/ggerganov/ggml) @@ -17,31 +16,40 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ## Hot topics -- *add hot topics here* +- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggerganov/llama.cpp/pull/11427 +- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode +- Universal tool call support in `llama-server`: https://github.com/ggerganov/llama.cpp/pull/9639 +- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim +- Introducing GGUF-my-LoRA https://github.com/ggerganov/llama.cpp/discussions/10123 +- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggerganov/llama.cpp/discussions/9669 +- Hugging Face GGUF editor: [discussion](https://github.com/ggerganov/llama.cpp/discussions/9268) | [tool](https://huggingface.co/spaces/CISCai/gguf-editor) ---- ## Description The main goal of `llama.cpp` is to enable LLM inference with minimal setup and state-of-the-art performance on a wide -variety of hardware - locally and in the cloud. +range of hardware - locally and in the cloud. - Plain C/C++ implementation without any dependencies - Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks -- AVX, AVX2 and AVX512 support for x86 architectures +- AVX, AVX2, AVX512 and AMX support for x86 architectures - 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use -- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP) +- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads MTT GPUs via MUSA) - Vulkan and SYCL backend support - CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity -Since its [inception](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022), the project has -improved significantly thanks to many contributions. It is the main playground for developing new features for the -[ggml](https://github.com/ggerganov/ggml) library. +The `llama.cpp` project is the main playground for developing new features for the [ggml](https://github.com/ggerganov/ggml) library. -**Supported models:** +
+Models Typically finetunes of the base models below are supported as well. +Instructions for adding support for new models: [HOWTO-add-model.md](docs/development/HOWTO-add-model.md) + +#### Text-only + - [X] LLaMA 🦙 - [x] LLaMA 2 🦙🦙 - [x] LLaMA 3 🦙🦙🦙 @@ -65,6 +73,7 @@ Typically finetunes of the base models below are supported as well. - [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen) - [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557) - [x] [Phi models](https://huggingface.co/models?search=microsoft/phi) +- [x] [PhiMoE](https://github.com/ggerganov/llama.cpp/pull/11003) - [x] [GPT-2](https://huggingface.co/gpt2) - [x] [Orion 14B](https://github.com/ggerganov/llama.cpp/pull/5118) - [x] [InternLM2](https://huggingface.co/models?search=internlm2) @@ -77,6 +86,8 @@ Typically finetunes of the base models below are supported as well. - [x] [SEA-LION](https://huggingface.co/models?search=sea-lion) - [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B) - [x] [OLMo](https://allenai.org/olmo) +- [x] [OLMo 2](https://allenai.org/olmo) +- [x] [OLMoE](https://huggingface.co/allenai/OLMoE-1B-7B-0924) - [x] [Granite models](https://huggingface.co/collections/ibm-granite/granite-code-models-6624c5cec322e4c148c8b330) - [x] [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) + [Pythia](https://github.com/EleutherAI/pythia) - [x] [Snowflake-Arctic MoE](https://huggingface.co/collections/Snowflake/arctic-66290090abe542894a5ac520) @@ -89,10 +100,13 @@ Typically finetunes of the base models below are supported as well. - [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) - [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) - [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) +- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat) +- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a) +- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM) +- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1) +- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct) -(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md)) - -**Multimodal models:** +#### Multimodal - [x] [LLaVA 1.5 models](https://huggingface.co/collections/liuhaotian/llava-15-653aac15d994e992e2677a7e), [LLaVA 1.6 models](https://huggingface.co/collections/liuhaotian/llava-16-65b9e40155f60fd046a5ccf2) - [x] [BakLLaVA](https://huggingface.co/models?search=SkunkworksAI/Bakllava) @@ -103,13 +117,18 @@ Typically finetunes of the base models below are supported as well. - [x] [Mini CPM](https://huggingface.co/models?search=MiniCPM) - [x] [Moondream](https://huggingface.co/vikhyatk/moondream2) - [x] [Bunny](https://github.com/BAAI-DCAI/Bunny) +- [x] [Qwen2-VL](https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d) -**Bindings:** +
+ +
+Bindings - Python: [abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python) - Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) - Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp) - JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp) +- JS/TS (Programmable Prompt Engine CLI): [offline-ai/cli](https://github.com/offline-ai/cli) - JavaScript/Wasm (works in browser): [tangledgroup/llama-cpp-wasm](https://github.com/tangledgroup/llama-cpp-wasm) - Typescript/Wasm (nicer API, available on npm): [ngxson/wllama](https://github.com/ngxson/wllama) - Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb) @@ -117,347 +136,374 @@ Typically finetunes of the base models below are supported as well. - Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp) - Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs) - C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp) +- C#/VB.NET (more features - community license): [LM-Kit.NET](https://docs.lm-kit.com/lm-kit-net/index.html) - Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s) - Clojure: [phronmophobic/llama.clj](https://github.com/phronmophobic/llama.clj) - React Native: [mybigday/llama.rn](https://github.com/mybigday/llama.rn) - Java: [kherud/java-llama.cpp](https://github.com/kherud/java-llama.cpp) - Zig: [deins/llama.cpp.zig](https://github.com/Deins/llama.cpp.zig) - Flutter/Dart: [netdur/llama_cpp_dart](https://github.com/netdur/llama_cpp_dart) +- Flutter: [xuegao-tzx/Fllama](https://github.com/xuegao-tzx/Fllama) - PHP (API bindings and features built on top of llama.cpp): [distantmagic/resonance](https://github.com/distantmagic/resonance) [(more info)](https://github.com/ggerganov/llama.cpp/pull/6326) - Guile Scheme: [guile_llama_cpp](https://savannah.nongnu.org/projects/guile-llama-cpp) +- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift) +- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama) -**UI:** +
-Unless otherwise noted these projects are open-source with permissive licensing: - -- [MindWorkAI/AI-Studio](https://github.com/MindWorkAI/AI-Studio) (FSL-1.1-MIT) -- [iohub/collama](https://github.com/iohub/coLLaMA) -- [janhq/jan](https://github.com/janhq/jan) (AGPL) -- [nat/openplayground](https://github.com/nat/openplayground) -- [Faraday](https://faraday.dev/) (proprietary) -- [LMStudio](https://lmstudio.ai/) (proprietary) -- [Layla](https://play.google.com/store/apps/details?id=com.laylalite) (proprietary) -- [ramalama](https://github.com/containers/ramalama) (MIT) -- [LocalAI](https://github.com/mudler/LocalAI) (MIT) -- [LostRuins/koboldcpp](https://github.com/LostRuins/koboldcpp) (AGPL) -- [Mozilla-Ocho/llamafile](https://github.com/Mozilla-Ocho/llamafile) -- [nomic-ai/gpt4all](https://github.com/nomic-ai/gpt4all) -- [ollama/ollama](https://github.com/ollama/ollama) -- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (AGPL) -- [psugihara/FreeChat](https://github.com/psugihara/FreeChat) -- [cztomsik/ava](https://github.com/cztomsik/ava) (MIT) -- [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal) -- [pythops/tenere](https://github.com/pythops/tenere) (AGPL) -- [RAGNA Desktop](https://ragna.app/) (proprietary) -- [RecurseChat](https://recurse.chat/) (proprietary) -- [semperai/amica](https://github.com/semperai/amica) -- [withcatai/catai](https://github.com/withcatai/catai) -- [Mobile-Artificial-Intelligence/maid](https://github.com/Mobile-Artificial-Intelligence/maid) (MIT) -- [Msty](https://msty.app) (proprietary) -- [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT) -- [KanTV](https://github.com/zhouwg/kantv?tab=readme-ov-file)(Apachev2.0 or later) -- [Dot](https://github.com/alexpinel/Dot) (GPL) -- [MindMac](https://mindmac.app) (proprietary) -- [KodiBot](https://github.com/firatkiral/kodibot) (GPL) -- [eva](https://github.com/ylsdamxssjxxdd/eva) (MIT) -- [AI Sublime Text plugin](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (MIT) -- [AIKit](https://github.com/sozercan/aikit) (MIT) -- [LARS - The LLM & Advanced Referencing Solution](https://github.com/abgulati/LARS) (AGPL) +
+UIs *(to have a project listed here, it should clearly state that it depends on `llama.cpp`)* -**Tools:** +- [AI Sublime Text plugin](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (MIT) +- [cztomsik/ava](https://github.com/cztomsik/ava) (MIT) +- [Dot](https://github.com/alexpinel/Dot) (GPL) +- [eva](https://github.com/ylsdamxssjxxdd/eva) (MIT) +- [iohub/collama](https://github.com/iohub/coLLaMA) (Apache-2.0) +- [janhq/jan](https://github.com/janhq/jan) (AGPL) +- [KanTV](https://github.com/zhouwg/kantv?tab=readme-ov-file) (Apache-2.0) +- [KodiBot](https://github.com/firatkiral/kodibot) (GPL) +- [llama.vim](https://github.com/ggml-org/llama.vim) (MIT) +- [LARS](https://github.com/abgulati/LARS) (AGPL) +- [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL) +- [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT) +- [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT) +- [LMStudio](https://lmstudio.ai/) (proprietary) +- [LocalAI](https://github.com/mudler/LocalAI) (MIT) +- [LostRuins/koboldcpp](https://github.com/LostRuins/koboldcpp) (AGPL) +- [MindMac](https://mindmac.app) (proprietary) +- [MindWorkAI/AI-Studio](https://github.com/MindWorkAI/AI-Studio) (FSL-1.1-MIT) +- [Mobile-Artificial-Intelligence/maid](https://github.com/Mobile-Artificial-Intelligence/maid) (MIT) +- [Mozilla-Ocho/llamafile](https://github.com/Mozilla-Ocho/llamafile) (Apache-2.0) +- [nat/openplayground](https://github.com/nat/openplayground) (MIT) +- [nomic-ai/gpt4all](https://github.com/nomic-ai/gpt4all) (MIT) +- [ollama/ollama](https://github.com/ollama/ollama) (MIT) +- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (AGPL) +- [PocketPal AI](https://github.com/a-ghorbani/pocketpal-ai) (MIT) +- [psugihara/FreeChat](https://github.com/psugihara/FreeChat) (MIT) +- [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal) (MIT) +- [pythops/tenere](https://github.com/pythops/tenere) (AGPL) +- [ramalama](https://github.com/containers/ramalama) (MIT) +- [semperai/amica](https://github.com/semperai/amica) (MIT) +- [withcatai/catai](https://github.com/withcatai/catai) (MIT) + +
+ +
+Tools - [akx/ggify](https://github.com/akx/ggify) – download PyTorch models from HuggingFace Hub and convert them to GGML +- [akx/ollama-dl](https://github.com/akx/ollama-dl) – download models from the Ollama library to be used directly with llama.cpp - [crashr/gppm](https://github.com/crashr/gppm) – launch llama.cpp instances utilizing NVIDIA Tesla P40 or P100 GPUs with reduced idle power consumption - [gpustack/gguf-parser](https://github.com/gpustack/gguf-parser-go/tree/main/cmd/gguf-parser) - review/check the GGUF file and estimate the memory usage +- [Styled Lines](https://marketplace.unity.com/packages/tools/generative-ai/styled-lines-llama-cpp-model-292902) (proprietary licensed, async wrapper of inference part for game development in Unity3d with pre-built Mobile and Web platform wrappers and a model example) -**Infrastructure:** +
+ +
+Infrastructure - [Paddler](https://github.com/distantmagic/paddler) - Stateful load balancer custom-tailored for llama.cpp - [GPUStack](https://github.com/gpustack/gpustack) - Manage GPU clusters for running LLMs +- [llama_cpp_canister](https://github.com/onicai/llama_cpp_canister) - llama.cpp as a smart contract on the Internet Computer, using WebAssembly +- [llama-swap](https://github.com/mostlygeek/llama-swap) - transparent proxy that adds automatic model switching with llama-server +- [Kalavai](https://github.com/kalavai-net/kalavai-client) - Crowdsource end to end LLM deployment at any scale + +
+ +
+Games -**Games:** - [Lucy's Labyrinth](https://github.com/MorganRO8/Lucys_Labyrinth) - A simple maze game where agents controlled by an AI model will try to trick you. -## Demo - -
-Typical run using LLaMA v2 13B on M2 Ultra - -``` -$ make -j && ./llama-cli -m models/llama-13b-v2/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -I llama.cpp build info: -I UNAME_S: Darwin -I UNAME_P: arm -I UNAME_M: arm64 -I CFLAGS: -I. -O3 -std=c11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes -pthread -DGGML_USE_K_QUANTS -DGGML_USE_ACCELERATE -I CXXFLAGS: -I. -I./common -O3 -std=c++11 -fPIC -DNDEBUG -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar -pthread -DGGML_USE_K_QUANTS -I LDFLAGS: -framework Accelerate -I CC: Apple clang version 14.0.3 (clang-1403.0.22.14.1) -I CXX: Apple clang version 14.0.3 (clang-1403.0.22.14.1) - -make: Nothing to be done for `default'. -main: build = 1041 (cf658ad) -main: seed = 1692823051 -llama_model_loader: loaded meta data with 16 key-value pairs and 363 tensors from models/llama-13b-v2/ggml-model-q4_0.gguf (version GGUF V1 (latest)) -llama_model_loader: - type f32: 81 tensors -llama_model_loader: - type q4_0: 281 tensors -llama_model_loader: - type q6_K: 1 tensors -llm_load_print_meta: format = GGUF V1 (latest) -llm_load_print_meta: arch = llama -llm_load_print_meta: vocab type = SPM -llm_load_print_meta: n_vocab = 32000 -llm_load_print_meta: n_merges = 0 -llm_load_print_meta: n_ctx_train = 4096 -llm_load_print_meta: n_ctx = 512 -llm_load_print_meta: n_embd = 5120 -llm_load_print_meta: n_head = 40 -llm_load_print_meta: n_head_kv = 40 -llm_load_print_meta: n_layer = 40 -llm_load_print_meta: n_rot = 128 -llm_load_print_meta: n_gqa = 1 -llm_load_print_meta: f_norm_eps = 1.0e-05 -llm_load_print_meta: f_norm_rms_eps = 1.0e-05 -llm_load_print_meta: n_ff = 13824 -llm_load_print_meta: freq_base = 10000.0 -llm_load_print_meta: freq_scale = 1 -llm_load_print_meta: model type = 13B -llm_load_print_meta: model ftype = mostly Q4_0 -llm_load_print_meta: model size = 13.02 B -llm_load_print_meta: general.name = LLaMA v2 -llm_load_print_meta: BOS token = 1 '' -llm_load_print_meta: EOS token = 2 '' -llm_load_print_meta: UNK token = 0 '' -llm_load_print_meta: LF token = 13 '<0x0A>' -llm_load_tensors: ggml ctx size = 0.11 MB -llm_load_tensors: mem required = 7024.01 MB (+ 400.00 MB per state) -................................................................................................... -llama_new_context_with_model: kv self size = 400.00 MB -llama_new_context_with_model: compute buffer total size = 75.41 MB - -system_info: n_threads = 16 / 24 | AVX = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 0 | NEON = 1 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | VSX = 0 | -sampling: repeat_last_n = 64, repeat_penalty = 1.100000, presence_penalty = 0.000000, frequency_penalty = 0.000000, top_k = 40, tfs_z = 1.000000, top_p = 0.950000, typical_p = 1.000000, temp = 0.800000, mirostat = 0, mirostat_lr = 0.100000, mirostat_ent = 5.000000 -generate: n_ctx = 512, n_batch = 512, n_predict = 400, n_keep = 0 - - - Building a website can be done in 10 simple steps: -Step 1: Find the right website platform. -Step 2: Choose your domain name and hosting plan. -Step 3: Design your website layout. -Step 4: Write your website content and add images. -Step 5: Install security features to protect your site from hackers or spammers -Step 6: Test your website on multiple browsers, mobile devices, operating systems etc… -Step 7: Test it again with people who are not related to you personally – friends or family members will work just fine! -Step 8: Start marketing and promoting the website via social media channels or paid ads -Step 9: Analyze how many visitors have come to your site so far, what type of people visit more often than others (e.g., men vs women) etc… -Step 10: Continue to improve upon all aspects mentioned above by following trends in web design and staying up-to-date on new technologies that can enhance user experience even further! -How does a Website Work? -A website works by having pages, which are made of HTML code. This code tells your computer how to display the content on each page you visit – whether it’s an image or text file (like PDFs). In order for someone else’s browser not only be able but also want those same results when accessing any given URL; some additional steps need taken by way of programming scripts that will add functionality such as making links clickable! -The most common type is called static HTML pages because they remain unchanged over time unless modified manually (either through editing files directly or using an interface such as WordPress). They are usually served up via HTTP protocols – this means anyone can access them without having any special privileges like being part of a group who is allowed into restricted areas online; however, there may still exist some limitations depending upon where one lives geographically speaking. -How to -llama_print_timings: load time = 576.45 ms -llama_print_timings: sample time = 283.10 ms / 400 runs ( 0.71 ms per token, 1412.91 tokens per second) -llama_print_timings: prompt eval time = 599.83 ms / 19 tokens ( 31.57 ms per token, 31.68 tokens per second) -llama_print_timings: eval time = 24513.59 ms / 399 runs ( 61.44 ms per token, 16.28 tokens per second) -llama_print_timings: total time = 25431.49 ms -``` -
-
-Demo of running both LLaMA-7B and whisper.cpp on a single M1 Pro MacBook - -And here is another demo of running both LLaMA-7B and [whisper.cpp](https://github.com/ggerganov/whisper.cpp) on a single M1 Pro MacBook: - -https://user-images.githubusercontent.com/1991296/224442907-7693d4be-acaa-4e01-8b4f-add84093ffff.mp4 - -
- -## Usage - -Here are the end-to-end binary build and model conversion steps for most supported models. - -### Basic usage - -Firstly, you need to get the binary. There are different methods that you can follow: -- Method 1: Clone this repository and build locally, see [how to build](./docs/build.md) -- Method 2: If you are using MacOS or Linux, you can install llama.cpp via [brew, flox or nix](./docs/install.md) -- Method 3: Use a Docker image, see [documentation for Docker](./docs/docker.md) -- Method 4: Download pre-built binary from [releases](https://github.com/ggerganov/llama.cpp/releases) - -You can run a basic completion using this command: - -```bash -llama-cli -m your_model.gguf -p "I believe the meaning of life is" -n 128 - -# Output: -# I believe the meaning of life is to find your own truth and to live in accordance with it. For me, this means being true to myself and following my passions, even if they don't align with societal expectations. I think that's what I love about yoga – it's not just a physical practice, but a spiritual one too. It's about connecting with yourself, listening to your inner voice, and honoring your own unique journey. -``` - -See [this page](./examples/main/README.md) for a full list of parameters. - -### Conversation mode - -If you want a more ChatGPT-like experience, you can run in conversation mode by passing `-cnv` as a parameter: - -```bash -llama-cli -m your_model.gguf -p "You are a helpful assistant" -cnv - -# Output: -# > hi, who are you? -# Hi there! I'm your helpful assistant! I'm an AI-powered chatbot designed to assist and provide information to users like you. I'm here to help answer your questions, provide guidance, and offer support on a wide range of topics. I'm a friendly and knowledgeable AI, and I'm always happy to help with anything you need. What's on your mind, and how can I assist you today? -# -# > what is 1+1? -# Easy peasy! The answer to 1+1 is... 2! -``` - -By default, the chat template will be taken from the input model. If you want to use another chat template, pass `--chat-template NAME` as a parameter. See the list of [supported templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) - -```bash -./llama-cli -m your_model.gguf -p "You are a helpful assistant" -cnv --chat-template chatml -``` - -You can also use your own template via in-prefix, in-suffix and reverse-prompt parameters: - -```bash -./llama-cli -m your_model.gguf -p "You are a helpful assistant" -cnv --in-prefix 'User: ' --reverse-prompt 'User:' -``` - -### Web server - -[llama.cpp web server](./examples/server/README.md) is a lightweight [OpenAI API](https://github.com/openai/openai-openapi) compatible HTTP server that can be used to serve local models and easily connect them to existing clients. - -Example usage: - -```bash -./llama-server -m your_model.gguf --port 8080 - -# Basic web UI can be accessed via browser: http://localhost:8080 -# Chat completion endpoint: http://localhost:8080/v1/chat/completions -``` - -### Interactive mode - -> [!NOTE] -> If you prefer basic usage, please consider using conversation mode instead of interactive mode - -In this mode, you can always interrupt generation by pressing Ctrl+C and entering one or more lines of text, which will be converted into tokens and appended to the current context. You can also specify a *reverse prompt* with the parameter `-r "reverse prompt string"`. This will result in user input being prompted whenever the exact tokens of the reverse prompt string are encountered in the generation. A typical use is to use a prompt that makes LLaMA emulate a chat between multiple users, say Alice and Bob, and pass `-r "Alice:"`. - -Here is an example of a few-shot interaction, invoked with the command - -```bash -# default arguments using a 7B model -./examples/chat.sh - -# advanced chat with a 13B model -./examples/chat-13B.sh - -# custom arguments using a 13B model -./llama-cli -m ./models/13B/ggml-model-q4_0.gguf -n 256 --repeat_penalty 1.0 --color -i -r "User:" -f prompts/chat-with-bob.txt -``` - -Note the use of `--color` to distinguish between user input and generated text. Other parameters are explained in more detail in the [README](examples/main/README.md) for the `llama-cli` example program. - -![image](https://user-images.githubusercontent.com/1991296/224575029-2af3c7dc-5a65-4f64-a6bb-517a532aea38.png) - -### Persistent Interaction - -The prompt, user inputs, and model generations can be saved and resumed across calls to `./llama-cli` by leveraging `--prompt-cache` and `--prompt-cache-all`. The `./examples/chat-persistent.sh` script demonstrates this with support for long-running, resumable chat sessions. To use this example, you must provide a file to cache the initial chat prompt and a directory to save the chat session, and may optionally provide the same variables as `chat-13B.sh`. The same prompt cache can be reused for new chat sessions. Note that both prompt cache and chat directory are tied to the initial prompt (`PROMPT_TEMPLATE`) and the model file. - -```bash -# Start a new chat -PROMPT_CACHE_FILE=chat.prompt.bin CHAT_SAVE_DIR=./chat/default ./examples/chat-persistent.sh - -# Resume that chat -PROMPT_CACHE_FILE=chat.prompt.bin CHAT_SAVE_DIR=./chat/default ./examples/chat-persistent.sh - -# Start a different chat with the same prompt/model -PROMPT_CACHE_FILE=chat.prompt.bin CHAT_SAVE_DIR=./chat/another ./examples/chat-persistent.sh - -# Different prompt cache for different prompt/model -PROMPT_TEMPLATE=./prompts/chat-with-bob.txt PROMPT_CACHE_FILE=bob.prompt.bin \ - CHAT_SAVE_DIR=./chat/bob ./examples/chat-persistent.sh -``` - -### Constrained output with grammars - -`llama.cpp` supports grammars to constrain model output. For example, you can force the model to output JSON only: - -```bash -./llama-cli -m ./models/13B/ggml-model-q4_0.gguf -n 256 --grammar-file grammars/json.gbnf -p 'Request: schedule a call at 8pm; Command:' -``` - -The `grammars/` folder contains a handful of sample grammars. To write your own, check out the [GBNF Guide](./grammars/README.md). - -For authoring more complex JSON grammars, you can also check out https://grammar.intrinsiclabs.ai/, a browser app that lets you write TypeScript interfaces which it compiles to GBNF grammars that you can save for local use. Note that the app is built and maintained by members of the community, please file any issues or FRs on [its repo](http://github.com/intrinsiclabsai/gbnfgen) and not this one. - -## Build - -Please refer to [Build llama.cpp locally](./docs/build.md) - ## Supported backends | Backend | Target devices | | --- | --- | -| [Metal](./docs/build.md#metal-build) | Apple Silicon | -| [BLAS](./docs/build.md#blas-build) | All | -| [BLIS](./docs/backend/BLIS.md) | All | -| [SYCL](./docs/backend/SYCL.md) | Intel and Nvidia GPU | -| [MUSA](./docs/build.md#musa) | Moore Threads GPU | -| [CUDA](./docs/build.md#cuda) | Nvidia GPU | -| [hipBLAS](./docs/build.md#hipblas) | AMD GPU | -| [Vulkan](./docs/build.md#vulkan) | GPU | -| [CANN](./docs/build.md#cann) | Ascend NPU | +| [Metal](docs/build.md#metal-build) | Apple Silicon | +| [BLAS](docs/build.md#blas-build) | All | +| [BLIS](docs/backend/BLIS.md) | All | +| [SYCL](docs/backend/SYCL.md) | Intel and Nvidia GPU | +| [MUSA](docs/build.md#musa) | Moore Threads MTT GPU | +| [CUDA](docs/build.md#cuda) | Nvidia GPU | +| [HIP](docs/build.md#hip) | AMD GPU | +| [Vulkan](docs/build.md#vulkan) | GPU | +| [CANN](docs/build.md#cann) | Ascend NPU | -## Tools +## Building the project -### Prepare and Quantize +The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h). +The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server. Possible methods for obtaining the binaries: -> [!NOTE] -> You can use the [GGUF-my-repo](https://huggingface.co/spaces/ggml-org/gguf-my-repo) space on Hugging Face to quantise your model weights without any setup too. It is synced from `llama.cpp` main every 6 hours. +- Clone this repository and build locally, see [how to build](docs/build.md) +- On MacOS or Linux, install `llama.cpp` via [brew, flox or nix](docs/install.md) +- Use a Docker image, see [documentation for Docker](docs/docker.md) +- Download pre-built binaries from [releases](https://github.com/ggerganov/llama.cpp/releases) -To obtain the official LLaMA 2 weights please see the Obtaining and using the Facebook LLaMA 2 model section. There is also a large selection of pre-quantized `gguf` models available on Hugging Face. +## Obtaining and quantizing models -Note: `convert.py` has been moved to `examples/convert_legacy_llama.py` and shouldn't be used for anything other than `Llama/Llama2/Mistral` models and their derivatives. -It does not support LLaMA 3, you can use `convert_hf_to_gguf.py` with LLaMA 3 downloaded from Hugging Face. +The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](https://huggingface.co/models?library=gguf&sort=trending) compatible with `llama.cpp`: -To learn more about quantizing model, [read this documentation](./examples/quantize/README.md) +- [Trending](https://huggingface.co/models?library=gguf&sort=trending) +- [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf) -### Perplexity (measuring model quality) +You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from Hugging Face by using this CLI argument: `-hf /[:quant]` -You can use the `perplexity` example to measure perplexity over a given prompt (lower perplexity is better). -For more information, see [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity). +After downloading a model, use the CLI tools to run it locally - see below. + +`llama.cpp` requires the model to be stored in the [GGUF](https://github.com/ggerganov/ggml/blob/master/docs/gguf.md) file format. Models in other data formats can be converted to GGUF using the `convert_*.py` Python scripts in this repo. + +The Hugging Face platform provides a variety of online tools for converting, quantizing and hosting models with `llama.cpp`: + +- Use the [GGUF-my-repo space](https://huggingface.co/spaces/ggml-org/gguf-my-repo) to convert to GGUF format and quantize model weights to smaller sizes +- Use the [GGUF-my-LoRA space](https://huggingface.co/spaces/ggml-org/gguf-my-lora) to convert LoRA adapters to GGUF format (more info: https://github.com/ggerganov/llama.cpp/discussions/10123) +- Use the [GGUF-editor space](https://huggingface.co/spaces/CISCai/gguf-editor) to edit GGUF meta data in the browser (more info: https://github.com/ggerganov/llama.cpp/discussions/9268) +- Use the [Inference Endpoints](https://ui.endpoints.huggingface.co/) to directly host `llama.cpp` in the cloud (more info: https://github.com/ggerganov/llama.cpp/discussions/9669) + +To learn more about model quantization, [read this documentation](examples/quantize/README.md) + +## [`llama-cli`](examples/main) + +#### A CLI tool for accessing and experimenting with most of `llama.cpp`'s functionality. + +-
+ Run in conversation mode + + Models with a built-in chat template will automatically activate conversation mode. If this doesn't occur, you can manually enable it by adding `-cnv` and specifying a suitable chat template with `--chat-template NAME` + + ```bash + llama-cli -m model.gguf + + # > hi, who are you? + # Hi there! I'm your helpful assistant! I'm an AI-powered chatbot designed to assist and provide information to users like you. I'm here to help answer your questions, provide guidance, and offer support on a wide range of topics. I'm a friendly and knowledgeable AI, and I'm always happy to help with anything you need. What's on your mind, and how can I assist you today? + # + # > what is 1+1? + # Easy peasy! The answer to 1+1 is... 2! + ``` + +
+ +-
+ Run in conversation mode with custom chat template + + ```bash + # use the "chatml" template (use -h to see the list of supported templates) + llama-cli -m model.gguf -cnv --chat-template chatml + + # use a custom template + llama-cli -m model.gguf -cnv --in-prefix 'User: ' --reverse-prompt 'User:' + ``` + +
+ +-
+ Run simple text completion + + To disable conversation mode explicitly, use `-no-cnv` + + ```bash + llama-cli -m model.gguf -p "I believe the meaning of life is" -n 128 -no-cnv + + # I believe the meaning of life is to find your own truth and to live in accordance with it. For me, this means being true to myself and following my passions, even if they don't align with societal expectations. I think that's what I love about yoga – it's not just a physical practice, but a spiritual one too. It's about connecting with yourself, listening to your inner voice, and honoring your own unique journey. + ``` + +
+ +-
+ Constrain the output with a custom grammar + + ```bash + llama-cli -m model.gguf -n 256 --grammar-file grammars/json.gbnf -p 'Request: schedule a call at 8pm; Command:' + + # {"appointmentTime": "8pm", "appointmentDetails": "schedule a a call"} + ``` + + The [grammars/](grammars/) folder contains a handful of sample grammars. To write your own, check out the [GBNF Guide](grammars/README.md). + + For authoring more complex JSON grammars, check out https://grammar.intrinsiclabs.ai/ + +
+ + +## [`llama-server`](examples/server) + +#### A lightweight, [OpenAI API](https://github.com/openai/openai-openapi) compatible, HTTP server for serving LLMs. + +-
+ Start a local HTTP server with default configuration on port 8080 + + ```bash + llama-server -m model.gguf --port 8080 + + # Basic web UI can be accessed via browser: http://localhost:8080 + # Chat completion endpoint: http://localhost:8080/v1/chat/completions + ``` + +
+ +-
+ Support multiple-users and parallel decoding + + ```bash + # up to 4 concurrent requests, each with 4096 max context + llama-server -m model.gguf -c 16384 -np 4 + ``` + +
+ +-
+ Enable speculative decoding + + ```bash + # the draft.gguf model should be a small variant of the target model.gguf + llama-server -m model.gguf -md draft.gguf + ``` + +
+ +-
+ Serve an embedding model + + ```bash + # use the /embedding endpoint + llama-server -m model.gguf --embedding --pooling cls -ub 8192 + ``` + +
+ +-
+ Serve a reranking model + + ```bash + # use the /reranking endpoint + llama-server -m model.gguf --reranking + ``` + +
+ +-
+ Constrain all outputs with a grammar + + ```bash + # custom grammar + llama-server -m model.gguf --grammar-file grammar.gbnf + + # JSON + llama-server -m model.gguf --grammar-file grammars/json.gbnf + ``` + +
+ + +## [`llama-perplexity`](examples/perplexity) + +#### A tool for measuring the perplexity [^1][^2] (and other quality metrics) of a model over a given text. + +-
+ Measure the perplexity over a text file + + ```bash + llama-perplexity -m model.gguf -f file.txt + + # [1]15.2701,[2]5.4007,[3]5.3073,[4]6.2965,[5]5.8940,[6]5.6096,[7]5.7942,[8]4.9297, ... + # Final estimate: PPL = 5.4007 +/- 0.67339 + ``` + +
+ +-
+ Measure KL divergence + + ```bash + # TODO + ``` + +
+ +[^1]: [examples/perplexity/README.md](./examples/perplexity/README.md) +[^2]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity) + +## [`llama-bench`](examples/llama-bench) + +#### Benchmark the performance of the inference for various parameters. + +-
+ Run default benchmark + + ```bash + llama-bench -m model.gguf + + # Output: + # | model | size | params | backend | threads | test | t/s | + # | ------------------- | ---------: | ---------: | ---------- | ------: | ------------: | -------------------: | + # | qwen2 1.5B Q4_0 | 885.97 MiB | 1.54 B | Metal,BLAS | 16 | pp512 | 5765.41 ± 20.55 | + # | qwen2 1.5B Q4_0 | 885.97 MiB | 1.54 B | Metal,BLAS | 16 | tg128 | 197.71 ± 0.81 | + # + # build: 3e0ba0e60 (4229) + ``` + +
+ +## [`llama-run`](examples/run) + +#### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^3]. + +-
+ Run a model with a specific prompt (by default it's pulled from Ollama registry) + + ```bash + llama-run granite-code + ``` + +
+ +[^3]: [RamaLama](https://github.com/containers/ramalama) + +## [`llama-simple`](examples/simple) + +#### A minimal example for implementing apps with `llama.cpp`. Useful for developers. + +-
+ Basic text completion + + ```bash + llama-simple -m model.gguf + + # Hello my name is Kaitlyn and I am a 16 year old girl. I am a junior in high school and I am currently taking a class called "The Art of + ``` + +
-To learn more how to measure perplexity using llama.cpp, [read this documentation](./examples/perplexity/README.md) ## Contributing - Contributors can open PRs - Collaborators can push to branches in the `llama.cpp` repo and merge PRs into the `master` branch - Collaborators will be invited based on contributions -- Any help with managing issues and PRs is very appreciated! +- Any help with managing issues, PRs and projects is very appreciated! - See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions - Read the [CONTRIBUTING.md](CONTRIBUTING.md) for more information - Make sure to read this: [Inference at the edge](https://github.com/ggerganov/llama.cpp/discussions/205) - A bit of backstory for those who are interested: [Changelog podcast](https://changelog.com/podcast/532) -## Other documentations +## Other documentation -- [main (cli)](./examples/main/README.md) -- [server](./examples/server/README.md) -- [jeopardy](./examples/jeopardy/README.md) -- [GBNF grammars](./grammars/README.md) +- [main (cli)](examples/main/README.md) +- [server](examples/server/README.md) +- [GBNF grammars](grammars/README.md) -**Development documentations** +#### Development documentation -- [How to build](./docs/build.md) -- [Running on Docker](./docs/docker.md) -- [Build on Android](./docs/android.md) -- [Performance troubleshooting](./docs/development/token_generation_performance_tips.md) +- [How to build](docs/build.md) +- [Running on Docker](docs/docker.md) +- [Build on Android](docs/android.md) +- [Performance troubleshooting](docs/development/token_generation_performance_tips.md) - [GGML tips & tricks](https://github.com/ggerganov/llama.cpp/wiki/GGML-Tips-&-Tricks) -**Seminal papers and background on the models** +#### Seminal papers and background on the models If your issue is with model generation quality, then please at least scan the following links and papers to understand the limitations of LLaMA models. This is especially important when choosing an appropriate model size and appreciating both the significant and subtle differences between LLaMA models and ChatGPT: - LLaMA: @@ -468,3 +514,6 @@ If your issue is with model generation quality, then please at least scan the fo - GPT-3.5 / InstructGPT / ChatGPT: - [Aligning language models to follow instructions](https://openai.com/research/instruction-following) - [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) + +#### References + diff --git a/Sources/llama/llama.h b/Sources/llama/llama.h new file mode 100644 index 000000000..41725880e --- /dev/null +++ b/Sources/llama/llama.h @@ -0,0 +1,4 @@ +#pragma once + +#include + diff --git a/Sources/llama/module.modulemap b/Sources/llama/module.modulemap new file mode 100644 index 000000000..d010555b1 --- /dev/null +++ b/Sources/llama/module.modulemap @@ -0,0 +1,5 @@ +module llama [system] { + header "llama.h" + link "llama" + export * +} diff --git a/ci/run.sh b/ci/run.sh index 751bb0a02..77c32ce00 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -1,4 +1,4 @@ -#/bin/bash +#!/bin/bash # # sample usage: # @@ -39,7 +39,7 @@ SRC=`pwd` CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON" if [ ! -z ${GG_BUILD_METAL} ]; then - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON -DGGML_METAL_USE_BF16=ON" fi if [ ! -z ${GG_BUILD_CUDA} ]; then @@ -53,7 +53,7 @@ if [ ! -z ${GG_BUILD_SYCL} ]; then exit 1 fi - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON" fi if [ ! -z ${GG_BUILD_VULKAN} ]; then @@ -326,36 +326,36 @@ function gg_run_open_llama_7b_v2 { ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k - (time ./bin/llama-cli --model ${model_f16} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-cli --model ${model_q8_0} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log - (time ./bin/llama-cli --model ${model_q4_0} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log - (time ./bin/llama-cli --model ${model_q4_1} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log - (time ./bin/llama-cli --model ${model_q5_0} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log - (time ./bin/llama-cli --model ${model_q5_1} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log - (time ./bin/llama-cli --model ${model_q2_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log - (time ./bin/llama-cli --model ${model_q3_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log - (time ./bin/llama-cli --model ${model_q4_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log - (time ./bin/llama-cli --model ${model_q5_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log - (time ./bin/llama-cli --model ${model_q6_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log - (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log - (time ./bin/llama-perplexity --model ${model_q4_0} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log - (time ./bin/llama-perplexity --model ${model_q4_1} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log - (time ./bin/llama-perplexity --model ${model_q5_0} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log - (time ./bin/llama-perplexity --model ${model_q5_1} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log - (time ./bin/llama-perplexity --model ${model_q2_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log - (time ./bin/llama-perplexity --model ${model_q3_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log - (time ./bin/llama-perplexity --model ${model_q4_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log - (time ./bin/llama-perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log - (time ./bin/llama-perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-perplexity --model ${model_q4_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-perplexity --model ${model_q4_1} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-perplexity --model ${model_q5_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-perplexity --model ${model_q5_1} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-perplexity --model ${model_q2_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-perplexity --model ${model_q3_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-perplexity --model ${model_q4_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log - (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log + (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/llama-save-load-state -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state -fa -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state -fa -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state--model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -460,34 +460,34 @@ function gg_run_pythia_1_4b { ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k - (time ./bin/llama-cli --model ${model_f16} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-cli --model ${model_q8_0} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log - (time ./bin/llama-cli --model ${model_q4_0} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log - (time ./bin/llama-cli --model ${model_q4_1} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log - (time ./bin/llama-cli --model ${model_q5_0} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log - (time ./bin/llama-cli --model ${model_q5_1} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log - (time ./bin/llama-cli --model ${model_q2_k} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log - (time ./bin/llama-cli --model ${model_q3_k} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log - (time ./bin/llama-cli --model ${model_q4_k} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log - (time ./bin/llama-cli --model ${model_q5_k} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log - (time ./bin/llama-cli --model ${model_q6_k} -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/llama-cli -no-cnv --model ${model_f16} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -ngl 99 -c 0 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log - (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log - (time ./bin/llama-perplexity --model ${model_q4_0} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log - (time ./bin/llama-perplexity --model ${model_q4_1} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log - (time ./bin/llama-perplexity --model ${model_q5_0} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log - (time ./bin/llama-perplexity --model ${model_q5_1} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log - (time ./bin/llama-perplexity --model ${model_q2_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log - (time ./bin/llama-perplexity --model ${model_q3_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log - (time ./bin/llama-perplexity --model ${model_q4_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log - (time ./bin/llama-perplexity --model ${model_q5_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log - (time ./bin/llama-perplexity --model ${model_q6_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-perplexity --model ${model_q4_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-perplexity --model ${model_q4_1} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-perplexity --model ${model_q5_0} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-perplexity --model ${model_q5_1} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-perplexity --model ${model_q2_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-perplexity --model ${model_q3_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-perplexity --model ${model_q4_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-perplexity --model ${model_q5_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-perplexity --model ${model_q6_k} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log - (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log + (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/llama-save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state -fa --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -591,36 +591,36 @@ function gg_run_pythia_2_8b { ./bin/llama-quantize ${model_f16} ${model_q5_k} q5_k ./bin/llama-quantize ${model_f16} ${model_q6_k} q6_k - (time ./bin/llama-cli --model ${model_f16} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-cli --model ${model_q8_0} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log - (time ./bin/llama-cli --model ${model_q4_0} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log - (time ./bin/llama-cli --model ${model_q4_1} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log - (time ./bin/llama-cli --model ${model_q5_0} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log - (time ./bin/llama-cli --model ${model_q5_1} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log - (time ./bin/llama-cli --model ${model_q2_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log - (time ./bin/llama-cli --model ${model_q3_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log - (time ./bin/llama-cli --model ${model_q4_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log - (time ./bin/llama-cli --model ${model_q5_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log - (time ./bin/llama-cli --model ${model_q6_k} -t 1 -ngl 999 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/llama-cli -no-cnv --model ${model_f16} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-cli -no-cnv --model ${model_q8_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_0} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_1} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-cli -no-cnv --model ${model_q2_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q3_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q4_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q5_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-cli -no-cnv --model ${model_q6_k} -t 1 -ngl 99 -c 0 -s 1234 -n 256 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log - (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log - (time ./bin/llama-perplexity --model ${model_q4_0} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log - (time ./bin/llama-perplexity --model ${model_q4_1} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log - (time ./bin/llama-perplexity --model ${model_q5_0} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log - (time ./bin/llama-perplexity --model ${model_q5_1} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log - (time ./bin/llama-perplexity --model ${model_q2_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log - (time ./bin/llama-perplexity --model ${model_q3_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log - (time ./bin/llama-perplexity --model ${model_q4_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log - (time ./bin/llama-perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log - (time ./bin/llama-perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/llama-perplexity --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-perplexity --model ${model_q8_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-perplexity --model ${model_q4_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_0.log + (time ./bin/llama-perplexity --model ${model_q4_1} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_1.log + (time ./bin/llama-perplexity --model ${model_q5_0} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_0.log + (time ./bin/llama-perplexity --model ${model_q5_1} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_1.log + (time ./bin/llama-perplexity --model ${model_q2_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q2_k.log + (time ./bin/llama-perplexity --model ${model_q3_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q3_k.log + (time ./bin/llama-perplexity --model ${model_q4_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q4_k.log + (time ./bin/llama-perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log + (time ./bin/llama-perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log - (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log + (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/llama-save-load-state -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state -fa -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state -fa -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -706,12 +706,88 @@ function gg_run_embd_bge_small { ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 - (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log set +e } +function gg_sum_embd_bge_small { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'BGE Small (BERT):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" + gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" +} + +# rerank_tiny + +function gg_run_rerank_tiny { + cd ${SRC} + + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/tokenizer_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/special_tokens_map.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/resolve/main/pytorch_model.bin + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/sentence_bert_config.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/vocab.txt + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/modules.json + gg_wget models-mnt/rerank-tiny/ https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/config.json + + gg_wget models-mnt/rerank-tiny/1_Pooling https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/raw/main/1_Pooling/config.json + + path_models="../models-mnt/rerank-tiny" + + rm -rf build-ci-release && mkdir build-ci-release && cd build-ci-release + + set -e + + (time cmake -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log + (time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log + + python3 ../convert_hf_to_gguf.py ${path_models} --outfile ${path_models}/ggml-model-f16.gguf + + model_f16="${path_models}/ggml-model-f16.gguf" + + # for this model, the SEP token is "" + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?hi\nwhat is panda?it's a bear\nwhat is panda?The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + + # sample output + # rerank score 0: 0.029 + # rerank score 1: 0.029 + # rerank score 2: 0.135 + + # check that the score is in the range [$3, $4] + function check_score { + qnt="$1" + score=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) + + if [ $(echo "$score < $3" | bc) -eq 1 ] || [ $(echo "$score > $4" | bc) -eq 1 ]; then + printf ' - %s @ %s (FAIL: score not in range [%s, %s])\n' "$qnt" "$score" "$3" "$4" + return 20 + fi + + printf ' - %s @ %s OK\n' "$qnt" "$score" + return 0 + } + + check_score "rerank score 0" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 0")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 1" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 1")" "0.00" "0.05" | tee -a $OUT/${ci}-rk-f16.log + check_score "rerank score 2" "$(cat $OUT/${ci}-rk-f16.log | grep "rerank score 2")" "0.10" "0.30" | tee -a $OUT/${ci}-rk-f16.log + + set +e +} + +function gg_sum_rerank_tiny { + gg_printf '### %s\n\n' "${ci}" + + gg_printf 'Rerank Tiny (Jina):\n' + gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" + gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-rk-f16.log)" +} + function gg_check_build_requirements { if ! command -v cmake &> /dev/null; then gg_printf 'cmake not found, please install' @@ -726,17 +802,11 @@ function gg_check_build_requirements { fi } -function gg_sum_embd_bge_small { - gg_printf '### %s\n\n' "${ci}" - - gg_printf 'BGE Small (BERT):\n' - gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)" - gg_printf '- f16: \n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-f16.log)" - gg_printf '- q8_0:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q8_0.log)" -} - ## main +export LLAMA_LOG_PREFIX=1 +export LLAMA_LOG_TIMESTAMPS=1 + if [ -z ${GG_BUILD_LOW_PERF} ]; then # Create symlink: ./llama.cpp/models-mnt -> $MNT/models/models-mnt rm -rf ${SRC}/models-mnt @@ -745,7 +815,10 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then ln -sfn ${mnt_models} ${SRC}/models-mnt # Create a fresh python3 venv and enter it - python3 -m venv "$MNT/venv" + if ! python3 -m venv "$MNT/venv"; then + echo "Error: Failed to create Python virtual environment at $MNT/venv." + exit 1 + fi source "$MNT/venv/bin/activate" pip install -r ${SRC}/requirements.txt --disable-pip-version-check @@ -759,6 +832,7 @@ test $ret -eq 0 && gg_run ctest_release if [ -z ${GG_BUILD_LOW_PERF} ]; then test $ret -eq 0 && gg_run embd_bge_small + test $ret -eq 0 && gg_run rerank_tiny if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then test $ret -eq 0 && gg_run test_scripts_debug diff --git a/cmake/arm64-apple-clang.cmake b/cmake/arm64-apple-clang.cmake new file mode 100644 index 000000000..5fcd2882a --- /dev/null +++ b/cmake/arm64-apple-clang.cmake @@ -0,0 +1,16 @@ +set( CMAKE_SYSTEM_NAME Darwin ) +set( CMAKE_SYSTEM_PROCESSOR arm64 ) + +set( target arm64-apple-darwin-macho ) + +set( CMAKE_C_COMPILER clang ) +set( CMAKE_CXX_COMPILER clang++ ) + +set( CMAKE_C_COMPILER_TARGET ${target} ) +set( CMAKE_CXX_COMPILER_TARGET ${target} ) + +set( arch_c_flags "-march=armv8.4-a -fvectorize -ffp-model=fast -fno-finite-math-only" ) +set( warn_c_flags "-Wno-format -Wno-unused-variable -Wno-unused-function" ) + +set( CMAKE_C_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" ) +set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags} ${warn_c_flags}" ) diff --git a/cmake/build-info.cmake b/cmake/build-info.cmake index ea3dc55c8..c1a456e17 100644 --- a/cmake/build-info.cmake +++ b/cmake/build-info.cmake @@ -44,7 +44,7 @@ if(MSVC) set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME}) else() execute_process( - COMMAND sh -c "$@ --version | head -1" _ ${CMAKE_C_COMPILER} + COMMAND sh -c "\"$@\" --version | head -1" _ ${CMAKE_C_COMPILER} OUTPUT_VARIABLE OUT OUTPUT_STRIP_TRAILING_WHITESPACE ) diff --git a/cmake/common.cmake b/cmake/common.cmake new file mode 100644 index 000000000..0f54871e4 --- /dev/null +++ b/cmake/common.cmake @@ -0,0 +1,33 @@ +function(llama_add_compile_flags) + if (LLAMA_FATAL_WARNINGS) + if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + list(APPEND C_FLAGS -Werror) + list(APPEND CXX_FLAGS -Werror) + elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + add_compile_options(/WX) + endif() + endif() + + if (LLAMA_ALL_WARNINGS) + if (NOT MSVC) + list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes + -Werror=implicit-int -Werror=implicit-function-declaration) + + list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn) + + list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function) + + list(APPEND C_FLAGS ${WARNING_FLAGS}) + list(APPEND CXX_FLAGS ${WARNING_FLAGS}) + + ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}) + + add_compile_options("$<$:${C_FLAGS};${GF_C_FLAGS}>" + "$<$:${CXX_FLAGS};${GF_CXX_FLAGS}>") + else() + # todo : msvc + set(C_FLAGS "" PARENT_SCOPE) + set(CXX_FLAGS "" PARENT_SCOPE) + endif() + endif() +endfunction() diff --git a/cmake/llama-config.cmake.in b/cmake/llama-config.cmake.in index f072b76a3..90cbec5b6 100644 --- a/cmake/llama-config.cmake.in +++ b/cmake/llama-config.cmake.in @@ -3,88 +3,28 @@ set(LLAMA_BUILD_COMMIT @LLAMA_BUILD_COMMIT@) set(LLAMA_BUILD_NUMBER @LLAMA_BUILD_NUMBER@) set(LLAMA_SHARED_LIB @BUILD_SHARED_LIBS@) -set(GGML_BLAS @GGML_BLAS@) -set(GGML_CUDA @GGML_CUDA@) -set(GGML_METAL @GGML_METAL@) -set(GGML_HIPBLAS @GGML_HIPBLAS@) -set(GGML_ACCELERATE @GGML_ACCELERATE@) -set(GGML_VULKAN @GGML_VULKAN@) -set(GGML_VULKAN_CHECK_RESULTS @GGML_VULKAN_CHECK_RESULTS@) -set(GGML_VULKAN_DEBUG @GGML_VULKAN_DEBUG@) -set(GGML_VULKAN_MEMORY_DEBUG @GGML_VULKAN_MEMORY_DEBUG@) -set(GGML_VULKAN_VALIDATE @GGML_VULKAN_VALIDATE@) -set(GGML_SYCL @GGML_SYCL@) -set(GGML_OPENMP @GGML_OPENMP@) - @PACKAGE_INIT@ set_and_check(LLAMA_INCLUDE_DIR "@PACKAGE_LLAMA_INCLUDE_INSTALL_DIR@") set_and_check(LLAMA_LIB_DIR "@PACKAGE_LLAMA_LIB_INSTALL_DIR@") set_and_check(LLAMA_BIN_DIR "@PACKAGE_LLAMA_BIN_INSTALL_DIR@") -# Ensure transient dependencies satisfied - -find_package(Threads REQUIRED) - -if (APPLE AND GGML_ACCELERATE) - find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) -endif() - -if (GGML_BLAS) - find_package(BLAS REQUIRED) -endif() - -if (GGML_CUDA) - find_package(CUDAToolkit REQUIRED) -endif() - -if (GGML_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) -endif() - -if (GGML_VULKAN) - find_package(Vulkan REQUIRED) -endif() - -if (GGML_HIPBLAS) - find_package(hip REQUIRED) - find_package(hipblas REQUIRED) - find_package(rocblas REQUIRED) -endif() - -if (GGML_SYCL) - find_package(IntelSYCL REQUIRED) - find_package(MKL REQUIRED) -endif() - -if (GGML_OPENMP) - find_package(OpenMP REQUIRED) -endif() - - -find_library(ggml_LIBRARY ggml - REQUIRED - HINTS ${LLAMA_LIB_DIR}) +find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake) find_library(llama_LIBRARY llama REQUIRED - HINTS ${LLAMA_LIB_DIR}) - -set(_llama_link_deps "${ggml_LIBRARY}" "@GGML_LINK_LIBRARIES@") -set(_llama_transient_defines "@GGML_TRANSIENT_DEFINES@") + HINTS ${LLAMA_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) add_library(llama UNKNOWN IMPORTED) - set_target_properties(llama PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LLAMA_INCLUDE_DIR}" - INTERFACE_LINK_LIBRARIES "${_llama_link_deps}" - INTERFACE_COMPILE_DEFINITIONS "${_llama_transient_defines}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" IMPORTED_LOCATION "${llama_LIBRARY}" - INTERFACE_COMPILE_FEATURES cxx_std_11 - POSITION_INDEPENDENT_CODE ON ) + INTERFACE_COMPILE_FEATURES c_std_90 + POSITION_INDEPENDENT_CODE ON) check_required_components(Llama) diff --git a/cmake/llama.pc.in b/cmake/llama.pc.in index 326acbb61..0b2b6bcfa 100644 --- a/cmake/llama.pc.in +++ b/cmake/llama.pc.in @@ -6,5 +6,5 @@ includedir=${prefix}/include Name: llama Description: Port of Facebook's LLaMA model in C/C++ Version: @PROJECT_VERSION@ -Libs: -L${libdir} -lllama +Libs: -L${libdir} -lggml -lggml-base -lllama Cflags: -I${includedir} diff --git a/cmake/x64-windows-llvm.cmake b/cmake/x64-windows-llvm.cmake new file mode 100644 index 000000000..0603d738f --- /dev/null +++ b/cmake/x64-windows-llvm.cmake @@ -0,0 +1,11 @@ +set( CMAKE_SYSTEM_NAME Windows ) +set( CMAKE_SYSTEM_PROCESSOR x86_64 ) + +set( CMAKE_C_COMPILER clang ) +set( CMAKE_CXX_COMPILER clang++ ) + +set( arch_c_flags "-march=native" ) + +set( CMAKE_C_FLAGS_INIT "${arch_c_flags}" ) +set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags}" ) + diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 2c72793b8..72f0915c1 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -2,6 +2,8 @@ find_package(Threads REQUIRED) +llama_add_compile_flags() + # Build info header # @@ -51,19 +53,27 @@ endif() set(TARGET common) add_library(${TARGET} STATIC + arg.cpp + arg.h base64.hpp - common.h + chat.cpp + chat.hpp + chat-template.hpp common.cpp - sampling.h - sampling.cpp - console.h + common.h console.cpp - json.hpp + console.h json-schema-to-grammar.cpp - train.h - train.cpp - ngram-cache.h + json.hpp + log.cpp + log.h + minja.hpp ngram-cache.cpp + ngram-cache.h + sampling.cpp + sampling.h + speculative.cpp + speculative.h ) if (BUILD_SHARED_LIBS) @@ -75,12 +85,12 @@ set(LLAMA_COMMON_EXTRA_LIBS build_info) # Use curl to download model url if (LLAMA_CURL) find_package(CURL REQUIRED) - add_definitions(-DLLAMA_USE_CURL) + target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL) include_directories(${CURL_INCLUDE_DIRS}) find_library(CURL_LIBRARY curl REQUIRED) set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY}) endif () target_include_directories(${TARGET} PUBLIC .) -target_compile_features (${TARGET} PUBLIC cxx_std_11) +target_compile_features (${TARGET} PUBLIC cxx_std_17) target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) diff --git a/common/arg.cpp b/common/arg.cpp new file mode 100644 index 000000000..f5e9b294f --- /dev/null +++ b/common/arg.cpp @@ -0,0 +1,2315 @@ +#include "arg.h" + +#include "log.h" +#include "sampling.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "json-schema-to-grammar.h" + +using json = nlohmann::ordered_json; + +common_arg & common_arg::set_examples(std::initializer_list examples) { + this->examples = std::move(examples); + return *this; +} + +common_arg & common_arg::set_excludes(std::initializer_list excludes) { + this->excludes = std::move(excludes); + return *this; +} + +common_arg & common_arg::set_env(const char * env) { + help = help + "\n(env: " + env + ")"; + this->env = env; + return *this; +} + +common_arg & common_arg::set_sparam() { + is_sparam = true; + return *this; +} + +bool common_arg::in_example(enum llama_example ex) { + return examples.find(ex) != examples.end(); +} + +bool common_arg::is_exclude(enum llama_example ex) { + return excludes.find(ex) != excludes.end(); +} + +bool common_arg::get_value_from_env(std::string & output) { + if (env == nullptr) return false; + char * value = std::getenv(env); + if (value) { + output = value; + return true; + } + return false; +} + +bool common_arg::has_value_from_env() { + return env != nullptr && std::getenv(env); +} + +static std::vector break_str_into_lines(std::string input, size_t max_char_per_line) { + std::vector result; + std::istringstream iss(input); + std::string line; + auto add_line = [&](const std::string& l) { + if (l.length() <= max_char_per_line) { + result.push_back(l); + } else { + std::istringstream line_stream(l); + std::string word, current_line; + while (line_stream >> word) { + if (current_line.length() + !current_line.empty() + word.length() > max_char_per_line) { + if (!current_line.empty()) result.push_back(current_line); + current_line = word; + } else { + current_line += (!current_line.empty() ? " " : "") + word; + } + } + if (!current_line.empty()) result.push_back(current_line); + } + }; + while (std::getline(iss, line)) { + add_line(line); + } + return result; +} + +std::string common_arg::to_string() { + // params for printing to console + const static int n_leading_spaces = 40; + const static int n_char_per_line_help = 70; // TODO: detect this based on current console + std::string leading_spaces(n_leading_spaces, ' '); + + std::ostringstream ss; + for (const auto arg : args) { + if (arg == args.front()) { + if (args.size() == 1) { + ss << arg; + } else { + // first arg is usually abbreviation, we need padding to make it more beautiful + auto tmp = std::string(arg) + ", "; + auto spaces = std::string(std::max(0, 7 - (int)tmp.size()), ' '); + ss << tmp << spaces; + } + } else { + ss << arg << (arg != args.back() ? ", " : ""); + } + } + if (value_hint) ss << " " << value_hint; + if (value_hint_2) ss << " " << value_hint_2; + if (ss.tellp() > n_leading_spaces - 3) { + // current line is too long, add new line + ss << "\n" << leading_spaces; + } else { + // padding between arg and help, same line + ss << std::string(leading_spaces.size() - ss.tellp(), ' '); + } + const auto help_lines = break_str_into_lines(help, n_char_per_line_help); + for (const auto & line : help_lines) { + ss << (&line == &help_lines.front() ? "" : leading_spaces) << line << "\n"; + } + return ss.str(); +} + +// +// utils +// + +static void common_params_handle_model_default( + std::string & model, + const std::string & model_url, + std::string & hf_repo, + std::string & hf_file, + const std::string & hf_token, + const std::string & model_default) { + if (!hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (hf_file.empty()) { + if (model.empty()) { + auto auto_detected = common_get_hf_file(hf_repo, hf_token); + if (auto_detected.first.empty() || auto_detected.second.empty()) { + exit(1); // built without CURL, error message already printed + } + hf_repo = auto_detected.first; + hf_file = auto_detected.second; + } else { + hf_file = model; + } + } + // make sure model path is present (for caching purposes) + if (model.empty()) { + // this is to avoid different repo having same file name, or same file name in different subdirs + std::string filename = hf_repo + "_" + hf_file; + // to make sure we don't have any slashes in the filename + string_replace_all(filename, "/", "_"); + model = fs_get_cache_file(filename); + } + } else if (!model_url.empty()) { + if (model.empty()) { + auto f = string_split(model_url, '#').front(); + f = string_split(f, '?').front(); + model = fs_get_cache_file(string_split(f, '/').back()); + } + } else if (model.empty()) { + model = model_default; + } +} + +const std::vector kv_cache_types = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_IQ4_NL, + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, +}; + +static ggml_type kv_cache_type_from_str(const std::string & s) { + for (const auto & type : kv_cache_types) { + if (ggml_type_name(type) == s) { + return type; + } + } + throw std::runtime_error("Unsupported cache type: " + s); +} + +static std::string get_all_kv_cache_types() { + std::ostringstream msg; + for (const auto & type : kv_cache_types) { + msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", "); + } + return msg.str(); +} + +// +// CLI argument parsing functions +// + +static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) { + std::string arg; + const std::string arg_prefix = "--"; + common_params & params = ctx_arg.params; + + std::unordered_map arg_to_options; + for (auto & opt : ctx_arg.options) { + for (const auto & arg : opt.args) { + arg_to_options[arg] = &opt; + } + } + + // handle environment variables + for (auto & opt : ctx_arg.options) { + std::string value; + if (opt.get_value_from_env(value)) { + try { + if (opt.handler_void && (value == "1" || value == "true")) { + opt.handler_void(params); + } + if (opt.handler_int) { + opt.handler_int(params, std::stoi(value)); + } + if (opt.handler_string) { + opt.handler_string(params, value); + continue; + } + } catch (std::exception & e) { + throw std::invalid_argument(string_format( + "error while handling environment variable \"%s\": %s\n\n", opt.env, e.what())); + } + } + } + + // handle command line arguments + auto check_arg = [&](int i) { + if (i+1 >= argc) { + throw std::invalid_argument("expected value for argument"); + } + }; + + for (int i = 1; i < argc; i++) { + const std::string arg_prefix = "--"; + + std::string arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + if (arg_to_options.find(arg) == arg_to_options.end()) { + throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str())); + } + auto opt = *arg_to_options[arg]; + if (opt.has_value_from_env()) { + fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str()); + } + try { + if (opt.handler_void) { + opt.handler_void(params); + continue; + } + + // arg with single value + check_arg(i); + std::string val = argv[++i]; + if (opt.handler_int) { + opt.handler_int(params, std::stoi(val)); + continue; + } + if (opt.handler_string) { + opt.handler_string(params, val); + continue; + } + + // arg with 2 values + check_arg(i); + std::string val2 = argv[++i]; + if (opt.handler_str_str) { + opt.handler_str_str(params, val, val2); + continue; + } + } catch (std::exception & e) { + throw std::invalid_argument(string_format( + "error while handling argument \"%s\": %s\n\n" + "usage:\n%s\n\nto show complete usage, run with -h", + arg.c_str(), e.what(), arg_to_options[arg]->to_string().c_str())); + } + } + + postprocess_cpu_params(params.cpuparams, nullptr); + postprocess_cpu_params(params.cpuparams_batch, ¶ms.cpuparams); + + postprocess_cpu_params(params.speculative.cpuparams, ¶ms.cpuparams); + postprocess_cpu_params(params.speculative.cpuparams_batch, ¶ms.cpuparams_batch); + + if (params.prompt_cache_all && (params.interactive || params.interactive_first)) { + throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); + } + + // TODO: refactor model params in a common struct + common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH); + common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, ""); + common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, ""); + + if (params.escape) { + string_process_escapes(params.prompt); + string_process_escapes(params.input_prefix); + string_process_escapes(params.input_suffix); + for (auto & antiprompt : params.antiprompt) { + string_process_escapes(antiprompt); + } + for (auto & seq_breaker : params.sampling.dry_sequence_breakers) { + string_process_escapes(seq_breaker); + } + } + + if (!params.kv_overrides.empty()) { + params.kv_overrides.emplace_back(); + params.kv_overrides.back().key[0] = 0; + } + + if (params.reranking && params.embedding) { + throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); + } + + if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) { + throw std::runtime_error(string_format( + "error: the supplied chat template is not supported: %s%s\n", + params.chat_template.c_str(), + params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates" + )); + } + + return true; +} + +static void common_params_print_usage(common_params_context & ctx_arg) { + auto print_options = [](std::vector & options) { + for (common_arg * opt : options) { + printf("%s", opt->to_string().c_str()); + } + }; + + std::vector common_options; + std::vector sparam_options; + std::vector specific_options; + for (auto & opt : ctx_arg.options) { + // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example + if (opt.is_sparam) { + sparam_options.push_back(&opt); + } else if (opt.in_example(ctx_arg.ex)) { + specific_options.push_back(&opt); + } else { + common_options.push_back(&opt); + } + } + printf("----- common params -----\n\n"); + print_options(common_options); + printf("\n\n----- sampling params -----\n\n"); + print_options(sparam_options); + // TODO: maybe convert enum llama_example to string + printf("\n\n----- example-specific params -----\n\n"); + print_options(specific_options); +} + +static std::vector parse_device_list(const std::string & value) { + std::vector devices; + auto dev_names = string_split(value, ','); + if (dev_names.empty()) { + throw std::invalid_argument("no devices specified"); + } + if (dev_names.size() == 1 && dev_names[0] == "none") { + devices.push_back(nullptr); + } else { + for (const auto & device : dev_names) { + auto * dev = ggml_backend_dev_by_name(device.c_str()); + if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) { + throw std::invalid_argument(string_format("invalid device: %s", device.c_str())); + } + devices.push_back(dev); + } + devices.push_back(nullptr); + } + return devices; +} + +static void add_rpc_devices(std::string servers) { + auto rpc_servers = string_split(servers, ','); + if (rpc_servers.empty()) { + throw std::invalid_argument("no RPC servers specified"); + } + ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); + if (!rpc_reg) { + throw std::invalid_argument("failed to find RPC backend"); + } + typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); + ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); + if (!ggml_backend_rpc_add_device_fn) { + throw std::invalid_argument("failed to find RPC device add function"); + } + for (const auto & server : rpc_servers) { + ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); + if (dev) { + ggml_backend_device_register(dev); + } else { + throw std::invalid_argument("failed to register RPC device"); + } + } +} + +bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { + auto ctx_arg = common_params_parser_init(params, ex, print_usage); + const common_params params_org = ctx_arg.params; // the example can modify the default params + + try { + if (!common_params_parse_ex(argc, argv, ctx_arg)) { + ctx_arg.params = params_org; + return false; + } + if (ctx_arg.params.usage) { + common_params_print_usage(ctx_arg); + if (ctx_arg.print_usage) { + ctx_arg.print_usage(argc, argv); + } + exit(0); + } + } catch (const std::invalid_argument & ex) { + fprintf(stderr, "%s\n", ex.what()); + ctx_arg.params = params_org; + return false; + } + + return true; +} + +static std::string list_builtin_chat_templates() { + std::vector supported_tmpl; + int32_t res = llama_chat_builtin_templates(nullptr, 0); + supported_tmpl.resize(res); + res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size()); + std::ostringstream msg; + for (auto & tmpl : supported_tmpl) { + msg << tmpl << (&tmpl == &supported_tmpl.back() ? "" : ", "); + } + return msg.str(); +} + +common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { + // load dynamic backends + ggml_backend_load_all(); + + common_params_context ctx_arg(params); + ctx_arg.print_usage = print_usage; + ctx_arg.ex = ex; + + std::string sampler_type_chars; + std::string sampler_type_names; + for (const auto & sampler : params.sampling.samplers) { + sampler_type_chars += common_sampler_type_to_chr(sampler); + sampler_type_names += common_sampler_type_to_str(sampler) + ";"; + } + sampler_type_names.pop_back(); + + + /** + * filter options by example + * rules: + * - all examples inherit options from LLAMA_EXAMPLE_COMMON + * - if LLAMA_EXAMPLE_* is set (other than COMMON), we only show the option in the corresponding example + * - if both {LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_*,} are set, we will prioritize the LLAMA_EXAMPLE_* matching current example + */ + auto add_opt = [&](common_arg arg) { + if ((arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) && !arg.is_exclude(ex)) { + ctx_arg.options.push_back(std::move(arg)); + } + }; + + + add_opt(common_arg( + {"-h", "--help", "--usage"}, + "print usage and exit", + [](common_params & params) { + params.usage = true; + } + )); + add_opt(common_arg( + {"--version"}, + "show version and build info", + [](common_params &) { + fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); + fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); + exit(0); + } + )); + add_opt(common_arg( + {"--verbose-prompt"}, + string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"), + [](common_params & params) { + params.verbose_prompt = true; + } + )); + add_opt(common_arg( + {"--no-display-prompt"}, + string_format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"), + [](common_params & params) { + params.display_prompt = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-co", "--color"}, + string_format("colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false"), + [](common_params & params) { + params.use_color = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP})); + add_opt(common_arg( + {"-t", "--threads"}, "N", + string_format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads), + [](common_params & params, int value) { + params.cpuparams.n_threads = value; + if (params.cpuparams.n_threads <= 0) { + params.cpuparams.n_threads = std::thread::hardware_concurrency(); + } + } + ).set_env("LLAMA_ARG_THREADS")); + add_opt(common_arg( + {"-tb", "--threads-batch"}, "N", + "number of threads to use during batch and prompt processing (default: same as --threads)", + [](common_params & params, int value) { + params.cpuparams_batch.n_threads = value; + if (params.cpuparams_batch.n_threads <= 0) { + params.cpuparams_batch.n_threads = std::thread::hardware_concurrency(); + } + } + )); + add_opt(common_arg( + {"-C", "--cpu-mask"}, "M", + "CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")", + [](common_params & params, const std::string & mask) { + params.cpuparams.mask_valid = true; + if (!parse_cpu_mask(mask, params.cpuparams.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + )); + add_opt(common_arg( + {"-Cr", "--cpu-range"}, "lo-hi", + "range of CPUs for affinity. Complements --cpu-mask", + [](common_params & params, const std::string & range) { + params.cpuparams.mask_valid = true; + if (!parse_cpu_range(range, params.cpuparams.cpumask)) { + throw std::invalid_argument("invalid range"); + } + } + )); + add_opt(common_arg( + {"--cpu-strict"}, "<0|1>", + string_format("use strict CPU placement (default: %u)\n", (unsigned) params.cpuparams.strict_cpu), + [](common_params & params, const std::string & value) { + params.cpuparams.strict_cpu = std::stoul(value); + } + )); + add_opt(common_arg( + {"--prio"}, "N", + string_format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams.priority), + [](common_params & params, int prio) { + if (prio < 0 || prio > 3) { + throw std::invalid_argument("invalid value"); + } + params.cpuparams.priority = (enum ggml_sched_priority) prio; + } + )); + add_opt(common_arg( + {"--poll"}, "<0...100>", + string_format("use polling level to wait for work (0 - no polling, default: %u)\n", (unsigned) params.cpuparams.poll), + [](common_params & params, const std::string & value) { + params.cpuparams.poll = std::stoul(value); + } + )); + add_opt(common_arg( + {"-Cb", "--cpu-mask-batch"}, "M", + "CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask)", + [](common_params & params, const std::string & mask) { + params.cpuparams_batch.mask_valid = true; + if (!parse_cpu_mask(mask, params.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + )); + add_opt(common_arg( + {"-Crb", "--cpu-range-batch"}, "lo-hi", + "ranges of CPUs for affinity. Complements --cpu-mask-batch", + [](common_params & params, const std::string & range) { + params.cpuparams_batch.mask_valid = true; + if (!parse_cpu_range(range, params.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid range"); + } + } + )); + add_opt(common_arg( + {"--cpu-strict-batch"}, "<0|1>", + "use strict CPU placement (default: same as --cpu-strict)", + [](common_params & params, int value) { + params.cpuparams_batch.strict_cpu = value; + } + )); + add_opt(common_arg( + {"--prio-batch"}, "N", + string_format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams_batch.priority), + [](common_params & params, int prio) { + if (prio < 0 || prio > 3) { + throw std::invalid_argument("invalid value"); + } + params.cpuparams_batch.priority = (enum ggml_sched_priority) prio; + } + )); + add_opt(common_arg( + {"--poll-batch"}, "<0|1>", + "use polling to wait for work (default: same as --poll)", + [](common_params & params, int value) { + params.cpuparams_batch.poll = value; + } + )); + add_opt(common_arg( + {"-lcs", "--lookup-cache-static"}, "FNAME", + "path to static lookup cache to use for lookup decoding (not updated by generation)", + [](common_params & params, const std::string & value) { + params.lookup_cache_static = value; + } + ).set_examples({LLAMA_EXAMPLE_LOOKUP})); + add_opt(common_arg( + {"-lcd", "--lookup-cache-dynamic"}, "FNAME", + "path to dynamic lookup cache to use for lookup decoding (updated by generation)", + [](common_params & params, const std::string & value) { + params.lookup_cache_dynamic = value; + } + ).set_examples({LLAMA_EXAMPLE_LOOKUP})); + add_opt(common_arg( + {"-c", "--ctx-size"}, "N", + string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx), + [](common_params & params, int value) { + params.n_ctx = value; + } + ).set_env("LLAMA_ARG_CTX_SIZE")); + add_opt(common_arg( + {"-n", "--predict", "--n-predict"}, "N", + string_format("number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict), + [](common_params & params, int value) { + params.n_predict = value; + } + ).set_env("LLAMA_ARG_N_PREDICT")); + add_opt(common_arg( + {"-b", "--batch-size"}, "N", + string_format("logical maximum batch size (default: %d)", params.n_batch), + [](common_params & params, int value) { + params.n_batch = value; + } + ).set_env("LLAMA_ARG_BATCH")); + add_opt(common_arg( + {"-ub", "--ubatch-size"}, "N", + string_format("physical maximum batch size (default: %d)", params.n_ubatch), + [](common_params & params, int value) { + params.n_ubatch = value; + } + ).set_env("LLAMA_ARG_UBATCH")); + add_opt(common_arg( + {"--keep"}, "N", + string_format("number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep), + [](common_params & params, int value) { + params.n_keep = value; + } + )); + add_opt(common_arg( + {"--no-context-shift"}, + string_format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), + [](common_params & params) { + params.ctx_shift = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT")); + add_opt(common_arg( + {"--chunks"}, "N", + string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), + [](common_params & params, int value) { + params.n_chunks = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL})); + add_opt(common_arg( + {"-fa", "--flash-attn"}, + string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"), + [](common_params & params) { + params.flash_attn = true; + } + ).set_env("LLAMA_ARG_FLASH_ATTN")); + add_opt(common_arg( + {"-p", "--prompt"}, "PROMPT", + ex == LLAMA_EXAMPLE_MAIN + ? "prompt to start generation with\nif -cnv is set, this will be used as system prompt" + : "prompt to start generation with", + [](common_params & params, const std::string & value) { + params.prompt = value; + } + ).set_excludes({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--no-perf"}, + string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"), + [](common_params & params) { + params.no_perf = true; + params.sampling.no_perf = true; + } + ).set_env("LLAMA_ARG_NO_PERF")); + add_opt(common_arg( + {"-f", "--file"}, "FNAME", + "a file containing the prompt (default: none)", + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + // store the external file name in params + params.prompt_file = value; + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (!params.prompt.empty() && params.prompt.back() == '\n') { + params.prompt.pop_back(); + } + } + ).set_excludes({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--in-file"}, "FNAME", + "an input file (repeat to specify multiple files)", + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + params.in_files.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"-bf", "--binary-file"}, "FNAME", + "binary file containing the prompt (default: none)", + [](common_params & params, const std::string & value) { + std::ifstream file(value, std::ios::binary); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + // store the external file name in params + params.prompt_file = value; + std::ostringstream ss; + ss << file.rdbuf(); + params.prompt = ss.str(); + fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), value.c_str()); + } + ).set_excludes({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-e", "--escape"}, + string_format("process escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\) (default: %s)", params.escape ? "true" : "false"), + [](common_params & params) { + params.escape = true; + } + )); + add_opt(common_arg( + {"--no-escape"}, + "do not process escape sequences", + [](common_params & params) { + params.escape = false; + } + )); + add_opt(common_arg( + {"-ptc", "--print-token-count"}, "N", + string_format("print token count every N tokens (default: %d)", params.n_print), + [](common_params & params, int value) { + params.n_print = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--prompt-cache"}, "FNAME", + "file to cache prompt state for faster startup (default: none)", + [](common_params & params, const std::string & value) { + params.path_prompt_cache = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--prompt-cache-all"}, + "if specified, saves user input and generations to cache as well\n", + [](common_params & params) { + params.prompt_cache_all = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--prompt-cache-ro"}, + "if specified, uses the prompt cache but does not update it", + [](common_params & params) { + params.prompt_cache_ro = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-r", "--reverse-prompt"}, "PROMPT", + "halt generation at PROMPT, return control in interactive mode\n", + [](common_params & params, const std::string & value) { + params.antiprompt.emplace_back(value); + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-sp", "--special"}, + string_format("special tokens output enabled (default: %s)", params.special ? "true" : "false"), + [](common_params & params) { + params.special = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-cnv", "--conversation"}, + "run in conversation mode:\n" + "- does not print special tokens and suffix/prefix\n" + "- interactive mode is also enabled\n" + "(default: auto enabled if chat template is available)", + [](common_params & params) { + params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-no-cnv", "--no-conversation"}, + "force disable conversation mode (default: false)", + [](common_params & params) { + params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-i", "--interactive"}, + string_format("run in interactive mode (default: %s)", params.interactive ? "true" : "false"), + [](common_params & params) { + params.interactive = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-if", "--interactive-first"}, + string_format("run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false"), + [](common_params & params) { + params.interactive_first = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-mli", "--multiline-input"}, + "allows you to write or paste multiple lines without ending each in '\\'", + [](common_params & params) { + params.multiline_input = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--in-prefix-bos"}, + "prefix BOS to user inputs, preceding the `--in-prefix` string", + [](common_params & params) { + params.input_prefix_bos = true; + params.enable_chat_template = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"--in-prefix"}, "STRING", + "string to prefix user inputs with (default: empty)", + [](common_params & params, const std::string & value) { + params.input_prefix = value; + params.enable_chat_template = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); + add_opt(common_arg( + {"--in-suffix"}, "STRING", + "string to suffix after user inputs with (default: empty)", + [](common_params & params, const std::string & value) { + params.input_suffix = value; + params.enable_chat_template = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); + add_opt(common_arg( + {"--no-warmup"}, + "skip warming up the model with an empty run", + [](common_params & params) { + params.warmup = false; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING})); + add_opt(common_arg( + {"--spm-infill"}, + string_format( + "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", + params.spm_infill ? "enabled" : "disabled" + ), + [](common_params & params) { + params.spm_infill = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_INFILL})); + add_opt(common_arg( + {"--samplers"}, "SAMPLERS", + string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), + [](common_params & params, const std::string & value) { + const auto sampler_names = string_split(value, ';'); + params.sampling.samplers = common_sampler_types_from_names(sampler_names, true); + } + ).set_sparam()); + add_opt(common_arg( + {"-s", "--seed"}, "SEED", + string_format("RNG seed (default: %d, use random seed for %d)", params.sampling.seed, LLAMA_DEFAULT_SEED), + [](common_params & params, const std::string & value) { + params.sampling.seed = std::stoul(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--sampling-seq", "--sampler-seq"}, "SEQUENCE", + string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()), + [](common_params & params, const std::string & value) { + params.sampling.samplers = common_sampler_types_from_chars(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--ignore-eos"}, + "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)", + [](common_params & params) { + params.sampling.ignore_eos = true; + } + ).set_sparam()); + add_opt(common_arg( + {"--temp"}, "N", + string_format("temperature (default: %.1f)", (double)params.sampling.temp), + [](common_params & params, const std::string & value) { + params.sampling.temp = std::stof(value); + params.sampling.temp = std::max(params.sampling.temp, 0.0f); + } + ).set_sparam()); + add_opt(common_arg( + {"--top-k"}, "N", + string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k), + [](common_params & params, int value) { + params.sampling.top_k = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--top-p"}, "N", + string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), + [](common_params & params, const std::string & value) { + params.sampling.top_p = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--min-p"}, "N", + string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p), + [](common_params & params, const std::string & value) { + params.sampling.min_p = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--xtc-probability"}, "N", + string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability), + [](common_params & params, const std::string & value) { + params.sampling.xtc_probability = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--xtc-threshold"}, "N", + string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold), + [](common_params & params, const std::string & value) { + params.sampling.xtc_threshold = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--typical"}, "N", + string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p), + [](common_params & params, const std::string & value) { + params.sampling.typ_p = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--repeat-last-n"}, "N", + string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n), + [](common_params & params, int value) { + if (value < -1) { + throw std::runtime_error(string_format("error: invalid repeat-last-n = %d\n", value)); + } + params.sampling.penalty_last_n = value; + params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n); + } + ).set_sparam()); + add_opt(common_arg( + {"--repeat-penalty"}, "N", + string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat), + [](common_params & params, const std::string & value) { + params.sampling.penalty_repeat = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--presence-penalty"}, "N", + string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present), + [](common_params & params, const std::string & value) { + params.sampling.penalty_present = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--frequency-penalty"}, "N", + string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq), + [](common_params & params, const std::string & value) { + params.sampling.penalty_freq = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-multiplier"}, "N", + string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier), + [](common_params & params, const std::string & value) { + params.sampling.dry_multiplier = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-base"}, "N", + string_format("set DRY sampling base value (default: %.2f)", (double)params.sampling.dry_base), + [](common_params & params, const std::string & value) { + float potential_base = std::stof(value); + if (potential_base >= 1.0f) + { + params.sampling.dry_base = potential_base; + } + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-allowed-length"}, "N", + string_format("set allowed length for DRY sampling (default: %d)", params.sampling.dry_allowed_length), + [](common_params & params, int value) { + params.sampling.dry_allowed_length = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-penalty-last-n"}, "N", + string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n), + [](common_params & params, int value) { + if (value < -1) { + throw std::runtime_error(string_format("error: invalid dry-penalty-last-n = %d\n", value)); + } + params.sampling.dry_penalty_last_n = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--dry-sequence-breaker"}, "STRING", + string_format("add sequence breaker for DRY sampling, clearing out default breakers (%s) in the process; use \"none\" to not use any sequence breakers\n", + params.sampling.dry_sequence_breakers.empty() ? "none" : + std::accumulate(std::next(params.sampling.dry_sequence_breakers.begin()), + params.sampling.dry_sequence_breakers.end(), + std::string("'") + (params.sampling.dry_sequence_breakers[0] == "\n" ? "\\n" : params.sampling.dry_sequence_breakers[0]) + "'", + [](const std::string& a, const std::string& b) { + std::string formatted_b = (b == "\n") ? "\\n" : b; + return a + ", '" + formatted_b + "'"; + }).c_str()), + [](common_params & params, const std::string & value) { + static bool defaults_cleared = false; + + if (!defaults_cleared) { + params.sampling.dry_sequence_breakers.clear(); + defaults_cleared = true; + } + + if (value == "none") { + params.sampling.dry_sequence_breakers.clear(); + } else { + params.sampling.dry_sequence_breakers.emplace_back(value); + } + } + ).set_sparam()); + add_opt(common_arg( + {"--dynatemp-range"}, "N", + string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range), + [](common_params & params, const std::string & value) { + params.sampling.dynatemp_range = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--dynatemp-exp"}, "N", + string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent), + [](common_params & params, const std::string & value) { + params.sampling.dynatemp_exponent = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--mirostat"}, "N", + string_format("use Mirostat sampling.\nTop K, Nucleus and Locally Typical samplers are ignored if used.\n" + "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat), + [](common_params & params, int value) { + params.sampling.mirostat = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--mirostat-lr"}, "N", + string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta), + [](common_params & params, const std::string & value) { + params.sampling.mirostat_eta = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"--mirostat-ent"}, "N", + string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau), + [](common_params & params, const std::string & value) { + params.sampling.mirostat_tau = std::stof(value); + } + ).set_sparam()); + add_opt(common_arg( + {"-l", "--logit-bias"}, "TOKEN_ID(+/-)BIAS", + "modifies the likelihood of token appearing in the completion,\n" + "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" + "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'", + [](common_params & params, const std::string & value) { + std::stringstream ss(value); + llama_token key; + char sign; + std::string value_str; + try { + if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { + const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + params.sampling.logit_bias.push_back({key, bias}); + } else { + throw std::invalid_argument("invalid input format"); + } + } catch (const std::exception&) { + throw std::invalid_argument("invalid input format"); + } + } + ).set_sparam()); + add_opt(common_arg( + {"--grammar"}, "GRAMMAR", + string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()), + [](common_params & params, const std::string & value) { + params.sampling.grammar = value; + } + ).set_sparam()); + add_opt(common_arg( + {"--grammar-file"}, "FNAME", + "file to read grammar from", + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(params.sampling.grammar) + ); + } + ).set_sparam()); + add_opt(common_arg( + {"-j", "--json-schema"}, "SCHEMA", + "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", + [](common_params & params, const std::string & value) { + params.sampling.grammar = json_schema_to_grammar(json::parse(value)); + } + ).set_sparam()); + add_opt(common_arg( + {"--pooling"}, "{none,mean,cls,last,rank}", + "pooling type for embeddings, use model default if unspecified", + [](common_params & params, const std::string & value) { + /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } + else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } + else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); + add_opt(common_arg( + {"--attention"}, "{causal,non-causal}", + "attention type for embeddings, use model default if unspecified", + [](common_params & params, const std::string & value) { + /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } + else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + add_opt(common_arg( + {"--rope-scaling"}, "{none,linear,yarn}", + "RoPE frequency scaling method, defaults to linear unless specified by the model", + [](common_params & params, const std::string & value) { + /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } + else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } + else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_env("LLAMA_ARG_ROPE_SCALING_TYPE")); + add_opt(common_arg( + {"--rope-scale"}, "N", + "RoPE context scaling factor, expands context by a factor of N", + [](common_params & params, const std::string & value) { + params.rope_freq_scale = 1.0f / std::stof(value); + } + ).set_env("LLAMA_ARG_ROPE_SCALE")); + add_opt(common_arg( + {"--rope-freq-base"}, "N", + "RoPE base frequency, used by NTK-aware scaling (default: loaded from model)", + [](common_params & params, const std::string & value) { + params.rope_freq_base = std::stof(value); + } + ).set_env("LLAMA_ARG_ROPE_FREQ_BASE")); + add_opt(common_arg( + {"--rope-freq-scale"}, "N", + "RoPE frequency scaling factor, expands context by a factor of 1/N", + [](common_params & params, const std::string & value) { + params.rope_freq_scale = std::stof(value); + } + ).set_env("LLAMA_ARG_ROPE_FREQ_SCALE")); + add_opt(common_arg( + {"--yarn-orig-ctx"}, "N", + string_format("YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx), + [](common_params & params, int value) { + params.yarn_orig_ctx = value; + } + ).set_env("LLAMA_ARG_YARN_ORIG_CTX")); + add_opt(common_arg( + {"--yarn-ext-factor"}, "N", + string_format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor), + [](common_params & params, const std::string & value) { + params.yarn_ext_factor = std::stof(value); + } + ).set_env("LLAMA_ARG_YARN_EXT_FACTOR")); + add_opt(common_arg( + {"--yarn-attn-factor"}, "N", + string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor), + [](common_params & params, const std::string & value) { + params.yarn_attn_factor = std::stof(value); + } + ).set_env("LLAMA_ARG_YARN_ATTN_FACTOR")); + add_opt(common_arg( + {"--yarn-beta-slow"}, "N", + string_format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow), + [](common_params & params, const std::string & value) { + params.yarn_beta_slow = std::stof(value); + } + ).set_env("LLAMA_ARG_YARN_BETA_SLOW")); + add_opt(common_arg( + {"--yarn-beta-fast"}, "N", + string_format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast), + [](common_params & params, const std::string & value) { + params.yarn_beta_fast = std::stof(value); + } + ).set_env("LLAMA_ARG_YARN_BETA_FAST")); + add_opt(common_arg( + {"-gan", "--grp-attn-n"}, "N", + string_format("group-attention factor (default: %d)", params.grp_attn_n), + [](common_params & params, int value) { + params.grp_attn_n = value; + } + ).set_env("LLAMA_ARG_GRP_ATTN_N").set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_PASSKEY})); + add_opt(common_arg( + {"-gaw", "--grp-attn-w"}, "N", + string_format("group-attention width (default: %d)", params.grp_attn_w), + [](common_params & params, int value) { + params.grp_attn_w = value; + } + ).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN})); + add_opt(common_arg( + {"-dkvc", "--dump-kv-cache"}, + "verbose print of the KV cache", + [](common_params & params) { + params.dump_kv_cache = true; + } + )); + add_opt(common_arg( + {"-nkvo", "--no-kv-offload"}, + "disable KV offload", + [](common_params & params) { + params.no_kv_offload = true; + } + ).set_env("LLAMA_ARG_NO_KV_OFFLOAD")); + add_opt(common_arg( + {"-ctk", "--cache-type-k"}, "TYPE", + string_format( + "KV cache data type for K\n" + "allowed values: %s\n" + "(default: %s)", + get_all_kv_cache_types().c_str(), + ggml_type_name(params.cache_type_k) + ), + [](common_params & params, const std::string & value) { + params.cache_type_k = kv_cache_type_from_str(value); + } + ).set_env("LLAMA_ARG_CACHE_TYPE_K")); + add_opt(common_arg( + {"-ctv", "--cache-type-v"}, "TYPE", + string_format( + "KV cache data type for V\n" + "allowed values: %s\n" + "(default: %s)", + get_all_kv_cache_types().c_str(), + ggml_type_name(params.cache_type_v) + ), + [](common_params & params, const std::string & value) { + params.cache_type_v = kv_cache_type_from_str(value); + } + ).set_env("LLAMA_ARG_CACHE_TYPE_V")); + add_opt(common_arg( + {"--perplexity", "--all-logits"}, + string_format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"), + [](common_params & params) { + params.logits_all = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--hellaswag"}, + "compute HellaSwag score over random tasks from datafile supplied with -f", + [](common_params & params) { + params.hellaswag = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--hellaswag-tasks"}, "N", + string_format("number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks), + [](common_params & params, int value) { + params.hellaswag_tasks = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--winogrande"}, + "compute Winogrande score over random tasks from datafile supplied with -f", + [](common_params & params) { + params.winogrande = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--winogrande-tasks"}, "N", + string_format("number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks), + [](common_params & params, int value) { + params.winogrande_tasks = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--multiple-choice"}, + "compute multiple choice score over random tasks from datafile supplied with -f", + [](common_params & params) { + params.multiple_choice = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--multiple-choice-tasks"}, "N", + string_format("number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks), + [](common_params & params, int value) { + params.multiple_choice_tasks = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--kl-divergence"}, + "computes KL-divergence to logits provided via --kl-divergence-base", + [](common_params & params) { + params.kl_divergence = true; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--save-all-logits", "--kl-divergence-base"}, "FNAME", + "set logits file", + [](common_params & params, const std::string & value) { + params.logits_file = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--ppl-stride"}, "N", + string_format("stride for perplexity calculation (default: %d)", params.ppl_stride), + [](common_params & params, int value) { + params.ppl_stride = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"--ppl-output-type"}, "<0|1>", + string_format("output type for perplexity calculation (default: %d)", params.ppl_output_type), + [](common_params & params, int value) { + params.ppl_output_type = value; + } + ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); + add_opt(common_arg( + {"-dt", "--defrag-thold"}, "N", + string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold), + [](common_params & params, const std::string & value) { + params.defrag_thold = std::stof(value); + } + ).set_env("LLAMA_ARG_DEFRAG_THOLD")); + add_opt(common_arg( + {"-np", "--parallel"}, "N", + string_format("number of parallel sequences to decode (default: %d)", params.n_parallel), + [](common_params & params, int value) { + params.n_parallel = value; + } + ).set_env("LLAMA_ARG_N_PARALLEL")); + add_opt(common_arg( + {"-ns", "--sequences"}, "N", + string_format("number of sequences to decode (default: %d)", params.n_sequences), + [](common_params & params, int value) { + params.n_sequences = value; + } + ).set_examples({LLAMA_EXAMPLE_PARALLEL})); + add_opt(common_arg( + {"-cb", "--cont-batching"}, + string_format("enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled"), + [](common_params & params) { + params.cont_batching = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CONT_BATCHING")); + add_opt(common_arg( + {"-nocb", "--no-cont-batching"}, + "disable continuous batching", + [](common_params & params) { + params.cont_batching = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONT_BATCHING")); + add_opt(common_arg( + {"--mmproj"}, "FILE", + "path to a multimodal projector file for LLaVA. see examples/llava/README.md", + [](common_params & params, const std::string & value) { + params.mmproj = value; + } + ).set_examples({LLAMA_EXAMPLE_LLAVA})); + add_opt(common_arg( + {"--image"}, "FILE", + "path to an image file. use with multimodal models. Specify multiple times for batching", + [](common_params & params, const std::string & value) { + params.image.emplace_back(value); + } + ).set_examples({LLAMA_EXAMPLE_LLAVA})); + if (llama_supports_rpc()) { + add_opt(common_arg( + {"--rpc"}, "SERVERS", + "comma separated list of RPC servers", + [](common_params & params, const std::string & value) { + add_rpc_devices(value); + GGML_UNUSED(params); + } + ).set_env("LLAMA_ARG_RPC")); + } + add_opt(common_arg( + {"--mlock"}, + "force system to keep model in RAM rather than swapping or compressing", + [](common_params & params) { + params.use_mlock = true; + } + ).set_env("LLAMA_ARG_MLOCK")); + add_opt(common_arg( + {"--no-mmap"}, + "do not memory-map model (slower load but may reduce pageouts if not using mlock)", + [](common_params & params) { + params.use_mmap = false; + } + ).set_env("LLAMA_ARG_NO_MMAP")); + add_opt(common_arg( + {"--numa"}, "TYPE", + "attempt optimizations that help on some NUMA systems\n" + "- distribute: spread execution evenly over all nodes\n" + "- isolate: only spawn threads on CPUs on the node that execution started on\n" + "- numactl: use the CPU map provided by numactl\n" + "if run without this previously, it is recommended to drop the system page cache before using this\n" + "see https://github.com/ggerganov/llama.cpp/issues/1437", + [](common_params & params, const std::string & value) { + /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } + else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } + else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_env("LLAMA_ARG_NUMA")); + add_opt(common_arg( + {"-dev", "--device"}, "", + "comma-separated list of devices to use for offloading (none = don't offload)\n" + "use --list-devices to see a list of available devices", + [](common_params & params, const std::string & value) { + params.devices = parse_device_list(value); + } + ).set_env("LLAMA_ARG_DEVICE")); + add_opt(common_arg( + {"--list-devices"}, + "print list of available devices and exit", + [](common_params &) { + printf("Available devices:\n"); + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { + size_t free, total; + ggml_backend_dev_memory(dev, &free, &total); + printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024); + } + } + exit(0); + } + )); + add_opt(common_arg( + {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", + "number of layers to store in VRAM", + [](common_params & params, int value) { + params.n_gpu_layers = value; + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: no usable GPU found, --gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n"); + fprintf(stderr, "warning: consult docs/build.md for compilation instructions\n"); + } + } + ).set_env("LLAMA_ARG_N_GPU_LAYERS")); + add_opt(common_arg( + {"-sm", "--split-mode"}, "{none,layer,row}", + "how to split the model across multiple GPUs, one of:\n" + "- none: use one GPU only\n" + "- layer (default): split layers and KV across GPUs\n" + "- row: split rows across GPUs", + [](common_params & params, const std::string & value) { + std::string arg_next = value; + if (arg_next == "none") { + params.split_mode = LLAMA_SPLIT_MODE_NONE; + } else if (arg_next == "layer") { + params.split_mode = LLAMA_SPLIT_MODE_LAYER; + } else if (arg_next == "row") { + params.split_mode = LLAMA_SPLIT_MODE_ROW; + } else { + throw std::invalid_argument("invalid value"); + } + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the split mode has no effect.\n"); + } + } + ).set_env("LLAMA_ARG_SPLIT_MODE")); + add_opt(common_arg( + {"-ts", "--tensor-split"}, "N0,N1,N2,...", + "fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1", + [](common_params & params, const std::string & value) { + std::string arg_next = value; + + // split string by , and / + const std::regex regex{ R"([,/]+)" }; + std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; + std::vector split_arg{ it, {} }; + if (split_arg.size() >= llama_max_devices()) { + throw std::invalid_argument( + string_format("got %d input configs, but system only has %d devices", (int)split_arg.size(), (int)llama_max_devices()) + ); + } + for (size_t i = 0; i < llama_max_devices(); ++i) { + if (i < split_arg.size()) { + params.tensor_split[i] = std::stof(split_arg[i]); + } else { + params.tensor_split[i] = 0.0f; + } + } + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting a tensor split has no effect.\n"); + } + } + ).set_env("LLAMA_ARG_TENSOR_SPLIT")); + add_opt(common_arg( + {"-mg", "--main-gpu"}, "INDEX", + string_format("the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu), + [](common_params & params, int value) { + params.main_gpu = value; + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: llama.cpp was compiled without support for GPU offload. Setting the main GPU has no effect.\n"); + } + } + ).set_env("LLAMA_ARG_MAIN_GPU")); + add_opt(common_arg( + {"--check-tensors"}, + string_format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"), + [](common_params & params) { + params.check_tensors = true; + } + )); + add_opt(common_arg( + {"--override-kv"}, "KEY=TYPE:VALUE", + "advanced option to override model metadata by key. may be specified multiple times.\n" + "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false", + [](common_params & params, const std::string & value) { + if (!string_parse_kv_override(value.c_str(), params.kv_overrides)) { + throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", value.c_str())); + } + } + )); + add_opt(common_arg( + {"--lora"}, "FNAME", + "path to LoRA adapter (can be repeated to use multiple adapters)", + [](common_params & params, const std::string & value) { + params.lora_adapters.push_back({ std::string(value), 1.0, nullptr }); + } + // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); + add_opt(common_arg( + {"--lora-scaled"}, "FNAME", "SCALE", + "path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)", + [](common_params & params, const std::string & fname, const std::string & scale) { + params.lora_adapters.push_back({ fname, std::stof(scale), nullptr }); + } + // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); + add_opt(common_arg( + {"--control-vector"}, "FNAME", + "add a control vector\nnote: this argument can be repeated to add multiple control vectors", + [](common_params & params, const std::string & value) { + params.control_vectors.push_back({ 1.0f, value, }); + } + )); + add_opt(common_arg( + {"--control-vector-scaled"}, "FNAME", "SCALE", + "add a control vector with user defined scaling SCALE\n" + "note: this argument can be repeated to add multiple scaled control vectors", + [](common_params & params, const std::string & fname, const std::string & scale) { + params.control_vectors.push_back({ std::stof(scale), fname }); + } + )); + add_opt(common_arg( + {"--control-vector-layer-range"}, "START", "END", + "layer range to apply the control vector(s) to, start and end inclusive", + [](common_params & params, const std::string & start, const std::string & end) { + params.control_vector_layer_start = std::stoi(start); + params.control_vector_layer_end = std::stoi(end); + } + )); + add_opt(common_arg( + {"-a", "--alias"}, "STRING", + "set alias for model name (to be used by REST API)", + [](common_params & params, const std::string & value) { + params.model_alias = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALIAS")); + add_opt(common_arg( + {"-m", "--model"}, "FNAME", + ex == LLAMA_EXAMPLE_EXPORT_LORA + ? std::string("model path from which to load base model") + : string_format( + "model path (default: `models/$filename` with filename from `--hf-file` " + "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH + ), + [](common_params & params, const std::string & value) { + params.model = value; + } + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); + add_opt(common_arg( + {"-mu", "--model-url"}, "MODEL_URL", + "model download url (default: unused)", + [](common_params & params, const std::string & value) { + params.model_url = value; + } + ).set_env("LLAMA_ARG_MODEL_URL")); + add_opt(common_arg( + {"-hf", "-hfr", "--hf-repo"}, "/[:quant]", + "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n" + "example: unsloth/phi-4-GGUF:q4_k_m\n" + "(default: unused)", + [](common_params & params, const std::string & value) { + params.hf_repo = value; + } + ).set_env("LLAMA_ARG_HF_REPO")); + add_opt(common_arg( + {"-hfd", "-hfrd", "--hf-repo-draft"}, "/[:quant]", + "Same as --hf-repo, but for the draft model (default: unused)", + [](common_params & params, const std::string & value) { + params.speculative.hf_repo = value; + } + ).set_env("LLAMA_ARG_HFD_REPO")); + add_opt(common_arg( + {"-hff", "--hf-file"}, "FILE", + "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", + [](common_params & params, const std::string & value) { + params.hf_file = value; + } + ).set_env("LLAMA_ARG_HF_FILE")); + add_opt(common_arg( + {"-hfv", "-hfrv", "--hf-repo-v"}, "/[:quant]", + "Hugging Face model repository for the vocoder model (default: unused)", + [](common_params & params, const std::string & value) { + params.vocoder.hf_repo = value; + } + ).set_env("LLAMA_ARG_HF_REPO_V")); + add_opt(common_arg( + {"-hffv", "--hf-file-v"}, "FILE", + "Hugging Face model file for the vocoder model (default: unused)", + [](common_params & params, const std::string & value) { + params.vocoder.hf_file = value; + } + ).set_env("LLAMA_ARG_HF_FILE_V")); + add_opt(common_arg( + {"-hft", "--hf-token"}, "TOKEN", + "Hugging Face access token (default: value from HF_TOKEN environment variable)", + [](common_params & params, const std::string & value) { + params.hf_token = value; + } + ).set_env("HF_TOKEN")); + add_opt(common_arg( + {"--context-file"}, "FNAME", + "file to load context from (repeat to specify multiple files)", + [](common_params & params, const std::string & value) { + std::ifstream file(value, std::ios::binary); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + params.context_files.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_RETRIEVAL})); + add_opt(common_arg( + {"--chunk-size"}, "N", + string_format("minimum length of embedded text chunks (default: %d)", params.chunk_size), + [](common_params & params, int value) { + params.chunk_size = value; + } + ).set_examples({LLAMA_EXAMPLE_RETRIEVAL})); + add_opt(common_arg( + {"--chunk-separator"}, "STRING", + string_format("separator between chunks (default: '%s')", params.chunk_separator.c_str()), + [](common_params & params, const std::string & value) { + params.chunk_separator = value; + } + ).set_examples({LLAMA_EXAMPLE_RETRIEVAL})); + add_opt(common_arg( + {"--junk"}, "N", + string_format("number of times to repeat the junk text (default: %d)", params.n_junk), + [](common_params & params, int value) { + params.n_junk = value; + } + ).set_examples({LLAMA_EXAMPLE_PASSKEY})); + add_opt(common_arg( + {"--pos"}, "N", + string_format("position of the passkey in the junk text (default: %d)", params.i_pos), + [](common_params & params, int value) { + params.i_pos = value; + } + ).set_examples({LLAMA_EXAMPLE_PASSKEY})); + add_opt(common_arg( + {"-o", "--output", "--output-file"}, "FNAME", + string_format("output file (default: '%s')", + ex == LLAMA_EXAMPLE_EXPORT_LORA + ? params.lora_outfile.c_str() + : ex == LLAMA_EXAMPLE_CVECTOR_GENERATOR + ? params.cvector_outfile.c_str() + : params.out_file.c_str()), + [](common_params & params, const std::string & value) { + params.out_file = value; + params.cvector_outfile = value; + params.lora_outfile = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA})); + add_opt(common_arg( + {"-ofreq", "--output-frequency"}, "N", + string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq), + [](common_params & params, int value) { + params.n_out_freq = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--save-frequency"}, "N", + string_format("save an imatrix copy every N iterations (default: %d)", params.n_save_freq), + [](common_params & params, int value) { + params.n_save_freq = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--process-output"}, + string_format("collect data for the output tensor (default: %s)", params.process_output ? "true" : "false"), + [](common_params & params) { + params.process_output = true; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--no-ppl"}, + string_format("do not compute perplexity (default: %s)", params.compute_ppl ? "true" : "false"), + [](common_params & params) { + params.compute_ppl = false; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--chunk", "--from-chunk"}, "N", + string_format("start processing the input from chunk N (default: %d)", params.i_chunk), + [](common_params & params, int value) { + params.i_chunk = value; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"-pps"}, + string_format("is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false"), + [](common_params & params) { + params.is_pp_shared = true; + } + ).set_examples({LLAMA_EXAMPLE_BENCH})); + add_opt(common_arg( + {"-npp"}, "n0,n1,...", + "number of prompt tokens", + [](common_params & params, const std::string & value) { + auto p = string_split(value, ','); + params.n_pp.insert(params.n_pp.end(), p.begin(), p.end()); + } + ).set_examples({LLAMA_EXAMPLE_BENCH})); + add_opt(common_arg( + {"-ntg"}, "n0,n1,...", + "number of text generation tokens", + [](common_params & params, const std::string & value) { + auto p = string_split(value, ','); + params.n_tg.insert(params.n_tg.end(), p.begin(), p.end()); + } + ).set_examples({LLAMA_EXAMPLE_BENCH})); + add_opt(common_arg( + {"-npl"}, "n0,n1,...", + "number of parallel prompts", + [](common_params & params, const std::string & value) { + auto p = string_split(value, ','); + params.n_pl.insert(params.n_pl.end(), p.begin(), p.end()); + } + ).set_examples({LLAMA_EXAMPLE_BENCH})); + add_opt(common_arg( + {"--embd-normalize"}, "N", + string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize), + [](common_params & params, int value) { + params.embd_normalize = value; + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + add_opt(common_arg( + {"--embd-output-format"}, "FORMAT", + "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix", + [](common_params & params, const std::string & value) { + params.embd_out = value; + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + add_opt(common_arg( + {"--embd-separator"}, "STRING", + "separator of embeddings (default \\n) for example \"<#sep#>\"", + [](common_params & params, const std::string & value) { + params.embd_sep = value; + } + ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + add_opt(common_arg( + {"--host"}, "HOST", + string_format("ip address to listen (default: %s)", params.hostname.c_str()), + [](common_params & params, const std::string & value) { + params.hostname = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_HOST")); + add_opt(common_arg( + {"--port"}, "PORT", + string_format("port to listen (default: %d)", params.port), + [](common_params & params, int value) { + params.port = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT")); + add_opt(common_arg( + {"--path"}, "PATH", + string_format("path to serve static files from (default: %s)", params.public_path.c_str()), + [](common_params & params, const std::string & value) { + params.public_path = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_STATIC_PATH")); + add_opt(common_arg( + {"--no-webui"}, + string_format("Disable the Web UI (default: %s)", params.webui ? "enabled" : "disabled"), + [](common_params & params) { + params.webui = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_WEBUI")); + add_opt(common_arg( + {"--embedding", "--embeddings"}, + string_format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"), + [](common_params & params) { + params.embedding = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + add_opt(common_arg( + {"--reranking", "--rerank"}, + string_format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"), + [](common_params & params) { + params.reranking = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); + add_opt(common_arg( + {"--api-key"}, "KEY", + "API key to use for authentication (default: none)", + [](common_params & params, const std::string & value) { + params.api_keys.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); + add_opt(common_arg( + {"--api-key-file"}, "FNAME", + "path to file containing API keys (default: none)", + [](common_params & params, const std::string & value) { + std::ifstream key_file(value); + if (!key_file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::string key; + while (std::getline(key_file, key)) { + if (!key.empty()) { + params.api_keys.push_back(key); + } + } + key_file.close(); + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--ssl-key-file"}, "FNAME", + "path to file a PEM-encoded SSL private key", + [](common_params & params, const std::string & value) { + params.ssl_file_key = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_KEY_FILE")); + add_opt(common_arg( + {"--ssl-cert-file"}, "FNAME", + "path to file a PEM-encoded SSL certificate", + [](common_params & params, const std::string & value) { + params.ssl_file_cert = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE")); + add_opt(common_arg( + {"-to", "--timeout"}, "N", + string_format("server read/write timeout in seconds (default: %d)", params.timeout_read), + [](common_params & params, int value) { + params.timeout_read = value; + params.timeout_write = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TIMEOUT")); + add_opt(common_arg( + {"--threads-http"}, "N", + string_format("number of threads used to process HTTP requests (default: %d)", params.n_threads_http), + [](common_params & params, int value) { + params.n_threads_http = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP")); + add_opt(common_arg( + {"--cache-reuse"}, "N", + string_format("min chunk size to attempt reusing from the cache via KV shifting (default: %d)", params.n_cache_reuse), + [](common_params & params, int value) { + params.n_cache_reuse = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CACHE_REUSE")); + add_opt(common_arg( + {"--metrics"}, + string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"), + [](common_params & params) { + params.endpoint_metrics = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS")); + add_opt(common_arg( + {"--slots"}, + string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"), + [](common_params & params) { + params.endpoint_slots = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS")); + add_opt(common_arg( + {"--props"}, + string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"), + [](common_params & params) { + params.endpoint_props = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS")); + add_opt(common_arg( + {"--no-slots"}, + "disables slots monitoring endpoint", + [](common_params & params) { + params.endpoint_slots = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_ENDPOINT_SLOTS")); + add_opt(common_arg( + {"--slot-save-path"}, "PATH", + "path to save slot kv cache (default: disabled)", + [](common_params & params, const std::string & value) { + params.slot_save_path = value; + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { + params.slot_save_path += DIRECTORY_SEPARATOR; + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--jinja"}, + "use jinja template for chat (default: disabled)", + [](common_params & params) { + params.use_jinja = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); + add_opt(common_arg( + {"--chat-template"}, "JINJA_TEMPLATE", + string_format( + "set custom jinja chat template (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" + "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() + ), + [](common_params & params, const std::string & value) { + params.chat_template = value; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); + add_opt(common_arg( + {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", + string_format( + "set custom jinja chat template file (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" + "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() + ), + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(params.chat_template)); + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( + {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", + string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), + [](common_params & params, const std::string & value) { + params.slot_prompt_similarity = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--lora-init-without-apply"}, + string_format("load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"), + [](common_params & params) { + params.lora_init_without_apply = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--simple-io"}, + "use basic IO for better compatibility in subprocesses and limited consoles", + [](common_params & params) { + params.simple_io = true; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); + add_opt(common_arg( + {"--positive-file"}, "FNAME", + string_format("positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str()), + [](common_params & params, const std::string & value) { + params.cvector_positive_file = value; + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--negative-file"}, "FNAME", + string_format("negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str()), + [](common_params & params, const std::string & value) { + params.cvector_negative_file = value; + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--pca-batch"}, "N", + string_format("batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch), + [](common_params & params, int value) { + params.n_pca_batch = value; + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--pca-iter"}, "N", + string_format("number of iterations used for PCA (default: %d)", params.n_pca_iterations), + [](common_params & params, int value) { + params.n_pca_iterations = value; + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--method"}, "{pca, mean}", + "dimensionality reduction method to be used (default: pca)", + [](common_params & params, const std::string & value) { + /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; } + else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; } + else { throw std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); + add_opt(common_arg( + {"--output-format"}, "{md,jsonl}", + "output format for batched-bench results (default: md)", + [](common_params & params, const std::string & value) { + /**/ if (value == "jsonl") { params.batched_bench_output_jsonl = true; } + else if (value == "md") { params.batched_bench_output_jsonl = false; } + else { std::invalid_argument("invalid value"); } + } + ).set_examples({LLAMA_EXAMPLE_BENCH})); + add_opt(common_arg( + {"--log-disable"}, + "Log disable", + [](common_params &) { + common_log_pause(common_log_main()); + } + )); + add_opt(common_arg( + {"--log-file"}, "FNAME", + "Log to file", + [](common_params &, const std::string & value) { + common_log_set_file(common_log_main(), value.c_str()); + } + )); + add_opt(common_arg( + {"--log-colors"}, + "Enable colored logging", + [](common_params &) { + common_log_set_colors(common_log_main(), true); + } + ).set_env("LLAMA_LOG_COLORS")); + add_opt(common_arg( + {"-v", "--verbose", "--log-verbose"}, + "Set verbosity level to infinity (i.e. log all messages, useful for debugging)", + [](common_params & params) { + params.verbosity = INT_MAX; + common_log_set_verbosity_thold(INT_MAX); + } + )); + add_opt(common_arg( + {"-lv", "--verbosity", "--log-verbosity"}, "N", + "Set the verbosity threshold. Messages with a higher verbosity will be ignored.", + [](common_params & params, int value) { + params.verbosity = value; + common_log_set_verbosity_thold(value); + } + ).set_env("LLAMA_LOG_VERBOSITY")); + add_opt(common_arg( + {"--log-prefix"}, + "Enable prefx in log messages", + [](common_params &) { + common_log_set_prefix(common_log_main(), true); + } + ).set_env("LLAMA_LOG_PREFIX")); + add_opt(common_arg( + {"--log-timestamps"}, + "Enable timestamps in log messages", + [](common_params &) { + common_log_set_timestamps(common_log_main(), true); + } + ).set_env("LLAMA_LOG_TIMESTAMPS")); + + // speculative parameters + add_opt(common_arg( + {"-td", "--threads-draft"}, "N", + "number of threads to use during generation (default: same as --threads)", + [](common_params & params, int value) { + params.speculative.cpuparams.n_threads = value; + if (params.speculative.cpuparams.n_threads <= 0) { + params.speculative.cpuparams.n_threads = std::thread::hardware_concurrency(); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-tbd", "--threads-batch-draft"}, "N", + "number of threads to use during batch and prompt processing (default: same as --threads-draft)", + [](common_params & params, int value) { + params.speculative.cpuparams_batch.n_threads = value; + if (params.speculative.cpuparams_batch.n_threads <= 0) { + params.speculative.cpuparams_batch.n_threads = std::thread::hardware_concurrency(); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Cd", "--cpu-mask-draft"}, "M", + "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", + [](common_params & params, const std::string & mask) { + params.speculative.cpuparams.mask_valid = true; + if (!parse_cpu_mask(mask, params.speculative.cpuparams.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Crd", "--cpu-range-draft"}, "lo-hi", + "Ranges of CPUs for affinity. Complements --cpu-mask-draft", + [](common_params & params, const std::string & range) { + params.speculative.cpuparams.mask_valid = true; + if (!parse_cpu_range(range, params.speculative.cpuparams.cpumask)) { + throw std::invalid_argument("invalid range"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--cpu-strict-draft"}, "<0|1>", + "Use strict CPU placement for draft model (default: same as --cpu-strict)", + [](common_params & params, int value) { + params.speculative.cpuparams.strict_cpu = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--prio-draft"}, "N", + string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams.priority), + [](common_params & params, int prio) { + if (prio < 0 || prio > 3) { + throw std::invalid_argument("invalid value"); + } + params.speculative.cpuparams.priority = (enum ggml_sched_priority) prio; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--poll-draft"}, "<0|1>", + "Use polling to wait for draft model work (default: same as --poll])", + [](common_params & params, int value) { + params.speculative.cpuparams.poll = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Cbd", "--cpu-mask-batch-draft"}, "M", + "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", + [](common_params & params, const std::string & mask) { + params.speculative.cpuparams_batch.mask_valid = true; + if (!parse_cpu_mask(mask, params.speculative.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"-Crbd", "--cpu-range-batch-draft"}, "lo-hi", + "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)", + [](common_params & params, const std::string & range) { + params.speculative.cpuparams_batch.mask_valid = true; + if (!parse_cpu_range(range, params.speculative.cpuparams_batch.cpumask)) { + throw std::invalid_argument("invalid cpumask"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--cpu-strict-batch-draft"}, "<0|1>", + "Use strict CPU placement for draft model (default: --cpu-strict-draft)", + [](common_params & params, int value) { + params.speculative.cpuparams_batch.strict_cpu = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--prio-batch-draft"}, "N", + string_format("set draft process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.speculative.cpuparams_batch.priority), + [](common_params & params, int prio) { + if (prio < 0 || prio > 3) { + throw std::invalid_argument("invalid value"); + } + params.speculative.cpuparams_batch.priority = (enum ggml_sched_priority) prio; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--poll-batch-draft"}, "<0|1>", + "Use polling to wait for draft model work (default: --poll-draft)", + [](common_params & params, int value) { + params.speculative.cpuparams_batch.poll = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + add_opt(common_arg( + {"--draft-max", "--draft", "--draft-n"}, "N", + string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max), + [](common_params & params, int value) { + params.speculative.n_max = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DRAFT_MAX")); + add_opt(common_arg( + {"--draft-min", "--draft-n-min"}, "N", + string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.speculative.n_min), + [](common_params & params, int value) { + params.speculative.n_min = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DRAFT_MIN")); + add_opt(common_arg( + {"--draft-p-split"}, "P", + string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split), + [](common_params & params, const std::string & value) { + params.speculative.p_split = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT")); + add_opt(common_arg( + {"--draft-p-min"}, "P", + string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min), + [](common_params & params, const std::string & value) { + params.speculative.p_min = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_DRAFT_P_MIN")); + add_opt(common_arg( + {"-cd", "--ctx-size-draft"}, "N", + string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), + [](common_params & params, int value) { + params.speculative.n_ctx = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CTX_SIZE_DRAFT")); + add_opt(common_arg( + {"-devd", "--device-draft"}, "", + "comma-separated list of devices to use for offloading the draft model (none = don't offload)\n" + "use --list-devices to see a list of available devices", + [](common_params & params, const std::string & value) { + params.speculative.devices = parse_device_list(value); + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N", + "number of layers to store in VRAM for the draft model", + [](common_params & params, int value) { + params.speculative.n_gpu_layers = value; + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: no usable GPU found, --gpu-layers-draft option will be ignored\n"); + fprintf(stderr, "warning: one possible reason is that llama.cpp was compiled without GPU support\n"); + fprintf(stderr, "warning: consult docs/build.md for compilation instructions\n"); + } + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_GPU_LAYERS_DRAFT")); + add_opt(common_arg( + {"-md", "--model-draft"}, "FNAME", + "draft model for speculative decoding (default: unused)", + [](common_params & params, const std::string & value) { + params.speculative.model = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT")); + + add_opt(common_arg( + {"-mv", "--model-vocoder"}, "FNAME", + "vocoder model for audio generation (default: unused)", + [](common_params & params, const std::string & value) { + params.vocoder.model = value; + } + ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--tts-use-guide-tokens"}, + "Use guide tokens to improve TTS word recall", + [](common_params & params) { + params.vocoder.use_guide_tokens = true; + } + ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); + + // model-specific + add_opt(common_arg( + {"--tts-oute-default"}, + string_format("use default OuteTTS models (note: can download weights from the internet)"), + [](common_params & params) { + params.hf_repo = "OuteAI/OuteTTS-0.2-500M-GGUF"; + params.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf"; + params.vocoder.hf_repo = "ggml-org/WavTokenizer"; + params.vocoder.hf_file = "WavTokenizer-Large-75-F16.gguf"; + } + ).set_examples({LLAMA_EXAMPLE_TTS})); + + return ctx_arg; +} diff --git a/common/arg.h b/common/arg.h new file mode 100644 index 000000000..49ab8667b --- /dev/null +++ b/common/arg.h @@ -0,0 +1,80 @@ +#pragma once + +#include "common.h" + +#include +#include +#include + +// +// CLI argument parsing +// + +struct common_arg { + std::set examples = {LLAMA_EXAMPLE_COMMON}; + std::set excludes = {}; + std::vector args; + const char * value_hint = nullptr; // help text or example for arg value + const char * value_hint_2 = nullptr; // for second arg value + const char * env = nullptr; + std::string help; + bool is_sparam = false; // is current arg a sampling param? + void (*handler_void) (common_params & params) = nullptr; + void (*handler_string) (common_params & params, const std::string &) = nullptr; + void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr; + void (*handler_int) (common_params & params, int) = nullptr; + + common_arg( + const std::initializer_list & args, + const char * value_hint, + const std::string & help, + void (*handler)(common_params & params, const std::string &) + ) : args(args), value_hint(value_hint), help(help), handler_string(handler) {} + + common_arg( + const std::initializer_list & args, + const char * value_hint, + const std::string & help, + void (*handler)(common_params & params, int) + ) : args(args), value_hint(value_hint), help(help), handler_int(handler) {} + + common_arg( + const std::initializer_list & args, + const std::string & help, + void (*handler)(common_params & params) + ) : args(args), help(help), handler_void(handler) {} + + // support 2 values for arg + common_arg( + const std::initializer_list & args, + const char * value_hint, + const char * value_hint_2, + const std::string & help, + void (*handler)(common_params & params, const std::string &, const std::string &) + ) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {} + + common_arg & set_examples(std::initializer_list examples); + common_arg & set_excludes(std::initializer_list excludes); + common_arg & set_env(const char * env); + common_arg & set_sparam(); + bool in_example(enum llama_example ex); + bool is_exclude(enum llama_example ex); + bool get_value_from_env(std::string & output); + bool has_value_from_env(); + std::string to_string(); +}; + +struct common_params_context { + enum llama_example ex = LLAMA_EXAMPLE_COMMON; + common_params & params; + std::vector options; + void(*print_usage)(int, char **) = nullptr; + common_params_context(common_params & params) : params(params) {} +}; + +// parse input arguments from CLI +// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message) +bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); + +// function to be used by test-arg-parser +common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/common/chat-template.hpp b/common/chat-template.hpp new file mode 100644 index 000000000..75ba5d938 --- /dev/null +++ b/common/chat-template.hpp @@ -0,0 +1,366 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include "minja.hpp" +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace minja { + +struct chat_template_caps { + bool supports_tools = false; + bool supports_tool_calls = false; + bool supports_tool_responses = false; + bool supports_system_role = false; + bool supports_parallel_tool_calls = false; + bool supports_tool_call_id = false; + // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + bool requires_object_arguments = false; + // CohereForAI/c4ai-command-r-plus simple variant + bool requires_non_null_content = false; + // MiniMaxAI/MiniMax-Text-01 special + bool requires_typed_content = false; +}; + +class chat_template { + + private: + chat_template_caps caps_; + std::string source_; + std::string bos_token_; + std::string eos_token_; + std::shared_ptr template_root_; + + std::string try_raw_render( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const + { + try { + auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false); + // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); + return prompt; + } catch (const std::exception & e) { + // fprintf(stderr, "try_raw_render error: %s\n", e.what()); + return ""; + } + } + + public: + + chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) + : source_(source), bos_token_(bos_token), eos_token_(eos_token) + { + template_root_ = minja::Parser::parse(source_, { + /* .trim_blocks = */ true, + /* .lstrip_blocks = */ true, + /* .keep_trailing_newline = */ false, + }); + + auto contains = [](const std::string & haystack, const std::string & needle) { + return haystack.find(needle) != std::string::npos; + }; + + const std::string user_needle = ""; + const std::string sys_needle = ""; + const json dummy_str_user_msg = {{"role", "user"}, {"content", user_needle}}; + const json dummy_typed_user_msg = {{"role", "user"}, {"content", json::array({{{"type", "text"}, {"text", user_needle}}})}}; + + caps_.requires_typed_content = + !contains(try_raw_render(json::array({dummy_str_user_msg}), {}, false), user_needle) + && contains(try_raw_render(json::array({dummy_typed_user_msg}), {}, false), user_needle); + + const auto dummy_user_msg = caps_.requires_typed_content + ? dummy_typed_user_msg + : dummy_str_user_msg; + const json needle_system_msg = { + {"role", "system"}, + {"content", caps_.requires_typed_content ? json::array({{{"type", "text"}, {"text", sys_needle}}}) : json(sys_needle)}, + }; + + caps_.supports_system_role = contains(try_raw_render({needle_system_msg, dummy_user_msg,}, {}, false), sys_needle); + + auto out = try_raw_render(json::array({ + dummy_user_msg + }), json::array({ + { + {"name", "some_tool"}, + {"type", "function"}, + {"function", { + {"name", "some_tool"}, + {"description", "Some tool."}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"arg", { + {"type", "string"}, + {"description", "Some argument."}, + }}, + }}, + {"required", json::array({ "arg" })}, + }}, + }}, + }, + }), false); + caps_.supports_tools = contains(out, "some_tool"); + + auto make_tool_calls_msg = [&](const json & tool_calls) { + return json { + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", tool_calls}, + }; + }; + auto make_tool_call = [](const std::string & tool_name, const json & arguments) { + return json { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", arguments}, + {"name", tool_name}, + }}, + }; + }; + const json dummy_args_obj {{"argument_needle", "print('Hello, World!')"}}; + + // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want. + out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj.dump())})), + }), {}, false); + auto tool_call_renders_str_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({make_tool_call("ipython", dummy_args_obj)})), + }), {}, false); + auto tool_call_renders_obj_arguments = contains(out, "\"argument_needle\":") || contains(out, "'argument_needle':"); + + caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; + caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; + auto out_empty = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", ""}}}), {}, false); + auto out_null = try_raw_render(json::array({dummy_user_msg, {{"role", "assistant"}, {"content", nullptr}}}), {}, false); + caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); + + if (caps_.supports_tool_calls) { + auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump()); + auto tc1 = make_tool_call("test_tool1", dummy_args); + auto tc2 = make_tool_call("test_tool2", dummy_args); + auto out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({tc1, tc2})), + }), {}, false); + caps_.supports_parallel_tool_calls = contains(out, "test_tool1") && contains(out, "test_tool2"); + + out = try_raw_render(json::array({ + dummy_user_msg, + make_tool_calls_msg(json::array({tc1})), + { + {"role", "tool"}, + {"name", "test_tool1"}, + {"content", "Some response!"}, + {"tool_call_id", "call_911_"}, + } + }), {}, false); + caps_.supports_tool_responses = contains(out, "Some response!"); + caps_.supports_tool_call_id = contains(out, "call_911_"); + } + } + + const std::string & source() const { return source_; } + const std::string & bos_token() const { return bos_token_; } + const std::string & eos_token() const { return eos_token_; } + const chat_template_caps & original_caps() const { return caps_; } + + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(), + bool adjust_inputs = true) const + { + json actual_messages; + + auto needs_adjustments = adjust_inputs && (false + || !caps_.supports_system_role + || !caps_.supports_tools + || !caps_.supports_tool_responses + || !caps_.supports_tool_calls + || caps_.requires_object_arguments + || caps_.requires_typed_content + ); + if (needs_adjustments) { + actual_messages = json::array(); + + auto add_message = [&](const json & msg) { + if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) { + actual_messages.push_back({ + {"role", msg.at("role")}, + {"content", {{ + {"type", "text"}, + {"text", msg.at("content")}, + }}}, + }); + } else { + actual_messages.push_back(msg); + } + }; + + std::string pending_system; + auto flush_sys = [&]() { + if (!pending_system.empty()) { + add_message({ + {"role", "user"}, + {"content", pending_system}, + }); + pending_system.clear(); + } + }; + auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools; + + for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) { + auto message = message_; + if (!message.contains("role") || !message.contains("content")) { + throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + } + std::string role = message.at("role"); + + if (message.contains("tool_calls")) { + if (caps_.requires_object_arguments || !caps_.supports_tool_calls) { + for (auto & tool_call : message.at("tool_calls")) { + if (tool_call["type"] == "function") { + auto & function = tool_call.at("function"); + auto & arguments = function.at("arguments"); + if (arguments.is_string()) { + try { + arguments = json::parse(arguments.get()); + } catch (const std::exception & ecvt) { + fprintf(stderr, "Failed to parse arguments: %s\n", ecvt.what()); + } + } + } + } + } + if (!caps_.supports_tool_calls) { + auto content = message.at("content"); + auto tool_calls = json::array(); + for (const auto & tool_call : message.at("tool_calls")) { + if (tool_call.at("type") != "function") { + continue; + } + const auto & function = tool_call.at("function"); + auto tc = json { + {"name", function.at("name")}, + {"arguments", function.at("arguments")}, + }; + if (tool_call.contains("id")) { + tc["id"] = tool_call["id"]; + } + tool_calls.push_back(tc); + } + auto obj = json { + {"tool_calls", tool_calls}, + }; + if (!content.is_null() && content != "") { + obj["content"] = content; + } + message["content"] = obj.dump(2); + message.erase("tool_calls"); + } + } + if (!caps_.supports_tool_responses && role == "tool") { + message["role"] = "user"; + auto obj = json { + {"tool_response", { + {"tool", message.at("name")}, + {"content", message.at("content")}, + }}, + }; + if (message.contains("tool_call_id")) { + obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); + } + message["content"] = obj.dump(2); + message.erase("name"); + } + + if (!message["content"].is_null() && !caps_.supports_system_role) { + std::string content = message.at("content"); + if (role == "system") { + if (!pending_system.empty()) pending_system += "\n"; + pending_system += content; + continue; + } else { + if (role == "user") { + if (!pending_system.empty()) { + message["content"] = pending_system + (content.empty() ? "" : "\n" + content); + pending_system.clear(); + } + } else { + flush_sys(); + } + } + } + add_message(message); + } + if (!caps_.supports_system_role) { + flush_sys(); + } + } else { + actual_messages = messages; + } + + auto context = minja::Context::make(json({ + {"messages", actual_messages}, + {"add_generation_prompt", add_generation_prompt}, + {"bos_token", bos_token_}, + {"eos_token", eos_token_}, + })); + + if (!tools.is_null()) { + auto tools_val = minja::Value(tools); + context->set("tools", tools_val); + } + if (!extra_context.is_null()) { + for (auto & kv : extra_context.items()) { + minja::Value val(kv.value()); + context->set(kv.key(), val); + } + } + + auto ret = template_root_->render(context); + // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str()); + // fprintf(stderr, "apply: %s\n\n", ret.c_str()); + return ret; + } + + static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { + json messages_with_system = messages; + + if (messages_with_system.size() > 0 && messages_with_system[0].at("role") == "system") { + std::string existing_system = messages_with_system.at(0).at("content"); + messages_with_system[0] = json { + {"role", "system"}, + {"content", existing_system + "\n" + system_prompt}, + }; + } else { + messages_with_system.insert(messages_with_system.begin(), json { + {"role", "system"}, + {"content", system_prompt}, + }); + } + return messages_with_system; + } +}; + +} // namespace minja diff --git a/common/chat.cpp b/common/chat.cpp new file mode 100644 index 000000000..d9a654892 --- /dev/null +++ b/common/chat.cpp @@ -0,0 +1,848 @@ +#include "chat.hpp" +#include "chat-template.hpp" +#include "json-schema-to-grammar.h" +#include "log.h" +#include "minja.hpp" + +std::string common_chat_format_name(common_chat_format format) { + switch (format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; + case COMMON_CHAT_FORMAT_GENERIC: return "Generic"; + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo"; + case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; + case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; + default: + throw std::runtime_error("Unknown chat format"); + } +} + +const common_grammar_options grammar_options { + /* .dotall = */ false, + /* .compact_spaces = */ false, + // /* .compact_spaces = */ true, +}; + +static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { + // // https://json.nlohmann.me/features/parsing/sax_interface/ + struct json_error_locator : public nlohmann::json_sax { + std::size_t position; + bool found_error; + + json_error_locator() : position(0), found_error(false) {} + + bool parse_error(std::size_t position, const std::string &, const json::exception &) override { + this->position = position - 1; + this->found_error = true; + return false; + } + bool null() override { return true; } + bool boolean(bool) override { return true; } + bool number_integer(number_integer_t) override { return true; } + bool number_unsigned(number_unsigned_t) override { return true; } + bool number_float(number_float_t, const string_t &) override { return true; } + bool string(string_t &) override { return true; } + bool binary(binary_t &) override { return true; } + bool start_object(std::size_t) override { return true; } + bool key(string_t &) override { return true; } + bool end_object() override { return true; } + bool start_array(std::size_t) override { return true; } + bool end_array() override { return true; } + }; + json_error_locator err_loc; + json::sax_parse(it, end, &err_loc); + + std::string::const_iterator temptative_end; + if (err_loc.found_error) { + temptative_end = it + err_loc.position; + } else { + temptative_end = end; + } + std::string json_sub {it, temptative_end}; + try { + out = json::parse(json_sub); + it = temptative_end; + return true; + } catch (const std::exception &) { + return false; + } +} + + +/** + * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. + * Aggregates the prefix, suffix and in-between text into the content. + */ +static common_chat_msg parse_json_tool_calls( + const std::string& input, + const std::optional & trigger_opt, + const std::regex & function_regex, + const std::regex & close_regex) { + std::smatch match; + + common_chat_msg result; + result.role = "assistant"; + + + auto end = input.end(); + auto it = input.begin(); + + if (trigger_opt) { + if (!std::regex_search(it, end, match, *trigger_opt)) { + result.content = input; + return result; + } + result.content = match.prefix().str(); + it = match.suffix().first; + } + + while (it != end) { + std::sregex_iterator rend; + std::sregex_iterator rit(it, end, function_regex); + if (rit == rend) { + fprintf(stderr, "No more tool calls found\n"); + result.content += std::string(it, end); + break; + } + auto name = rit->str(1); + result.content += std::string(it, rit->prefix().second); + it = rit->suffix().first; + + json arguments; + if (!parse_json(it, end, arguments)) { + throw std::runtime_error("Failed to parse json tool call arguments"); + } + if (!std::regex_search(it, end, match, close_regex)) { + throw std::runtime_error("Malformed input, missing closing pattern"); + } + it = match.suffix().first; + result.tool_calls.push_back({name, arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ ""}); + } + return result; +} + +static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { + auto content_end = input.find(prefix); + size_t tc_start = std::string::npos; + + common_chat_msg result; + result.role = "assistant"; + const auto process_tool_calls = [&](const json & tool_calls) { + for (const auto & tool_call : tool_calls) { + const auto & arguments = tool_call["arguments"]; + result.tool_calls.push_back({ + tool_call["name"], + arguments.is_string() ? arguments.get() : arguments.dump(), + tool_call.contains("id") ? tool_call["id"] : "", + }); + } + }; + if (content_end == std::string::npos) { + result.content = input; + } else { + tc_start = content_end + prefix.size() - rstrip_prefix; + result.content = input.substr(0, content_end); + auto tool_calls = json::parse(input.substr(tc_start)); + process_tool_calls(tool_calls); + } + return result; +} + +static void foreach_function(const json & tools, const std::function & fn) { + for (const auto & tool : tools) { + if (!tool.contains("type") || tool["type"] != "function" || !tool.contains("function")) { + LOG_INF("Skipping tool without function: %s", tool.dump(2).c_str()); + continue; + } + fn(tool); + } +} + +static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + + auto tool_call_schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + auto tool_schema = json { + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + }}, + {"required", json::array({"name", "arguments"})}, + }; + if (function.contains("description")) { + tool_schema["description"] = function["description"]; + } + if (inputs.parallel_tool_calls) { + tool_schema["properties"]["id"] = { + {"type", "string"}, + {"minLength", 4}, + }; + tool_schema["required"].push_back("id"); + } + tool_call_schemas.emplace_back(tool_schema); + }); + const auto tool_call = + inputs.parallel_tool_calls + ? json { + {"type", "object"}, + {"properties", { + {"tool_calls", { + {"type", "array"}, + {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { + {"anyOf", tool_call_schemas}, + }}, + {"minItems", 1}, + }}, + }}, + {"required", json::array({"tool_calls"})}, + } + : json { + {"type", "object"}, + {"properties", { + {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { + {"anyOf", tool_call_schemas}, + }}, + }}, + {"required", json::array({"tool_call"})}, + }; + const auto schema = + inputs.tool_choice != "required" + ? json { + {"anyOf", json::array({ + tool_call, + { + {"type", "object"}, + {"properties", { + {"response", inputs.json_schema.is_null() + ? json {{"type", "string"}} + : inputs.json_schema + }, + }}, + {"required", json::array({"response"})}, + }, + })} + } + : tool_call; + + data.grammar_lazy = false; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + builder.add_schema("root", schema); + }, grammar_options); + + auto tweaked_messages = common_chat_template::add_system( + inputs.messages, + "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request"); + + data.prompt = tmpl.apply(tweaked_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_GENERIC; + return data; +} +static common_chat_msg common_chat_parse_generic(const std::string & input) { + json data = json::parse(input); + common_chat_msg result; + result.role = "assistant"; + if (data.contains("tool_calls")) { + for (const auto & tool_call : data["tool_calls"]) { + result.tool_calls.push_back({ + tool_call["name"], + tool_call["arguments"].dump(), + tool_call.contains("id") ? tool_call["id"] : "", + }); + } + } else if (data.contains("tool_call")) { + result.tool_calls.push_back({ + data["tool_call"]["name"], + data["tool_call"]["arguments"].dump(), + /* id= */ "", + }); + } else if (data.contains("response")) { + const auto & response = data["response"]; + result.content = response.is_string() ? response.get() : response.dump(2); + } + return result; +} + +static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + schemas.push_back({ + {"type", "object"}, + {"properties", { + // Important note: the model is probably trained to take a JSON stringified arguments value. + // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + {"id", { + {"type", "string"}, + // Nemo's template expects a 9-character alphanumeric ID. + {"pattern", "^[a-zA-Z0-9]{9}$"}, + }}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); + }, grammar_options); + data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; + return data; +} +static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) { + return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); +} + +static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { + if (!parameters.is_object() || !parameters.contains("type") || parameters["type"] != "object" || !parameters.contains("properties") || !parameters.contains("required")) { + throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); + } + const auto & parameters_properties = parameters.at("properties"); + const auto & parameters_required = parameters.at("required"); + for (const auto & prop : expected_properties) { + if (!parameters_properties.contains(prop)) { + throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); + } + if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { + throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); + } + } + if (parameters_properties.size() != expected_properties.size()) { + throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", ")); + } +} + +static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) { + auto builtin_tools = json::array(); + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + + auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { + if (name == "wolfram_alpha") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py + expect_tool_parameters(name, parameters, {"query"}); + } else if (name == "web_search" || name == "brave_search") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py + expect_tool_parameters(name, parameters, {"query"}); + } else if (name == "python" || name == "code_interpreter") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py + expect_tool_parameters(name, parameters, {"code"}); + } else { + return false; + } + + std::vector kvs; + for (const auto & [key, value] : parameters.at("properties").items()) { + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); + } + + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\"")); + builtin_tools.push_back(name); + + return true; + }; + + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + + // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime + if (allow_python_tag_builtin_tools) { + handle_builtin_tool(name, parameters); + } + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) " + "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + builder.add_schema(name + "-args", parameters) + + " \"}\"")); + data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); + }); + data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true}); + if (!builtin_tools.empty()) { + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + } + builder.add_rule("root", string_join(tool_rules, " | ")); + }, grammar_options); + data.additional_stops.push_back("<|eom_id|>"); + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { + {"tools_in_user_message", false}, + {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, + }); + data.format = allow_python_tag_builtin_tools && !builtin_tools.empty() + ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS + : COMMON_CHAT_FORMAT_LLAMA_3_X; + return data; +} +static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { + // TODO: tighten & simplify the parser, don't accept leading text context. + static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); + static std::regex close_regex("\\}"); + static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)"); + + if (with_builtin_tools) { + std::smatch match; + if (std::regex_match(input, match, builtin_call_regex)) { + auto name = match[1].str(); + auto raw_args = match[2].str(); + + // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing. + auto it_eq = raw_args.find('='); + auto arg_name = raw_args.substr(0, it_eq); + auto arg_value_str = raw_args.substr(it_eq + 1); + auto arg_value = json::parse(arg_value_str); + + return { + /* .role = */ "assistant", + /* .content = */ match.prefix().str(), + /* .tool_calls = */ { + { + /* .name = */ match[1], + /* .arguments = */ (json { + {arg_name, arg_value}, + }).dump(), + /* .id = */ "", + }, + }, + }; + } + } + return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); +} + +static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto args_rule = builder.add_schema(name + "-args", parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); + }); + data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); + builder.add_rule("root", "\"<|tool▁calls▁begin|>\" (" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " space"); + }, grammar_options); + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; + return data; +} +static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input) { + static std::regex trigger_regex("<|tool▁calls▁begin|>"); + static std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); + static std::regex close_regex("```<|tool▁call▁end|>"); + return parse_json_tool_calls(input, trigger_regex, function_regex, close_regex); +} + +static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + fprintf(stderr, "%s\n", __func__); + common_chat_params data; + data.prompt = tmpl.apply(inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { + {"datetime", "Jan 29 2025 13:00:00 GMT"}, + {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, + }, /* adjust_inputs= */ false); + if (!inputs.tools.is_null() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function["name"]}, + }}, + {"arguments", function["parameters"]}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); + }, grammar_options); + data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); + data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2; + } else { + data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + } + return data; +} +static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) { + return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); +} + +static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... + // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar + common_chat_params data; + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; + if (!inputs.tools.is_null() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector first_tool_rules; + std::vector subsequent_tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto args_rule = builder.add_schema(name + "-args", parameters); + first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); + subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); + data.grammar_triggers.push_back({name, /* .at_start = */ true}); + data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); + }); + auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; + if (inputs.parallel_tool_calls) { + auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; + builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); + } else { + builder.add_rule("root", first_rule); + } + + }, grammar_options); + } + return data; +} + +static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) { + auto expected_it = expected.begin(); + auto tmp_it = it; + while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) { + ++tmp_it; + ++expected_it; + } + if (expected_it == expected.end()) { + it = tmp_it; + return true; + } + return false; +} + +static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) { + static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); + static std::regex close_regex(R"($|(?=>>>))"); + + std::string content; + auto it = input.begin(); + const auto end = input.end(); + + if (consume(it, end, "all\n")) { + std::smatch match; + if (std::regex_search(it, end, match, function_regex)) { + auto fun_it = match.prefix().second; + content = std::string(it, fun_it); + it = fun_it; + } else { + common_chat_msg res; + res.role = "assistant"; + res.content = std::string(it, end); + return res; + } + } + // TODO: tighten & simplify. + auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex); + res.content = content; + return res; +} + +static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt + common_chat_params data; + json tools = inputs.tools.is_null() ? inputs.tools : json::array(); + std::string python_code_argument_name; + auto has_raw_python = false; + + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + const auto & parameters = function["parameters"]; + std::string name = function["name"]; + if (name == "python" || name == "ipython") { + if (!parameters.contains("type")) { + throw std::runtime_error("Missing type in python tool"); + } + has_raw_python = true; + auto type = parameters.at("type"); + if (type == "object") { + auto properties = parameters.at("properties"); + for (auto it = properties.begin(); it != properties.end(); ++it) { + if (it.value().at("type") == "string") { + if (!python_code_argument_name.empty()) { + throw std::runtime_error("Multiple string arguments found in python tool"); + } + python_code_argument_name = it.key(); + } + } + if (python_code_argument_name.empty()) { + throw std::runtime_error("No string argument found in python tool"); + } + } else if (type != "string") { + throw std::runtime_error("Invalid type in python tool: " + type.dump()); + } + } + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); + }); + if (has_raw_python) { + tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); + data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + } + auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; + builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + data.grammar_triggers.push_back({"([\s\S\n]*)$)"); + std::smatch match; + if (std::regex_search(input, match, python_tag_regex)) { + auto code = match[1].str(); + return { + /* .role = */ "assistant", + /* .content = */ match.prefix().str(), + /* .tool_calls = */ { + { + /* .name = */ "python", + /* .arguments = */ (json {{"code", code}}).dump(), + /* .id = */ "", + }, + } + }; + } + static std::regex function_regex(R"()"); + static std::regex close_regex(R"()"); + // TODO: tighten & simplify. + return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); +} + +static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + // (content)?({"name": "foo", "arguments": {"a": 1}})* + data.grammar_lazy = inputs.tool_choice != "required"; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); + }); + auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; + builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + data.grammar_triggers.push_back({"", /* .at_start = */ false}); + // Not really a trigger but need to print this special token to get a successful parse. + data.grammar_triggers.push_back({"", /* .at_start = */ false}); + }, grammar_options); + + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; + return data; +} +static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) { + try { + std::regex start_pattern(R"([\n\s]*)"); + std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); + std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); + + auto end = input.end(); + std::sregex_iterator rend; + std::sregex_iterator rit(input.begin(), end, start_pattern); + if (rit == rend) { + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; + } + + common_chat_msg result; + result.role = "assistant"; + result.content = rit->prefix(); + + auto it = rit->suffix().first; + while (it != end) { + json call; + if (!parse_json(it, end, call)) { + throw std::runtime_error("Failed to parse json tool call"); + } + const auto & arguments = call["arguments"]; + result.tool_calls.push_back({ + call["name"], + arguments.dump(), + // arguments.is_string() ? arguments.get() : arguments.dump(), + /* id= */ "", + }); + rit = {it, end, middle_pattern}; + if (rit != rend) { + it = rit->suffix().first; + } else { + rit = {it, end, end_pattern}; + if (rit == rend) { + throw std::runtime_error("Malformed input, missing "); + } + break; + } + } + return result; + } catch (const std::exception & e) { + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; + } +} + +static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + common_chat_params data; + data.prompt = tmpl.apply(inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); + data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + data.grammar_lazy = false; + if (!inputs.json_schema.is_null()) { + if (!inputs.grammar.empty()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + data.grammar = json_schema_to_grammar(inputs.json_schema); + } else { + data.grammar = inputs.grammar.empty(); + } + return data; +} + +common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { + auto has_tools = !inputs.tools.is_null() && inputs.tool_choice != "none"; + LOG_DBG("[%s] has_tools=%s\n", __func__, has_tools ? "true" : "false"); + + if (has_tools && !inputs.grammar.empty()) { + throw std::runtime_error("Cannot specify grammar with tools"); + } + + const auto & src = tmpl.source(); + if (src.find(">>>all") != std::string::npos) { + // Functionary prepends "all\n" to plain content outputs, so we use the parser no matter when + return common_chat_params_init_functionary_v3_2(tmpl, inputs); + } + if (src.find(" functools[") != std::string::npos) { + // Firefunction v2 requires datetime and functions in the context, even w/o tools. + return common_chat_params_init_firefunction_v2(tmpl, inputs); + } + + if (!has_tools) { + return common_chat_params_init_without_tools(tmpl, inputs); + } + + if (src.find("") != std::string::npos) { + return common_chat_params_init_hermes_2_pro(tmpl, inputs); + } + if (src.find("<|start_header_id|>") != std::string::npos + && src.find("ipython<|end_header_id|>") != std::string::npos) { + auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; + return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); + } + if (src.find("<|tool▁calls▁begin|>") != std::string::npos) { + return common_chat_params_init_deepseek_r1(tmpl, inputs); + } + if (src.find("[TOOL_CALLS]") != std::string::npos) { + return common_chat_params_init_mistral_nemo(tmpl, inputs); + } + return common_chat_params_init_generic(tmpl, inputs); +} + +static common_chat_msg common_chat_parse_content_only(const std::string & input) { + return { + /* .role = */ "assistant", + /* .content = */ input, + /* .tool_calls = */ {}, + }; +} + +common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) { + switch (format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + return common_chat_parse_content_only(input); + case COMMON_CHAT_FORMAT_GENERIC: + return common_chat_parse_generic(input); + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: + return common_chat_parse_mistral_nemo(input); + case COMMON_CHAT_FORMAT_LLAMA_3_X: + return common_chat_parse_llama_3_1(input); + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: + return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true); + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: + return common_chat_parse_deepseek_r1(input); + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: + return common_chat_parse_functionary_v3_2(input); + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: + return common_chat_parse_functionary_v3_1_llama_3_1(input); + case COMMON_CHAT_FORMAT_HERMES_2_PRO: + return common_chat_parse_hermes_2_pro(input); + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: + return common_chat_parse_firefunction_v2(input); + default: + throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); + } +} diff --git a/common/chat.hpp b/common/chat.hpp new file mode 100644 index 000000000..ca165aa13 --- /dev/null +++ b/common/chat.hpp @@ -0,0 +1,50 @@ +// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers. + +#pragma once + +#include "common.h" +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +struct common_chat_inputs { + json messages; + json tools; + json tool_choice; + json json_schema; + bool parallel_tool_calls; + bool stream; + std::string grammar; + bool add_generation_prompt = true; +}; + +enum common_chat_format { + COMMON_CHAT_FORMAT_CONTENT_ONLY, + COMMON_CHAT_FORMAT_GENERIC, + COMMON_CHAT_FORMAT_MISTRAL_NEMO, + COMMON_CHAT_FORMAT_LLAMA_3_X, + COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, + COMMON_CHAT_FORMAT_DEEPSEEK_R1, + COMMON_CHAT_FORMAT_FIREFUNCTION_V2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, + COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, + COMMON_CHAT_FORMAT_HERMES_2_PRO, + + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats +}; + +struct common_chat_params { + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + json prompt; + std::string grammar; + bool grammar_lazy = false; + std::vector grammar_triggers; + std::vector additional_stops; +}; + +struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params); +std::string common_chat_format_name(common_chat_format format); +common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); diff --git a/common/common.cpp b/common/common.cpp index 9228eafca..6c81d18f9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2,30 +2,38 @@ #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING #endif +#include "ggml.h" +#include "gguf.h" + #include "common.h" +#include "log.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" +#include "chat.hpp" +#include "chat-template.hpp" #include #include +#include #include #include #include #include #include +#include #include #include #include #include #include #include +#include #include #include #include -#include #if defined(__APPLE__) && defined(__MACH__) #include @@ -49,7 +57,6 @@ #if defined(LLAMA_USE_CURL) #include #include -#include #include #endif @@ -57,23 +64,33 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -#if (defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)) -#define GGML_USE_CUDA_SYCL -#endif - -#if (defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)) || defined(GGML_USE_VULKAN) -#define GGML_USE_CUDA_SYCL_VULKAN -#endif - #if defined(LLAMA_USE_CURL) #ifdef __linux__ #include #elif defined(_WIN32) -#define PATH_MAX MAX_PATH +# if !defined(PATH_MAX) +# define PATH_MAX MAX_PATH +# endif #else #include #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + +// +// CURL utils +// + +using curl_ptr = std::unique_ptr; + +// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one +struct curl_slist_ptr { + struct curl_slist * ptr = nullptr; + ~curl_slist_ptr() { + if (ptr) { + curl_slist_free_all(ptr); + } + } +}; #endif // LLAMA_USE_CURL using json = nlohmann::ordered_json; @@ -235,7 +252,7 @@ bool set_process_priority(enum ggml_sched_priority prio) { } if (!SetPriorityClass(GetCurrentProcess(), p)) { - fprintf(stderr, "warn: failed to set process priority class %d : (%d)\n", prio, (int) GetLastError()); + LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError()); return false; } @@ -260,7 +277,7 @@ bool set_process_priority(enum ggml_sched_priority prio) { } if (!setpriority(PRIO_PROCESS, 0, p)) { - fprintf(stderr, "warn: failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno); + LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno); return false; } return true; @@ -272,53 +289,6 @@ bool set_process_priority(enum ggml_sched_priority prio) { // CLI argument parsing // -#ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) -#else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) -#endif - -LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) -static std::string format(const char * fmt, ...) { - va_list ap; - va_list ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - GGML_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), size); -} - -static void gpt_params_handle_model_default(gpt_params & params) { - if (!params.hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (params.hf_file.empty()) { - if (params.model.empty()) { - throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); - } - params.hf_file = params.model; - } else if (params.model.empty()) { - params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); - } - } else if (!params.model_url.empty()) { - if (params.model.empty()) { - auto f = string_split(params.model_url, '#').front(); - f = string_split(f, '?').front(); - params.model = fs_get_cache_file(string_split(f, '/').back()); - } - } else if (params.model.empty()) { - params.model = DEFAULT_MODEL_PATH; - } -} void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) { int32_t n_set = 0; @@ -340,158 +310,14 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) if (n_set && n_set < cpuparams.n_threads) { // Not enough set bits, may experience performance issues. - fprintf(stderr, "warn: Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads); + LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads); } } -bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vector & options) { - std::string arg; - const std::string arg_prefix = "--"; - gpt_sampler_params & sparams = params.sparams; - - std::unordered_map arg_to_options; - for (auto & opt : options) { - for (const auto & arg : opt.args) { - arg_to_options[arg] = &opt; - } - } - - // handle environment variables - for (auto & opt : options) { - std::string value; - if (opt.get_value_from_env(value)) { - try { - if (opt.handler_void && (value == "1" || value == "true")) { - opt.handler_void(params); - } - if (opt.handler_int) { - opt.handler_int(params, std::stoi(value)); - } - if (opt.handler_string) { - opt.handler_string(params, value); - continue; - } - } catch (std::exception & e) { - throw std::invalid_argument(format( - "error while handling environment variable \"%s\": %s\n\n", opt.env, e.what())); - } - } - } - - // handle command line arguments - auto check_arg = [&](int i) { - if (i+1 >= argc) { - throw std::invalid_argument("expected value for argument"); - } - }; - - for (int i = 1; i < argc; i++) { - const std::string arg_prefix = "--"; - - std::string arg = argv[i]; - if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { - std::replace(arg.begin(), arg.end(), '_', '-'); - } - if (arg_to_options.find(arg) == arg_to_options.end()) { - throw std::invalid_argument(format("error: invalid argument: %s", arg.c_str())); - } - auto opt = *arg_to_options[arg]; - if (opt.has_value_from_env()) { - fprintf(stderr, "warn: %s environment variable is set, but will be overwritten by command line argument %s\n", opt.env, arg.c_str()); - } - try { - if (opt.handler_void) { - opt.handler_void(params); - continue; - } - - // arg with single value - check_arg(i); - std::string val = argv[++i]; - if (opt.handler_int) { - opt.handler_int(params, std::stoi(val)); - continue; - } - if (opt.handler_string) { - opt.handler_string(params, val); - continue; - } - - // arg with 2 values - check_arg(i); - std::string val2 = argv[++i]; - if (opt.handler_str_str) { - opt.handler_str_str(params, val, val2); - continue; - } - } catch (std::exception & e) { - throw std::invalid_argument(format( - "error while handling argument \"%s\": %s\n\n" - "usage:\n%s\n\nto show complete usage, run with -h", - arg.c_str(), e.what(), arg_to_options[arg]->to_string().c_str())); - } - } - - postprocess_cpu_params(params.cpuparams, nullptr); - postprocess_cpu_params(params.cpuparams_batch, ¶ms.cpuparams); - postprocess_cpu_params(params.draft_cpuparams, ¶ms.cpuparams); - postprocess_cpu_params(params.draft_cpuparams_batch, ¶ms.cpuparams_batch); - - if (params.prompt_cache_all && (params.interactive || params.interactive_first)) { - throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); - } - - gpt_params_handle_model_default(params); - - if (params.escape) { - string_process_escapes(params.prompt); - string_process_escapes(params.input_prefix); - string_process_escapes(params.input_suffix); - for (auto & antiprompt : params.antiprompt) { - string_process_escapes(antiprompt); - } - } - - if (!params.kv_overrides.empty()) { - params.kv_overrides.emplace_back(); - params.kv_overrides.back().key[0] = 0; - } - - if (sparams.seed == LLAMA_DEFAULT_SEED) { - sparams.seed = time(NULL); - } - - return true; -} - -bool gpt_params_parse(int argc, char ** argv, gpt_params & params, std::vector & options) { - const auto params_org = params; // the example can modify the default params - - try { - if (!gpt_params_parse_ex(argc, argv, params, options)) { - params = params_org; - return false; - } - if (params.usage) { - gpt_params_print_usage(params, options); - if (params.print_usage) { - params.print_usage(argc, argv); - } - exit(0); - } - } catch (const std::invalid_argument & ex) { - fprintf(stderr, "%s\n", ex.what()); - params = params_org; - return false; - } - - return true; -} - bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) { size_t dash_loc = range.find('-'); if (dash_loc == std::string::npos) { - fprintf(stderr, "Format of CPU range is invalid! Expected []-[].\n"); + LOG_ERR("Format of CPU range is invalid! Expected []-[].\n"); return false; } @@ -503,7 +329,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE } else { start_i = std::stoull(range.substr(0, dash_loc)); if (start_i >= GGML_MAX_N_THREADS) { - fprintf(stderr, "Start index out of bounds!\n"); + LOG_ERR("Start index out of bounds!\n"); return false; } } @@ -513,7 +339,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE } else { end_i = std::stoull(range.substr(dash_loc + 1)); if (end_i >= GGML_MAX_N_THREADS) { - fprintf(stderr, "End index out of bounds!\n"); + LOG_ERR("End index out of bounds!\n"); return false; } } @@ -548,7 +374,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD } else if (c >= 'A' && c <= 'F') { id -= 'A' - 10; } else { - fprintf(stderr, "Invalid hex character '%c' at position %d\n", c, int32_t(i)); + LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i)); return false; } @@ -561,1703 +387,23 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD return true; } -static std::vector break_str_into_lines(std::string input, size_t max_char_per_line) { - std::vector result; - std::istringstream iss(input); - std::string line; - auto add_line = [&](const std::string& l) { - if (l.length() <= max_char_per_line) { - result.push_back(l); - } else { - std::istringstream line_stream(l); - std::string word, current_line; - while (line_stream >> word) { - if (current_line.length() + !current_line.empty() + word.length() > max_char_per_line) { - if (!current_line.empty()) result.push_back(current_line); - current_line = word; - } else { - current_line += (!current_line.empty() ? " " : "") + word; - } - } - if (!current_line.empty()) result.push_back(current_line); +void common_init() { + llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) { + if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) { + common_log_add(common_log_main(), level, "%s", text); } - }; - while (std::getline(iss, line)) { - add_line(line); - } - return result; -} + }, NULL); -std::string llama_arg::to_string() { - // params for printing to console - const static int n_leading_spaces = 40; - const static int n_char_per_line_help = 70; // TODO: detect this based on current console - std::string leading_spaces(n_leading_spaces, ' '); - - std::ostringstream ss; - for (const auto arg : args) { - if (arg == args.front()) { - if (args.size() == 1) { - ss << arg; - } else { - // first arg is usually abbreviation, we need padding to make it more beautiful - auto tmp = std::string(arg) + ", "; - ss << format("%-7s", tmp.c_str()); - } - } else { - ss << arg << (arg != args.back() ? ", " : ""); - } - } - if (value_hint) ss << " " << value_hint; - if (value_hint_2) ss << " " << value_hint_2; - if (ss.tellp() > n_leading_spaces - 3) { - // current line is too long, add new line - ss << "\n" << leading_spaces; - } else { - // padding between arg and help, same line - ss << std::string(leading_spaces.size() - ss.tellp(), ' '); - } - const auto help_lines = break_str_into_lines(help, n_char_per_line_help); - for (const auto & line : help_lines) { - ss << (&line == &help_lines.front() ? "" : leading_spaces) << line << "\n"; - } - return ss.str(); -} - -void gpt_params_print_usage(gpt_params & params, std::vector & options) { - auto print_options = [](std::vector & options) { - for (llama_arg * opt : options) { - printf("%s", opt->to_string().c_str()); - } - }; - - std::vector common_options; - std::vector specific_options; - for (auto & opt : options) { - // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example - if (opt.in_example(params.curr_ex)) { - specific_options.push_back(&opt); - } else { - common_options.push_back(&opt); - } - } - printf("----- common options -----\n\n"); - print_options(common_options); - // TODO: maybe convert enum llama_example to string - printf("\n\n----- example-specific options -----\n\n"); - print_options(specific_options); -} - -std::vector gpt_params_parser_init(gpt_params & params, llama_example ex) { - return gpt_params_parser_init(params, ex, nullptr); -} - -std::vector gpt_params_parser_init(gpt_params & params, llama_example ex, std::function print_usage) { - std::vector options; - params.print_usage = print_usage; - params.curr_ex = ex; - - std::string sampler_type_chars; - std::string sampler_type_names; - for (const auto & sampler : params.sparams.samplers) { - sampler_type_chars += gpt_sampler_type_to_chr(sampler); - sampler_type_names += gpt_sampler_type_to_str(sampler) + ";"; - } - sampler_type_names.pop_back(); - - - /** - * filter options by example - * rules: - * - all examples inherit options from LLAMA_EXAMPLE_COMMON - * - if LLAMA_EXAMPLE_* is set (other than COMMON), we only show the option in the corresponding example - * - if both {LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_*,} are set, we will prioritize the LLAMA_EXAMPLE_* matching current example - */ - std::unordered_set seen_args; - auto add_opt = [&](llama_arg arg) { - if (arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) { - // make sure there is no argument duplications - for (const auto & a : arg.args) { - if (seen_args.find(a) == seen_args.end()) { - seen_args.insert(a); - } else { - throw std::runtime_error(format("found duplicated argument in source code: %s", a)); - } - } - options.push_back(std::move(arg)); - } - }; - - - add_opt(llama_arg( - {"-h", "--help", "--usage"}, - "print usage and exit", - [](gpt_params & params) { - params.usage = true; - } - )); - add_opt(llama_arg( - {"--version"}, - "show version and build info", - [](gpt_params &) { - fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); - exit(0); - } - )); - add_opt(llama_arg( - {"-v", "--verbose"}, - "print verbose information", - [](gpt_params & params) { - params.verbosity = 1; - } - )); - add_opt(llama_arg( - {"--verbosity"}, "N", - format("set specific verbosity level (default: %d)", params.verbosity), - [](gpt_params & params, int value) { - params.verbosity = value; - } - )); - add_opt(llama_arg( - {"--verbose-prompt"}, - format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"), - [](gpt_params & params) { - params.verbose_prompt = true; - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"--no-display-prompt"}, - format("don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false"), - [](gpt_params & params) { - params.display_prompt = false; - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"-co", "--color"}, - format("colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false"), - [](gpt_params & params) { - params.use_color = true; - } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"-s", "--seed"}, "SEED", - format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed), - [](gpt_params & params, const std::string & value) { - params.sparams.seed = std::stoul(value); - } - )); - add_opt(llama_arg( - {"-t", "--threads"}, "N", - format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads), - [](gpt_params & params, int value) { - params.cpuparams.n_threads = value; - if (params.cpuparams.n_threads <= 0) { - params.cpuparams.n_threads = std::thread::hardware_concurrency(); - } - } - ).set_env("LLAMA_ARG_THREADS")); - add_opt(llama_arg( - {"-tb", "--threads-batch"}, "N", - "number of threads to use during batch and prompt processing (default: same as --threads)", - [](gpt_params & params, int value) { - params.cpuparams_batch.n_threads = value; - if (params.cpuparams_batch.n_threads <= 0) { - params.cpuparams_batch.n_threads = std::thread::hardware_concurrency(); - } - } - )); - add_opt(llama_arg( - {"-td", "--threads-draft"}, "N", - "number of threads to use during generation (default: same as --threads)", - [](gpt_params & params, int value) { - params.draft_cpuparams.n_threads = value; - if (params.draft_cpuparams.n_threads <= 0) { - params.draft_cpuparams.n_threads = std::thread::hardware_concurrency(); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"-tbd", "--threads-batch-draft"}, "N", - "number of threads to use during batch and prompt processing (default: same as --threads-draft)", - [](gpt_params & params, int value) { - params.draft_cpuparams_batch.n_threads = value; - if (params.draft_cpuparams_batch.n_threads <= 0) { - params.draft_cpuparams_batch.n_threads = std::thread::hardware_concurrency(); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"-C", "--cpu-mask"}, "M", - "CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: \"\")", - [](gpt_params & params, const std::string & value) { - std::string mask = value; - params.cpuparams.mask_valid = true; - if (!parse_cpu_mask(mask, params.cpuparams.cpumask)) { - throw std::invalid_argument("invalid cpumask"); - } - } - )); - add_opt(llama_arg( - {"-Cr", "--cpu-range"}, "lo-hi", - "range of CPUs for affinity. Complements --cpu-mask", - [](gpt_params & params, const std::string & value) { - std::string range = value; - params.cpuparams.mask_valid = true; - if (!parse_cpu_range(range, params.cpuparams.cpumask)) { - throw std::invalid_argument("invalid range"); - } - } - )); - add_opt(llama_arg( - {"--cpu-strict"}, "<0|1>", - format("use strict CPU placement (default: %u)\n", (unsigned) params.cpuparams.strict_cpu), - [](gpt_params & params, const std::string & value) { - params.cpuparams.strict_cpu = std::stoul(value); - } - )); - add_opt(llama_arg( - {"--poll"}, "<0...100>", - format("use polling level to wait for work (0 - no polling, default: %u)\n", (unsigned) params.cpuparams.poll), - [](gpt_params & params, const std::string & value) { - params.cpuparams.poll = std::stoul(value); - } - )); - add_opt(llama_arg( - {"-Cb", "--cpu-mask-batch"}, "M", - "CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask)", - [](gpt_params & params, const std::string & value) { - std::string mask = value; - params.cpuparams_batch.mask_valid = true; - if (!parse_cpu_mask(mask, params.cpuparams_batch.cpumask)) { - throw std::invalid_argument("invalid cpumask"); - } - } - )); - add_opt(llama_arg( - {"-Crb", "--cpu-range-batch"}, "lo-hi", - "ranges of CPUs for affinity. Complements --cpu-mask-batch", - [](gpt_params & params, const std::string & value) { - std::string range = value; - params.cpuparams_batch.mask_valid = true; - if (!parse_cpu_range(range, params.cpuparams_batch.cpumask)) { - throw std::invalid_argument("invalid range"); - } - } - )); - add_opt(llama_arg( - {"--cpu-strict-batch"}, "<0|1>", - "use strict CPU placement (default: same as --cpu-strict)", - [](gpt_params & params, int value) { - params.cpuparams_batch.strict_cpu = value; - } - )); - add_opt(llama_arg( - {"--poll-batch"}, "<0|1>", - "use polling to wait for work (default: same as --poll)", - [](gpt_params & params, int value) { - params.cpuparams_batch.poll = value; - } - )); - add_opt(llama_arg( - {"-Cd", "--cpu-mask-draft"}, "M", - "Draft model CPU affinity mask. Complements cpu-range-draft (default: same as --cpu-mask)", - [](gpt_params & params, const std::string & value) { - std::string mask = value; - params.draft_cpuparams.mask_valid = true; - if (!parse_cpu_mask(mask, params.draft_cpuparams.cpumask)) { - throw std::invalid_argument("invalid cpumask"); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"-Crd", "--cpu-range-draft"}, "lo-hi", - "Ranges of CPUs for affinity. Complements --cpu-mask-draft", - [](gpt_params & params, const std::string & value) { - std::string range = value; - params.draft_cpuparams.mask_valid = true; - if (!parse_cpu_range(range, params.draft_cpuparams.cpumask)) { - throw std::invalid_argument("invalid range"); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"--cpu-strict-draft"}, "<0|1>", - "Use strict CPU placement for draft model (default: same as --cpu-strict)", - [](gpt_params & params, int value) { - params.draft_cpuparams.strict_cpu = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"--poll-draft"}, "<0|1>", - "Use polling to wait for draft model work (default: same as --poll])", - [](gpt_params & params, int value) { - params.draft_cpuparams.poll = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"-Crbd", "--cpu-range-batch-draft"}, "lo-hi", - "Ranges of CPUs for affinity. Complements --cpu-mask-draft-batch)", - [](gpt_params & params, const std::string & value) { - std::string range = value; - params.draft_cpuparams_batch.mask_valid = true; - if (!parse_cpu_range(range, params.draft_cpuparams_batch.cpumask)) { - throw std::invalid_argument("invalid cpumask"); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"--cpu-strict-batch-draft"}, "<0|1>", - "Use strict CPU placement for draft model (default: --cpu-strict-draft)", - [](gpt_params & params, int value) { - params.draft_cpuparams_batch.strict_cpu = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"--poll-batch-draft"}, "<0|1>", - "Use polling to wait for draft model work (default: --poll-draft)", - [](gpt_params & params, int value) { - params.draft_cpuparams_batch.poll = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"--draft"}, "N", - format("number of tokens to draft for speculative decoding (default: %d)", params.n_draft), - [](gpt_params & params, int value) { - params.n_draft = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"-ps", "--p-split"}, "N", - format("speculative decoding split probability (default: %.1f)", (double)params.p_split), - [](gpt_params & params, const std::string & value) { - params.p_split = std::stof(value); - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"-lcs", "--lookup-cache-static"}, "FNAME", - "path to static lookup cache to use for lookup decoding (not updated by generation)", - [](gpt_params & params, const std::string & value) { - params.lookup_cache_static = value; - } - )); - add_opt(llama_arg( - {"-lcd", "--lookup-cache-dynamic"}, "FNAME", - "path to dynamic lookup cache to use for lookup decoding (updated by generation)", - [](gpt_params & params, const std::string & value) { - params.lookup_cache_dynamic = value; - } - )); - add_opt(llama_arg( - {"-c", "--ctx-size"}, "N", - format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx), - [](gpt_params & params, int value) { - params.n_ctx = value; - } - ).set_env("LLAMA_ARG_CTX_SIZE")); - add_opt(llama_arg( - {"-n", "--predict", "--n-predict"}, "N", - format("number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict), - [](gpt_params & params, int value) { - params.n_predict = value; - } - ).set_env("LLAMA_ARG_N_PREDICT")); - add_opt(llama_arg( - {"-b", "--batch-size"}, "N", - format("logical maximum batch size (default: %d)", params.n_batch), - [](gpt_params & params, int value) { - params.n_batch = value; - } - ).set_env("LLAMA_ARG_BATCH")); - add_opt(llama_arg( - {"-ub", "--ubatch-size"}, "N", - format("physical maximum batch size (default: %d)", params.n_ubatch), - [](gpt_params & params, int value) { - params.n_ubatch = value; - } - ).set_env("LLAMA_ARG_UBATCH")); - add_opt(llama_arg( - {"--keep"}, "N", - format("number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep), - [](gpt_params & params, int value) { - params.n_keep = value; - } - )); - add_opt(llama_arg( - {"--chunks"}, "N", - format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), - [](gpt_params & params, int value) { - params.n_chunks = value; - } - )); - add_opt(llama_arg( - {"-fa", "--flash-attn"}, - format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"), - [](gpt_params & params) { - params.flash_attn = true; - } - ).set_env("LLAMA_ARG_FLASH_ATTN")); - add_opt(llama_arg( - {"-p", "--prompt"}, "PROMPT", - ex == LLAMA_EXAMPLE_MAIN - ? "prompt to start generation with\nif -cnv is set, this will be used as system prompt" - : "prompt to start generation with", - [](gpt_params & params, const std::string & value) { - params.prompt = value; - } - )); - add_opt(llama_arg( - {"-f", "--file"}, "FNAME", - "a file containing the prompt (default: none)", - [](gpt_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(format("error: failed to open file '%s'\n", value.c_str())); - } - // store the external file name in params - params.prompt_file = value; - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); - if (!params.prompt.empty() && params.prompt.back() == '\n') { - params.prompt.pop_back(); - } - } - )); - add_opt(llama_arg( - {"--in-file"}, "FNAME", - "an input file (repeat to specify multiple files)", - [](gpt_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(format("error: failed to open file '%s'\n", value.c_str())); - } - params.in_files.push_back(value); - } - )); - add_opt(llama_arg( - {"-bf", "--binary-file"}, "FNAME", - "binary file containing the prompt (default: none)", - [](gpt_params & params, const std::string & value) { - std::ifstream file(value, std::ios::binary); - if (!file) { - throw std::runtime_error(format("error: failed to open file '%s'\n", value.c_str())); - } - // store the external file name in params - params.prompt_file = value; - std::ostringstream ss; - ss << file.rdbuf(); - params.prompt = ss.str(); - fprintf(stderr, "Read %zu bytes from binary file %s\n", params.prompt.size(), value.c_str()); - } - )); - add_opt(llama_arg( - {"-e", "--escape"}, - format("process escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\) (default: %s)", params.escape ? "true" : "false"), - [](gpt_params & params) { - params.escape = true; - } - )); - add_opt(llama_arg( - {"--no-escape"}, - "do not process escape sequences", - [](gpt_params & params) { - params.escape = false; - } - )); - add_opt(llama_arg( - {"-ptc", "--print-token-count"}, "N", - format("print token count every N tokens (default: %d)", params.n_print), - [](gpt_params & params, int value) { - params.n_print = value; - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"--prompt-cache"}, "FNAME", - "file to cache prompt state for faster startup (default: none)", - [](gpt_params & params, const std::string & value) { - params.path_prompt_cache = value; - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"--prompt-cache-all"}, - "if specified, saves user input and generations to cache as well\n", - [](gpt_params & params) { - params.prompt_cache_all = true; - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"--prompt-cache-ro"}, - "if specified, uses the prompt cache but does not update it", - [](gpt_params & params) { - params.prompt_cache_ro = true; - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"-r", "--reverse-prompt"}, "PROMPT", - "halt generation at PROMPT, return control in interactive mode\n", - [](gpt_params & params, const std::string & value) { - params.antiprompt.emplace_back(value); - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"-sp", "--special"}, - format("special tokens output enabled (default: %s)", params.special ? "true" : "false"), - [](gpt_params & params) { - params.special = true; - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"-cnv", "--conversation"}, - format( - "run in conversation mode:\n" - "- does not print special tokens and suffix/prefix\n" - "- interactive mode is also enabled\n" - "(default: %s)", - params.conversation ? "true" : "false" - ), - [](gpt_params & params) { - params.conversation = true; - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"-i", "--interactive"}, - format("run in interactive mode (default: %s)", params.interactive ? "true" : "false"), - [](gpt_params & params) { - params.interactive = true; - } - ).set_examples({LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"-if", "--interactive-first"}, - format("run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false"), - [](gpt_params & params) { - params.interactive_first = true; - } - ).set_examples({LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"-mli", "--multiline-input"}, - "allows you to write or paste multiple lines without ending each in '\\'", - [](gpt_params & params) { - params.multiline_input = true; - } - ).set_examples({LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"--in-prefix-bos"}, - "prefix BOS to user inputs, preceding the `--in-prefix` string", - [](gpt_params & params) { - params.input_prefix_bos = true; - params.enable_chat_template = false; - } - ).set_examples({LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"--in-prefix"}, "STRING", - "string to prefix user inputs with (default: empty)", - [](gpt_params & params, const std::string & value) { - params.input_prefix = value; - params.enable_chat_template = false; - } - ).set_examples({LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"--in-suffix"}, "STRING", - "string to suffix after user inputs with (default: empty)", - [](gpt_params & params, const std::string & value) { - params.input_suffix = value; - params.enable_chat_template = false; - } - ).set_examples({LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"--no-warmup"}, - "skip warming up the model with an empty run", - [](gpt_params & params) { - params.warmup = false; - } - ).set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(llama_arg( - {"--spm-infill"}, - format( - "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", - params.spm_infill ? "enabled" : "disabled" - ), - [](gpt_params & params) { - params.spm_infill = true; - } - ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"--samplers"}, "SAMPLERS", - format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), - [](gpt_params & params, const std::string & value) { - const auto sampler_names = string_split(value, ';'); - params.sparams.samplers = gpt_sampler_types_from_names(sampler_names, true); - } - )); - add_opt(llama_arg( - {"--sampling-seq"}, "SEQUENCE", - format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()), - [](gpt_params & params, const std::string & value) { - params.sparams.samplers = gpt_sampler_types_from_chars(value); - } - )); - add_opt(llama_arg( - {"--ignore-eos"}, - "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)", - [](gpt_params & params) { - params.sparams.ignore_eos = true; - } - )); - add_opt(llama_arg( - {"--penalize-nl"}, - format("penalize newline tokens (default: %s)", params.sparams.penalize_nl ? "true" : "false"), - [](gpt_params & params) { - params.sparams.penalize_nl = true; - } - )); - add_opt(llama_arg( - {"--temp"}, "N", - format("temperature (default: %.1f)", (double)params.sparams.temp), - [](gpt_params & params, const std::string & value) { - params.sparams.temp = std::stof(value); - params.sparams.temp = std::max(params.sparams.temp, 0.0f); - } - )); - add_opt(llama_arg( - {"--top-k"}, "N", - format("top-k sampling (default: %d, 0 = disabled)", params.sparams.top_k), - [](gpt_params & params, int value) { - params.sparams.top_k = value; - } - )); - add_opt(llama_arg( - {"--top-p"}, "N", - format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sparams.top_p), - [](gpt_params & params, const std::string & value) { - params.sparams.top_p = std::stof(value); - } - )); - add_opt(llama_arg( - {"--min-p"}, "N", - format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sparams.min_p), - [](gpt_params & params, const std::string & value) { - params.sparams.min_p = std::stof(value); - } - )); - add_opt(llama_arg( - {"--tfs"}, "N", - format("tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)params.sparams.tfs_z), - [](gpt_params & params, const std::string & value) { - params.sparams.tfs_z = std::stof(value); - } - )); - add_opt(llama_arg( - {"--typical"}, "N", - format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p), - [](gpt_params & params, const std::string & value) { - params.sparams.typ_p = std::stof(value); - } - )); - add_opt(llama_arg( - {"--repeat-last-n"}, "N", - format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sparams.penalty_last_n), - [](gpt_params & params, int value) { - params.sparams.penalty_last_n = value; - params.sparams.n_prev = std::max(params.sparams.n_prev, params.sparams.penalty_last_n); - } - )); - add_opt(llama_arg( - {"--repeat-penalty"}, "N", - format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sparams.penalty_repeat), - [](gpt_params & params, const std::string & value) { - params.sparams.penalty_repeat = std::stof(value); - } - )); - add_opt(llama_arg( - {"--presence-penalty"}, "N", - format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_present), - [](gpt_params & params, const std::string & value) { - params.sparams.penalty_present = std::stof(value); - } - )); - add_opt(llama_arg( - {"--frequency-penalty"}, "N", - format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sparams.penalty_freq), - [](gpt_params & params, const std::string & value) { - params.sparams.penalty_freq = std::stof(value); - } - )); - add_opt(llama_arg( - {"--dynatemp-range"}, "N", - format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sparams.dynatemp_range), - [](gpt_params & params, const std::string & value) { - params.sparams.dynatemp_range = std::stof(value); - } - )); - add_opt(llama_arg( - {"--dynatemp-exp"}, "N", - format("dynamic temperature exponent (default: %.1f)", (double)params.sparams.dynatemp_exponent), - [](gpt_params & params, const std::string & value) { - params.sparams.dynatemp_exponent = std::stof(value); - } - )); - add_opt(llama_arg( - {"--mirostat"}, "N", - format("use Mirostat sampling.\nTop K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n" - "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sparams.mirostat), - [](gpt_params & params, int value) { - params.sparams.mirostat = value; - } - )); - add_opt(llama_arg( - {"--mirostat-lr"}, "N", - format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sparams.mirostat_eta), - [](gpt_params & params, const std::string & value) { - params.sparams.mirostat_eta = std::stof(value); - } - )); - add_opt(llama_arg( - {"--mirostat-ent"}, "N", - format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sparams.mirostat_tau), - [](gpt_params & params, const std::string & value) { - params.sparams.mirostat_tau = std::stof(value); - } - )); - add_opt(llama_arg( - {"-l", "--logit-bias"}, "TOKEN_ID(+/-)BIAS", - "modifies the likelihood of token appearing in the completion,\n" - "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" - "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'", - [](gpt_params & params, const std::string & value) { - std::stringstream ss(value); - llama_token key; - char sign; - std::string value_str; - try { - if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); - params.sparams.logit_bias.push_back({key, bias}); - } else { - throw std::invalid_argument("invalid input format"); - } - } catch (const std::exception&) { - throw std::invalid_argument("invalid input format"); - } - } - )); - add_opt(llama_arg( - {"--grammar"}, "GRAMMAR", - format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sparams.grammar.c_str()), - [](gpt_params & params, const std::string & value) { - params.sparams.grammar = value; - } - )); - add_opt(llama_arg( - {"--grammar-file"}, "FNAME", - "file to read grammar from", - [](gpt_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(format("error: failed to open file '%s'\n", value.c_str())); - } - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(params.sparams.grammar) - ); - } - )); - add_opt(llama_arg( - {"-j", "--json-schema"}, "SCHEMA", - "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", - [](gpt_params & params, const std::string & value) { - params.sparams.grammar = json_schema_to_grammar(json::parse(value)); - } - )); - add_opt(llama_arg( - {"--pooling"}, "{none,mean,cls,last}", - "pooling type for embeddings, use model default if unspecified", - [](gpt_params & params, const std::string & value) { - /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } - else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } - else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } - else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } - else { throw std::invalid_argument("invalid value"); } - } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); - add_opt(llama_arg( - {"--attention"}, "{causal,non,causal}", - "attention type for embeddings, use model default if unspecified", - [](gpt_params & params, const std::string & value) { - /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } - else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; } - else { throw std::invalid_argument("invalid value"); } - } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); - add_opt(llama_arg( - {"--rope-scaling"}, "{none,linear,yarn}", - "RoPE frequency scaling method, defaults to linear unless specified by the model", - [](gpt_params & params, const std::string & value) { - /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } - else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } - else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } - else { throw std::invalid_argument("invalid value"); } - } - )); - add_opt(llama_arg( - {"--rope-scale"}, "N", - "RoPE context scaling factor, expands context by a factor of N", - [](gpt_params & params, const std::string & value) { - params.rope_freq_scale = 1.0f / std::stof(value); - } - )); - add_opt(llama_arg( - {"--rope-freq-base"}, "N", - "RoPE base frequency, used by NTK-aware scaling (default: loaded from model)", - [](gpt_params & params, const std::string & value) { - params.rope_freq_base = std::stof(value); - } - )); - add_opt(llama_arg( - {"--rope-freq-scale"}, "N", - "RoPE frequency scaling factor, expands context by a factor of 1/N", - [](gpt_params & params, const std::string & value) { - params.rope_freq_scale = std::stof(value); - } - )); - add_opt(llama_arg( - {"--yarn-orig-ctx"}, "N", - format("YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx), - [](gpt_params & params, int value) { - params.yarn_orig_ctx = value; - } - )); - add_opt(llama_arg( - {"--yarn-ext-factor"}, "N", - format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor), - [](gpt_params & params, const std::string & value) { - params.yarn_ext_factor = std::stof(value); - } - )); - add_opt(llama_arg( - {"--yarn-attn-factor"}, "N", - format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor), - [](gpt_params & params, const std::string & value) { - params.yarn_attn_factor = std::stof(value); - } - )); - add_opt(llama_arg( - {"--yarn-beta-slow"}, "N", - format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow), - [](gpt_params & params, const std::string & value) { - params.yarn_beta_slow = std::stof(value); - } - )); - add_opt(llama_arg( - {"--yarn-beta-fast"}, "N", - format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast), - [](gpt_params & params, const std::string & value) { - params.yarn_beta_fast = std::stof(value); - } - )); - add_opt(llama_arg( - {"-gan", "--grp-attn-n"}, "N", - format("group-attention factor (default: %d)", params.grp_attn_n), - [](gpt_params & params, int value) { - params.grp_attn_n = value; - } - )); - add_opt(llama_arg( - {"-gaw", "--grp-attn-w"}, "N", - format("group-attention width (default: %.1f)", (double)params.grp_attn_w), - [](gpt_params & params, int value) { - params.grp_attn_w = value; - } - )); - add_opt(llama_arg( - {"-dkvc", "--dump-kv-cache"}, - "verbose print of the KV cache", - [](gpt_params & params) { - params.dump_kv_cache = true; - } - )); - add_opt(llama_arg( - {"-nkvo", "--no-kv-offload"}, - "disable KV offload", - [](gpt_params & params) { - params.no_kv_offload = true; - } - )); - add_opt(llama_arg( - {"-ctk", "--cache-type-k"}, "TYPE", - format("KV cache data type for K (default: %s)", params.cache_type_k.c_str()), - [](gpt_params & params, const std::string & value) { - // TODO: get the type right here - params.cache_type_k = value; - } - )); - add_opt(llama_arg( - {"-ctv", "--cache-type-v"}, "TYPE", - format("KV cache data type for V (default: %s)", params.cache_type_v.c_str()), - [](gpt_params & params, const std::string & value) { - // TODO: get the type right here - params.cache_type_v = value; - } - )); - add_opt(llama_arg( - {"--all-logits"}, - format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"), - [](gpt_params & params) { - params.logits_all = true; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"--hellaswag"}, - "compute HellaSwag score over random tasks from datafile supplied with -f", - [](gpt_params & params) { - params.hellaswag = true; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"--hellaswag-tasks"}, "N", - format("number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks), - [](gpt_params & params, int value) { - params.hellaswag_tasks = value; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"--winogrande"}, - "compute Winogrande score over random tasks from datafile supplied with -f", - [](gpt_params & params) { - params.winogrande = true; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"--winogrande-tasks"}, "N", - format("number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks), - [](gpt_params & params, int value) { - params.winogrande_tasks = value; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"--multiple-choice"}, - "compute multiple choice score over random tasks from datafile supplied with -f", - [](gpt_params & params) { - params.multiple_choice = true; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"--multiple-choice-tasks"}, "N", - format("number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks), - [](gpt_params & params, int value) { - params.multiple_choice_tasks = value; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"--kl-divergence"}, - "computes KL-divergence to logits provided via --kl-divergence-base", - [](gpt_params & params) { - params.kl_divergence = true; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"--ppl-stride"}, "N", - format("stride for perplexity calculation (default: %d)", params.ppl_stride), - [](gpt_params & params, int value) { - params.ppl_stride = value; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"--ppl-output-type"}, "<0|1>", - format("output type for perplexity calculation (default: %d)", params.ppl_output_type), - [](gpt_params & params, int value) { - params.ppl_output_type = value; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); - add_opt(llama_arg( - {"-dt", "--defrag-thold"}, "N", - format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold), - [](gpt_params & params, const std::string & value) { - params.defrag_thold = std::stof(value); - } - ).set_env("LLAMA_ARG_DEFRAG_THOLD")); - add_opt(llama_arg( - {"-np", "--parallel"}, "N", - format("number of parallel sequences to decode (default: %d)", params.n_parallel), - [](gpt_params & params, int value) { - params.n_parallel = value; - } - )); - add_opt(llama_arg( - {"-ns", "--sequences"}, "N", - format("number of sequences to decode (default: %d)", params.n_sequences), - [](gpt_params & params, int value) { - params.n_sequences = value; - } - )); - add_opt(llama_arg( - {"-cb", "--cont-batching"}, - format("enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled"), - [](gpt_params & params) { - params.cont_batching = true; - } - ).set_env("LLAMA_ARG_CONT_BATCHING")); - add_opt(llama_arg( - {"-nocb", "--no-cont-batching"}, - "disable continuous batching", - [](gpt_params & params) { - params.cont_batching = false; - } - ).set_env("LLAMA_ARG_NO_CONT_BATCHING")); - add_opt(llama_arg( - {"--mmproj"}, "FILE", - "path to a multimodal projector file for LLaVA. see examples/llava/README.md", - [](gpt_params & params, const std::string & value) { - params.mmproj = value; - } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); - add_opt(llama_arg( - {"--image"}, "FILE", - "path to an image file. use with multimodal models. Specify multiple times for batching", - [](gpt_params & params, const std::string & value) { - params.image.emplace_back(value); - } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); -#ifdef GGML_USE_RPC - add_opt(llama_arg( - {"--rpc"}, "SERVERS", - "comma separated list of RPC servers", - [](gpt_params & params, const std::string & value) { - params.rpc_servers = value; - } - )); +#ifdef NDEBUG + const char * build_type = ""; +#else + const char * build_type = " (debug)"; #endif - add_opt(llama_arg( - {"--mlock"}, - "force system to keep model in RAM rather than swapping or compressing", - [](gpt_params & params) { - params.use_mlock = true; - } - )); - add_opt(llama_arg( - {"--no-mmap"}, - "do not memory-map model (slower load but may reduce pageouts if not using mlock)", - [](gpt_params & params) { - params.use_mmap = false; - } - )); - add_opt(llama_arg( - {"--numa"}, "TYPE", - "attempt optimizations that help on some NUMA systems\n" - "- distribute: spread execution evenly over all nodes\n" - "- isolate: only spawn threads on CPUs on the node that execution started on\n" - "- numactl: use the CPU map provided by numactl\n" - "if run without this previously, it is recommended to drop the system page cache before using this\n" - "see https://github.com/ggerganov/llama.cpp/issues/1437", - [](gpt_params & params, const std::string & value) { - /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } - else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } - else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } - else { throw std::invalid_argument("invalid value"); } - } - )); - add_opt(llama_arg( - {"-ngl", "--gpu-layers"}, "N", - "number of layers to store in VRAM", - [](gpt_params & params, int value) { - params.n_gpu_layers = value; - if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); - } - } - ).set_env("LLAMA_ARG_N_GPU_LAYERS")); - add_opt(llama_arg( - {"-ngld", "--gpu-layers-draft"}, "N", - "number of layers to store in VRAM for the draft model", - [](gpt_params & params, int value) { - params.n_gpu_layers_draft = value; - if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); - } - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"-sm", "--split-mode"}, "{none,layer,row}", - "how to split the model across multiple GPUs, one of:\n" - "- none: use one GPU only\n" - "- layer (default): split layers and KV across GPUs\n" - "- row: split rows across GPUs", - [](gpt_params & params, const std::string & value) { - std::string arg_next = value; - if (arg_next == "none") { - params.split_mode = LLAMA_SPLIT_MODE_NONE; - } else if (arg_next == "layer") { - params.split_mode = LLAMA_SPLIT_MODE_LAYER; - } - else if (arg_next == "row") { -#ifdef GGML_USE_SYCL - fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n"); - exit(1); -#endif // GGML_USE_SYCL - params.split_mode = LLAMA_SPLIT_MODE_ROW; - } - else { - throw std::invalid_argument("invalid value"); - } -#ifndef GGML_USE_CUDA_SYCL_VULKAN - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the split mode has no effect.\n"); -#endif // GGML_USE_CUDA_SYCL_VULKAN - } - )); - add_opt(llama_arg( - {"-ts", "--tensor-split"}, "N0,N1,N2,...", - "fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1", - [](gpt_params & params, const std::string & value) { - std::string arg_next = value; - // split string by , and / - const std::regex regex{ R"([,/]+)" }; - std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; - std::vector split_arg{ it, {} }; - if (split_arg.size() >= llama_max_devices()) { - throw std::invalid_argument( - format("got %d input configs, but system only has %d devices", (int)split_arg.size(), (int)llama_max_devices()) - ); - } - for (size_t i = 0; i < llama_max_devices(); ++i) { - if (i < split_arg.size()) { - params.tensor_split[i] = std::stof(split_arg[i]); - } else { - params.tensor_split[i] = 0.0f; - } - } -#ifndef GGML_USE_CUDA_SYCL_VULKAN - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting a tensor split has no effect.\n"); -#endif // GGML_USE_CUDA_SYCL_VULKAN - } - )); - add_opt(llama_arg( - {"-mg", "--main-gpu"}, "INDEX", - format("the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu), - [](gpt_params & params, int value) { - params.main_gpu = value; -#ifndef GGML_USE_CUDA_SYCL_VULKAN - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n"); -#endif // GGML_USE_CUDA_SYCL_VULKAN - } - )); - add_opt(llama_arg( - {"--check-tensors"}, - format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"), - [](gpt_params & params) { - params.check_tensors = true; - } - )); - add_opt(llama_arg( - {"--override-kv"}, "KEY=TYPE:VALUE", - "advanced option to override model metadata by key. may be specified multiple times.\n" - "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false", - [](gpt_params & params, const std::string & value) { - if (!string_parse_kv_override(value.c_str(), params.kv_overrides)) { - throw std::runtime_error(format("error: Invalid type for KV override: %s\n", value.c_str())); - } - } - )); - add_opt(llama_arg( - {"--lora"}, "FNAME", - "path to LoRA adapter (can be repeated to use multiple adapters)", - [](gpt_params & params, const std::string & value) { - params.lora_adapters.push_back({ std::string(value), 1.0 }); - } - ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); - add_opt(llama_arg( - {"--lora-scaled"}, "FNAME", "SCALE", - "path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)", - [](gpt_params & params, const std::string & fname, const std::string & scale) { - params.lora_adapters.push_back({ fname, std::stof(scale) }); - } - ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); - add_opt(llama_arg( - {"--control-vector"}, "FNAME", - "add a control vector\nnote: this argument can be repeated to add multiple control vectors", - [](gpt_params & params, const std::string & value) { - params.control_vectors.push_back({ 1.0f, value, }); - } - )); - add_opt(llama_arg( - {"--control-vector-scaled"}, "FNAME", "SCALE", - "add a control vector with user defined scaling SCALE\n" - "note: this argument can be repeated to add multiple scaled control vectors", - [](gpt_params & params, const std::string & fname, const std::string & scale) { - params.control_vectors.push_back({ std::stof(scale), fname }); - } - )); - add_opt(llama_arg( - {"--control-vector-layer-range"}, "START", "END", - "layer range to apply the control vector(s) to, start and end inclusive", - [](gpt_params & params, const std::string & start, const std::string & end) { - params.control_vector_layer_start = std::stoi(start); - params.control_vector_layer_end = std::stoi(end); - } - )); - add_opt(llama_arg( - {"-a", "--alias"}, "STRING", - "set alias for model name (to be used by REST API)", - [](gpt_params & params, const std::string & value) { - params.model_alias = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL")); - add_opt(llama_arg( - {"-m", "--model"}, "FNAME", - ex == LLAMA_EXAMPLE_EXPORT_LORA - ? std::string("model path from which to load base model") - : format( - "model path (default: `models/$filename` with filename from `--hf-file` " - "or `--model-url` if set, otherwise %s)", DEFAULT_MODEL_PATH - ), - [](gpt_params & params, const std::string & value) { - params.model = value; - } - ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); - add_opt(llama_arg( - {"-md", "--model-draft"}, "FNAME", - "draft model for speculative decoding (default: unused)", - [](gpt_params & params, const std::string & value) { - params.model_draft = value; - } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); - add_opt(llama_arg( - {"-mu", "--model-url"}, "MODEL_URL", - "model download url (default: unused)", - [](gpt_params & params, const std::string & value) { - params.model_url = value; - } - ).set_env("LLAMA_ARG_MODEL_URL")); - add_opt(llama_arg( - {"-hfr", "--hf-repo"}, "REPO", - "Hugging Face model repository (default: unused)", - [](gpt_params & params, const std::string & value) { - params.hf_repo = value; - } - ).set_env("LLAMA_ARG_HF_REPO")); - add_opt(llama_arg( - {"-hff", "--hf-file"}, "FILE", - "Hugging Face model file (default: unused)", - [](gpt_params & params, const std::string & value) { - params.hf_file = value; - } - ).set_env("LLAMA_ARG_HF_FILE")); - add_opt(llama_arg( - {"-hft", "--hf-token"}, "TOKEN", - "Hugging Face access token (default: value from HF_TOKEN environment variable)", - [](gpt_params & params, const std::string & value) { - params.hf_token = value; - } - ).set_env("HF_TOKEN")); - add_opt(llama_arg( - {"--context-file"}, "FNAME", - "file to load context from (repeat to specify multiple files)", - [](gpt_params & params, const std::string & value) { - std::ifstream file(value, std::ios::binary); - if (!file) { - throw std::runtime_error(format("error: failed to open file '%s'\n", value.c_str())); - } - params.context_files.push_back(value); - } - ).set_examples({LLAMA_EXAMPLE_RETRIEVAL})); - add_opt(llama_arg( - {"--chunk-size"}, "N", - format("minimum length of embedded text chunks (default: %d)", params.chunk_size), - [](gpt_params & params, int value) { - params.chunk_size = value; - } - ).set_examples({LLAMA_EXAMPLE_RETRIEVAL})); - add_opt(llama_arg( - {"--chunk-separator"}, "STRING", - format("separator between chunks (default: '%s')", params.chunk_separator.c_str()), - [](gpt_params & params, const std::string & value) { - params.chunk_separator = value; - } - ).set_examples({LLAMA_EXAMPLE_RETRIEVAL})); - add_opt(llama_arg( - {"--junk"}, "N", - format("number of times to repeat the junk text (default: %d)", params.n_junk), - [](gpt_params & params, int value) { - params.n_junk = value; - } - ).set_examples({LLAMA_EXAMPLE_PASSKEY})); - add_opt(llama_arg( - {"--pos"}, "N", - format("position of the passkey in the junk text (default: %d)", params.i_pos), - [](gpt_params & params, int value) { - params.i_pos = value; - } - ).set_examples({LLAMA_EXAMPLE_PASSKEY})); - add_opt(llama_arg( - {"-o", "--output"}, "FNAME", - format("output file (default: '%s')", - ex == LLAMA_EXAMPLE_EXPORT_LORA - ? params.lora_outfile.c_str() - : ex == LLAMA_EXAMPLE_CVECTOR_GENERATOR - ? params.cvector_outfile.c_str() - : params.out_file.c_str()), - [](gpt_params & params, const std::string & value) { - params.out_file = value; - params.cvector_outfile = value; - params.lora_outfile = value; - } - ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA})); - add_opt(llama_arg( - {"-ofreq", "--output-frequency"}, "N", - format("output the imatrix every N iterations (default: %d)", params.n_out_freq), - [](gpt_params & params, int value) { - params.n_out_freq = value; - } - ).set_examples({LLAMA_EXAMPLE_IMATRIX})); - add_opt(llama_arg( - {"--save-frequency"}, "N", - format("save an imatrix copy every N iterations (default: %d)", params.n_save_freq), - [](gpt_params & params, int value) { - params.n_save_freq = value; - } - ).set_examples({LLAMA_EXAMPLE_IMATRIX})); - add_opt(llama_arg( - {"--process-output"}, - format("collect data for the output tensor (default: %s)", params.process_output ? "true" : "false"), - [](gpt_params & params) { - params.process_output = true; - } - ).set_examples({LLAMA_EXAMPLE_IMATRIX})); - add_opt(llama_arg( - {"--no-ppl"}, - format("do not compute perplexity (default: %s)", params.compute_ppl ? "true" : "false"), - [](gpt_params & params) { - params.compute_ppl = false; - } - ).set_examples({LLAMA_EXAMPLE_IMATRIX})); - add_opt(llama_arg( - {"--chunk"}, "N", - format("start processing the input from chunk N (default: %d)", params.i_chunk), - [](gpt_params & params, int value) { - params.i_chunk = value; - } - ).set_examples({LLAMA_EXAMPLE_IMATRIX})); - add_opt(llama_arg( - {"-pps"}, - format("is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false"), - [](gpt_params & params) { - params.is_pp_shared = true; - } - ).set_examples({LLAMA_EXAMPLE_BENCH})); - add_opt(llama_arg( - {"-npp"}, "n0,n1,...", - "number of prompt tokens", - [](gpt_params & params, const std::string & value) { - auto p = string_split(value, ','); - params.n_pp.insert(params.n_pp.end(), p.begin(), p.end()); - } - ).set_examples({LLAMA_EXAMPLE_BENCH})); - add_opt(llama_arg( - {"-ntg"}, "n0,n1,...", - "number of text generation tokens", - [](gpt_params & params, const std::string & value) { - auto p = string_split(value, ','); - params.n_tg.insert(params.n_tg.end(), p.begin(), p.end()); - } - ).set_examples({LLAMA_EXAMPLE_BENCH})); - add_opt(llama_arg( - {"-npl"}, "n0,n1,...", - "number of parallel prompts", - [](gpt_params & params, const std::string & value) { - auto p = string_split(value, ','); - params.n_pl.insert(params.n_pl.end(), p.begin(), p.end()); - } - ).set_examples({LLAMA_EXAMPLE_BENCH})); - add_opt(llama_arg( - {"--embd-normalize"}, "N", - format("normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize), - [](gpt_params & params, int value) { - params.embd_normalize = value; - } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); - add_opt(llama_arg( - {"--embd-output-format"}, "FORMAT", - "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix", - [](gpt_params & params, const std::string & value) { - params.embd_out = value; - } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); - add_opt(llama_arg( - {"--embd-separator"}, "STRING", - "separator of embendings (default \\n) for example \"<#sep#>\"", - [](gpt_params & params, const std::string & value) { - params.embd_sep = value; - } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); - add_opt(llama_arg( - {"--host"}, "HOST", - format("ip address to listen (default: %s)", params.hostname.c_str()), - [](gpt_params & params, const std::string & value) { - params.hostname = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_HOST")); - add_opt(llama_arg( - {"--port"}, "PORT", - format("port to listen (default: %d)", params.port), - [](gpt_params & params, int value) { - params.port = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_PORT")); - add_opt(llama_arg( - {"--path"}, "PATH", - format("path to serve static files from (default: %s)", params.public_path.c_str()), - [](gpt_params & params, const std::string & value) { - params.public_path = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--embedding", "--embeddings"}, - format("restrict to only support embedding use case; use only with dedicated embedding models (default: %s)", params.embedding ? "enabled" : "disabled"), - [](gpt_params & params) { - params.embedding = true; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); - add_opt(llama_arg( - {"--api-key"}, "KEY", - "API key to use for authentication (default: none)", - [](gpt_params & params, const std::string & value) { - params.api_keys.push_back(value); - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); - add_opt(llama_arg( - {"--api-key-file"}, "FNAME", - "path to file containing API keys (default: none)", - [](gpt_params & params, const std::string & value) { - std::ifstream key_file(value); - if (!key_file) { - throw std::runtime_error(format("error: failed to open file '%s'\n", value.c_str())); - } - std::string key; - while (std::getline(key_file, key)) { - if (!key.empty()) { - params.api_keys.push_back(key); - } - } - key_file.close(); - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--ssl-key-file"}, "FNAME", - "path to file a PEM-encoded SSL private key", - [](gpt_params & params, const std::string & value) { - params.ssl_file_key = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--ssl-cert-file"}, "FNAME", - "path to file a PEM-encoded SSL certificate", - [](gpt_params & params, const std::string & value) { - params.ssl_file_cert = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--timeout"}, "N", - format("server read/write timeout in seconds (default: %d)", params.timeout_read), - [](gpt_params & params, int value) { - params.timeout_read = value; - params.timeout_write = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--threads-http"}, "N", - format("number of threads used to process HTTP requests (default: %d)", params.n_threads_http), - [](gpt_params & params, int value) { - params.n_threads_http = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP")); - add_opt(llama_arg( - {"-spf", "--system-prompt-file"}, "FNAME", - "set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications", - [](gpt_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(format("error: failed to open file '%s'\n", value.c_str())); - } - std::string system_prompt; - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(system_prompt) - ); - params.system_prompt = system_prompt; - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--log-format"}, "{text, json}", - "log output format: json or text (default: json)", - [](gpt_params & params, const std::string & value) { - if (value == "json") { - params.log_json = true; - } else if (value == "text") { - params.log_json = false; - } else { - throw std::invalid_argument("invalid value"); - } - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--metrics"}, - format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"), - [](gpt_params & params) { - params.endpoint_metrics = true; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS")); - add_opt(llama_arg( - {"--no-slots"}, - format("disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"), - [](gpt_params & params) { - params.endpoint_slots = false; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_ENDPOINT_SLOTS")); - add_opt(llama_arg( - {"--slot-save-path"}, "PATH", - "path to save slot kv cache (default: disabled)", - [](gpt_params & params, const std::string & value) { - params.slot_save_path = value; - // if doesn't end with DIRECTORY_SEPARATOR, add it - if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { - params.slot_save_path += DIRECTORY_SEPARATOR; - } - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--chat-template"}, "JINJA_TEMPLATE", - "set custom jinja chat template (default: template taken from model's metadata)\n" - "if suffix/prefix are specified, template will be disabled\n" - "only commonly used templates are accepted:\nhttps://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template", - [](gpt_params & params, const std::string & value) { - if (!llama_chat_verify_template(value)) { - throw std::runtime_error(format( - "error: the supplied chat template is not supported: %s\n" - "note: llama.cpp does not use jinja parser, we only support commonly used templates\n", - value.c_str() - )); - } - params.chat_template = value; - } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); - add_opt(llama_arg( - {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", - format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), - [](gpt_params & params, const std::string & value) { - params.slot_prompt_similarity = std::stof(value); - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--lora-init-without-apply"}, - format("load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"), - [](gpt_params & params) { - params.lora_init_without_apply = true; - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(llama_arg( - {"--simple-io"}, - "use basic IO for better compatibility in subprocesses and limited consoles", - [](gpt_params & params) { - params.simple_io = true; - } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); - add_opt(llama_arg( - {"-ld", "--logdir"}, "LOGDIR", - "path under which to save YAML logs (no logging if unset)", - [](gpt_params & params, const std::string & value) { - params.logdir = value; - - if (params.logdir.back() != DIRECTORY_SEPARATOR) { - params.logdir += DIRECTORY_SEPARATOR; - } - } - )); - add_opt(llama_arg( - {"--positive-file"}, "FNAME", - format("positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str()), - [](gpt_params & params, const std::string & value) { - params.cvector_positive_file = value; - } - ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); - add_opt(llama_arg( - {"--negative-file"}, "FNAME", - format("negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str()), - [](gpt_params & params, const std::string & value) { - params.cvector_negative_file = value; - } - ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); - add_opt(llama_arg( - {"--pca-batch"}, "N", - format("batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch), - [](gpt_params & params, int value) { - params.n_pca_batch = value; - } - ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); - add_opt(llama_arg( - {"--pca-iter"}, "N", - format("number of iterations used for PCA (default: %d)", params.n_pca_iterations), - [](gpt_params & params, int value) { - params.n_pca_iterations = value; - } - ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); - add_opt(llama_arg( - {"--method"}, "{pca, mean}", - "dimensionality reduction method to be used (default: pca)", - [](gpt_params & params, const std::string & value) { - /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; } - else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; } - else { throw std::invalid_argument("invalid value"); } - } - ).set_examples({LLAMA_EXAMPLE_CVECTOR_GENERATOR})); - add_opt(llama_arg( - {"--output-format"}, "{md,jsonl}", - "output format for batched-bench results (default: md)", - [](gpt_params & params, const std::string & value) { - /**/ if (value == "jsonl") { params.batched_bench_output_jsonl = true; } - else if (value == "md") { params.batched_bench_output_jsonl = false; } - else { std::invalid_argument("invalid value"); } - } - ).set_examples({LLAMA_EXAMPLE_BENCH})); -#ifndef LOG_DISABLE_LOGS - // TODO: make this looks less weird - add_opt(llama_arg( - {"--log-test"}, - "Log test", - [](gpt_params &) { log_param_single_parse("--log-test"); } - )); - add_opt(llama_arg( - {"--log-disable"}, - "Log disable", - [](gpt_params &) { log_param_single_parse("--log-disable"); } - )); - add_opt(llama_arg( - {"--log-enable"}, - "Log enable", - [](gpt_params &) { log_param_single_parse("--log-enable"); } - )); - add_opt(llama_arg( - {"--log-new"}, - "Log new", - [](gpt_params &) { log_param_single_parse("--log-new"); } - )); - add_opt(llama_arg( - {"--log-append"}, - "Log append", - [](gpt_params &) { log_param_single_parse("--log-append"); } - )); - add_opt(llama_arg( - {"--log-file"}, "FNAME", - "Log file", - [](gpt_params &, const std::string & value) { log_param_pair_parse(false, "--log-file", value); } - )); -#endif // LOG_DISABLE_LOGS - - return options; + LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type); } -std::string gpt_params_get_system_info(const gpt_params & params) { +std::string common_params_get_system_info(const common_params & params) { std::ostringstream os; os << "system_info: n_threads = " << params.cpuparams.n_threads; @@ -2279,17 +425,19 @@ std::string gpt_params_get_system_info(const gpt_params & params) { // String utils // -std::vector string_split(std::string input, char separator) { - std::vector parts; - size_t separator_pos = input.find(separator); - while (separator_pos != std::string::npos) { - std::string part = input.substr(0, separator_pos); - parts.emplace_back(part); - input = input.substr(separator_pos + 1); - separator_pos = input.find(separator); - } - parts.emplace_back(input); - return parts; +std::string string_format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); } std::string string_strip(const std::string & str) { @@ -2337,6 +485,136 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +std::string string_join(const std::vector & values, const std::string & separator) { + std::ostringstream result; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + result << separator; + } + result << values[i]; + } + return result.str(); +} + +std::vector string_split(const std::string & str, const std::string & delimiter) { + std::vector parts; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + parts.push_back(str.substr(start, end - start)); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + parts.push_back(str.substr(start)); + + return parts; +} + +std::string string_repeat(const std::string & str, size_t n) { + if (n == 0) { + return ""; + } + + std::string result; + result.reserve(str.length() * n); + + for (size_t i = 0; i < n; ++i) { + result += str; + } + + return result; +} + +std::string string_from(bool value) { + return value ? "true" : "false"; +} + +std::string string_from(const std::vector & values) { + std::stringstream buf; + + buf << "[ "; + bool first = true; + for (auto e : values) { + if (first) { + first = false; + } else { + buf << ", "; + } + buf << std::to_string(e); + } + buf << " ]"; + + return buf.str(); +} + +std::string string_from(const struct llama_context * ctx, const std::vector & tokens) { + std::stringstream buf; + + buf << "[ "; + + bool first = true; + for (const auto & token : tokens) { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = common_token_to_piece(ctx, token); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf << "'" << detokenized << "'" + << ":" << std::to_string(token); + } + + buf << " ]"; + + return buf.str(); +} + +std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) { + std::stringstream buf; + + buf << "[ "; + + bool first = true; + for (int i = 0; i < batch.n_tokens; ++i) { + if (!first) { + buf << ", "; + } else { + first = false; + } + + auto detokenized = common_token_to_piece(ctx, batch.token[i]); + + detokenized.erase( + std::remove_if( + detokenized.begin(), + detokenized.end(), + [](const unsigned char c) { return !std::isprint(c); }), + detokenized.end()); + + buf << "\n" << std::to_string(i) + << ", token '" << detokenized << "'" + << ", pos " << std::to_string(batch.pos[i]) + << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) + << ", seq_id " << std::to_string(batch.seq_id[i][0]) + << ", logits " << std::to_string(batch.logits[i]); + } + + buf << " ]"; + + return buf.str(); +} + void string_process_escapes(std::string & input) { std::size_t input_len = input.length(); std::size_t output_idx = 0; @@ -2377,7 +655,7 @@ void string_process_escapes(std::string & input) { bool string_parse_kv_override(const char * data, std::vector & overrides) { const char * sep = strchr(data, '='); if (sep == nullptr || sep - data >= 128) { - fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); + LOG_ERR("%s: malformed KV override '%s'\n", __func__, data); return false; } llama_model_kv_override kvo; @@ -2400,20 +678,20 @@ bool string_parse_kv_override(const char * data, std::vector 127) { - fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); + LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); return false; } strncpy(kvo.val_str, sep, 127); kvo.val_str[127] = '\0'; } else { - fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); + LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data); return false; } overrides.emplace_back(std::move(kvo)); @@ -2440,7 +718,17 @@ bool fs_validate_filename(const std::string & filename) { std::u32string filename_utf32; try { +#if defined(__clang__) + // disable C++17 deprecation warning for std::codecvt_utf8 +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif std::wstring_convert, char32_t> converter; + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif + filename_utf32 = converter.from_bytes(filename); // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, @@ -2610,87 +898,143 @@ std::string fs_get_cache_file(const std::string & filename) { // // Model utils // -struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { - llama_init_result iparams; - auto mparams = llama_model_params_from_gpt_params(params); +struct common_init_result common_init_from_params(common_params & params) { + common_init_result iparams; + auto mparams = common_model_params_to_llama(params); llama_model * model = nullptr; if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); + model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams); } else if (!params.model_url.empty()) { - model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); + model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); } else { - model = llama_load_model_from_file(params.model.c_str(), mparams); + model = llama_model_load_from_file(params.model.c_str(), mparams); } if (model == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str()); return iparams; } - auto cparams = llama_context_params_from_gpt_params(params); + const llama_vocab * vocab = llama_model_get_vocab(model); - llama_context * lctx = llama_new_context_with_model(model, cparams); + if (params.reranking) { + bool ok = true; + + if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); + ok = false; + } + + if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__); + ok = false; + } + + if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); + ok = false; + } + + if (!ok) { + llama_model_free(model); + + return iparams; + } + } + + auto cparams = common_context_params_to_llama(params); + + llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { - fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); + LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_model_free(model); return iparams; } + if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) { + LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__); + params.ctx_shift = false; + } + if (!params.control_vectors.empty()) { if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; - if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model); - const auto cvec = llama_control_vector_load(params.control_vectors); + const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); + return iparams; } - int err = llama_control_vector_apply(lctx, - cvec.data.data(), - cvec.data.size(), - cvec.n_embd, - params.control_vector_layer_start, - params.control_vector_layer_end); + int err = llama_apply_adapter_cvec( + lctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); if (err) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); + return iparams; } } // load and optionally apply lora adapters for (auto & la : params.lora_adapters) { - llama_lora_adapter_container loaded_la; - loaded_la.path = la.path; - loaded_la.scale = la.scale; - loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str()); - if (loaded_la.adapter == nullptr) { - fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model, la.path.c_str())); + if (lora == nullptr) { + LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } - iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters - } - if (!params.lora_init_without_apply) { - llama_lora_adapters_apply(lctx, iparams.lora_adapters); + + la.ptr = lora.get(); + iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters } - if (params.sparams.ignore_eos && llama_token_eos(model) == -1) { - fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); - params.sparams.ignore_eos = false; + if (!params.lora_init_without_apply) { + common_set_adapter_lora(lctx, params.lora_adapters); + } + + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sampling.ignore_eos = false; + } + + if (params.sampling.ignore_eos) { + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); + params.sampling.logit_bias.push_back({i, -INFINITY}); + } + } + } + + if (params.sampling.penalty_last_n == -1) { + LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.penalty_last_n = llama_n_ctx(lctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); } if (params.warmup) { - LOG("warming up the model with an empty run\n"); + LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); std::vector tmp; - llama_token bos = llama_token_bos(model); - llama_token eos = llama_token_eos(model); + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); + // some models (e.g. T5) don't have a BOS token if (bos != LLAMA_TOKEN_NULL) { tmp.push_back(bos); @@ -2703,43 +1047,46 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } if (llama_model_has_encoder(model)) { - llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = bos; } tmp.clear(); tmp.push_back(decoder_start_token_id); } if (llama_model_has_decoder(model)) { - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); } llama_kv_cache_clear(lctx); llama_synchronize(lctx); - llama_perf_reset(lctx, LLAMA_PERF_TYPE_CONTEXT); + llama_perf_context_reset(lctx); } - iparams.model = model; - iparams.context = lctx; + iparams.model.reset(model); + iparams.context.reset(lctx); + return iparams; } -void llama_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters) { - llama_lora_adapter_clear(ctx); - for (auto & la : lora_adapters) { +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { + llama_clear_adapter_lora(ctx); + for (auto & la : lora) { if (la.scale != 0.0f) { - llama_lora_adapter_set(ctx, la.adapter, la.scale); + llama_set_adapter_lora(ctx, la.ptr, la.scale); } } } -struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) { +struct llama_model_params common_model_params_to_llama(common_params & params) { auto mparams = llama_model_default_params(); + if (!params.devices.empty()) { + mparams.devices = params.devices.data(); + } if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } - mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; @@ -2756,36 +1103,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & return mparams; } -static ggml_type kv_cache_type_from_str(const std::string & s) { - if (s == "f32") { - return GGML_TYPE_F32; - } - if (s == "f16") { - return GGML_TYPE_F16; - } - if (s == "q8_0") { - return GGML_TYPE_Q8_0; - } - if (s == "q4_0") { - return GGML_TYPE_Q4_0; - } - if (s == "q4_1") { - return GGML_TYPE_Q4_1; - } - if (s == "iq4_nl") { - return GGML_TYPE_IQ4_NL; - } - if (s == "q5_0") { - return GGML_TYPE_Q5_0; - } - if (s == "q5_1") { - return GGML_TYPE_Q5_1; - } - - throw std::runtime_error("Invalid cache type: " + s); -} - -struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { +struct llama_context_params common_context_params_to_llama(const common_params & params) { auto cparams = llama_context_default_params(); cparams.n_ctx = params.n_ctx; @@ -2794,7 +1112,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_ubatch = params.n_ubatch; cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? - params.cpuparams.n_threads : params.cpuparams_batch.n_threads; + params.cpuparams.n_threads : params.cpuparams_batch.n_threads; cparams.logits_all = params.logits_all; cparams.embeddings = params.embedding; cparams.rope_scaling_type = params.rope_scaling_type; @@ -2812,9 +1130,15 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; + cparams.no_perf = params.no_perf; - cparams.type_k = kv_cache_type_from_str(params.cache_type_k); - cparams.type_v = kv_cache_type_from_str(params.cache_type_v); + if (params.reranking) { + cparams.embeddings = true; + cparams.pooling_type = LLAMA_POOLING_TYPE_RANK; + } + + cparams.type_k = params.cache_type_k; + cparams.type_v = params.cache_type_v; return cparams; } @@ -2837,17 +1161,38 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p #ifdef LLAMA_USE_CURL -static bool starts_with(const std::string & str, const std::string & prefix) { - // While we wait for C++20's std::string::starts_with... - return str.rfind(prefix, 0) == 0; +#define CURL_MAX_RETRY 3 +#define CURL_RETRY_DELAY_SECONDS 2 + +static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) { + int remaining_attempts = max_attempts; + + while (remaining_attempts > 0) { + LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); + + CURLcode res = curl_easy_perform(curl); + if (res == CURLE_OK) { + return true; + } + + int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000; + LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); + + remaining_attempts--; + std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); + } + + LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); + + return false; } -static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { - +static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl - std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; if (!curl) { - fprintf(stderr, "%s: error initializing libcurl\n", __func__); + LOG_ERR("%s: error initializing libcurl\n", __func__); return false; } @@ -2859,11 +1204,9 @@ static bool llama_download_file(const std::string & url, const std::string & pat // Check if hf-token or bearer-token was specified if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer "; - auth_header += hf_token.c_str(); - struct curl_slist *http_headers = NULL; - http_headers = curl_slist_append(http_headers, auth_header.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers); + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); } #if defined(_WIN32) @@ -2873,8 +1216,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat #endif // Check if the file already exists locally - struct stat model_file_info; - auto file_exists = (stat(path.c_str(), &model_file_info) == 0); + auto file_exists = std::filesystem::exists(path); // If the file exists, check its JSON metadata companion file. std::string metadata_path = path + ".json"; @@ -2888,11 +1230,11 @@ static bool llama_download_file(const std::string & url, const std::string & pat if (metadata_in.good()) { try { metadata_in >> metadata; - fprintf(stderr, "%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); if (metadata.contains("url") && metadata.at("url").is_string()) { auto previous_url = metadata.at("url").get(); if (previous_url != url) { - fprintf(stderr, "%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); + LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); return false; } } @@ -2903,24 +1245,26 @@ static bool llama_download_file(const std::string & url, const std::string & pat last_modified = metadata.at("lastModified"); } } catch (const nlohmann::json::exception & e) { - fprintf(stderr, "%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); return false; } } } else { - fprintf(stderr, "%s: no previous model file found %s\n", __func__, path.c_str()); + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } // Send a HEAD request to retrieve the etag and last-modified headers - struct llama_load_model_from_url_headers { + struct common_load_model_from_url_headers { std::string etag; std::string last_modified; }; - llama_load_model_from_url_headers headers; + + common_load_model_from_url_headers headers; + { typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { - llama_load_model_from_url_headers *headers = (llama_load_model_from_url_headers *) userdata; + common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; static std::regex header_regex("([^:]+): (.*)\r\n"); static std::regex etag_regex("ETag", std::regex_constants::icase); @@ -2945,9 +1289,8 @@ static bool llama_download_file(const std::string & url, const std::string & pat curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - CURLcode res = curl_easy_perform(curl.get()); - if (res != CURLE_OK) { - fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { return false; } @@ -2957,26 +1300,26 @@ static bool llama_download_file(const std::string & url, const std::string & pat // HEAD not supported, we don't know if the file has changed // force trigger downloading force_download = true; - fprintf(stderr, "%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); } } bool should_download = !file_exists || force_download; if (!should_download) { if (!etag.empty() && etag != headers.etag) { - fprintf(stderr, "%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); should_download = true; } else if (!last_modified.empty() && last_modified != headers.last_modified) { - fprintf(stderr, "%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); + LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str()); should_download = true; } } if (should_download) { std::string path_temporary = path + ".downloadInProgress"; if (file_exists) { - fprintf(stderr, "%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); + LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); if (remove(path.c_str()) != 0) { - fprintf(stderr, "%s: unable to delete file: %s\n", __func__, path.c_str()); + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); return false; } } @@ -2991,7 +1334,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); if (!outfile) { - fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str()); + LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str()); return false; } @@ -3022,18 +1365,17 @@ static bool llama_download_file(const std::string & url, const std::string & pat }; // start the download - fprintf(stderr, "%s: downloading from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, - llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); - auto res = curl_easy_perform(curl.get()); - if (res != CURLE_OK) { - fprintf(stderr, "%s: curl_easy_perform() failed: %s\n", __func__, curl_easy_strerror(res)); + LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, + llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + if (!was_perform_successful) { return false; } long http_code = 0; curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code); if (http_code < 200 || http_code >= 400) { - fprintf(stderr, "%s: invalid http status code received: %ld\n", __func__, http_code); + LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code); return false; } @@ -3047,10 +1389,10 @@ static bool llama_download_file(const std::string & url, const std::string & pat {"lastModified", headers.last_modified} }); std::ofstream(metadata_path) << metadata.dump(4); - fprintf(stderr, "%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); if (rename(path_temporary.c_str(), path.c_str()) != 0) { - fprintf(stderr, "%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return false; } } @@ -3058,18 +1400,18 @@ static bool llama_download_file(const std::string & url, const std::string & pat return true; } -struct llama_model * llama_load_model_from_url( - const char * model_url, - const char * path_model, - const char * hf_token, +struct llama_model * common_load_model_from_url( + const std::string & model_url, + const std::string & local_path, + const std::string & hf_token, const struct llama_model_params & params) { // Basic validation of the model_url - if (!model_url || strlen(model_url) == 0) { - fprintf(stderr, "%s: invalid model_url\n", __func__); + if (model_url.empty()) { + LOG_ERR("%s: invalid model_url\n", __func__); return NULL; } - if (!llama_download_file(model_url, path_model, hf_token)) { + if (!common_download_file(model_url, local_path, hf_token)) { return NULL; } @@ -3080,9 +1422,9 @@ struct llama_model * llama_load_model_from_url( /*.no_alloc = */ true, /*.ctx = */ NULL, }; - auto * ctx_gguf = gguf_init_from_file(path_model, gguf_params); + auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params); if (!ctx_gguf) { - fprintf(stderr, "\n%s: failed to load input GGUF from %s\n", __func__, path_model); + LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str()); return NULL; } @@ -3101,15 +1443,13 @@ struct llama_model * llama_load_model_from_url( // Verify the first split file format // and extract split URL and PATH prefixes { - if (!llama_split_prefix(split_prefix, sizeof(split_prefix), path_model, 0, n_split)) { - fprintf(stderr, "\n%s: unexpected model file name: %s" - " n_split=%d\n", __func__, path_model, n_split); + if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split); return NULL; } - if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url, 0, n_split)) { - fprintf(stderr, "\n%s: unexpected model url: %s" - " n_split=%d\n", __func__, model_url, n_split); + if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split); return NULL; } } @@ -3124,7 +1464,7 @@ struct llama_model * llama_load_model_from_url( char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - return llama_download_file(split_url, split_path, hf_token); + return common_download_file(split_url, split_path, hf_token); }, idx)); } @@ -3136,14 +1476,14 @@ struct llama_model * llama_load_model_from_url( } } - return llama_load_model_from_file(path_model, params); + return llama_model_load_from_file(local_path.c_str(), params); } -struct llama_model * llama_load_model_from_hf( - const char * repo, - const char * model, - const char * path_model, - const char * hf_token, +struct llama_model * common_load_model_from_hf( + const std::string & repo, + const std::string & remote_path, + const std::string & local_path, + const std::string & hf_token, const struct llama_model_params & params) { // construct hugging face model url: // @@ -3157,48 +1497,129 @@ struct llama_model * llama_load_model_from_hf( std::string model_url = "https://huggingface.co/"; model_url += repo; model_url += "/resolve/main/"; - model_url += model; + model_url += remote_path; - return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params); + return common_load_model_from_url(model_url, local_path, hf_token, params); +} + +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) + * + * Return pair of (with "repo" already having tag removed) + * + * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + */ +std::pair common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); + } + + // fetch model info from Hugging Face Hub API + json model_info; + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::string res_str; + std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + static_cast(data)->append((char * ) ptr, size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (!hf_token.empty()) { + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + throw std::runtime_error("error: cannot make GET request to HF API"); + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + if (res_code == 200) { + model_info = json::parse(res_str); + } else if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + } + + // check response + if (!model_info.contains("ggufFile")) { + throw std::runtime_error("error: model does not have ggufFile"); + } + json & gguf_file = model_info.at("ggufFile"); + if (!gguf_file.contains("rfilename")) { + throw std::runtime_error("error: ggufFile does not have rfilename"); + } + + return std::make_pair(hf_repo, gguf_file.at("rfilename")); } #else -struct llama_model * llama_load_model_from_url( - const char * /*model_url*/, - const char * /*path_model*/, - const char * /*hf_token*/, +struct llama_model * common_load_model_from_url( + const std::string & /*model_url*/, + const std::string & /*local_path*/, + const std::string & /*hf_token*/, const struct llama_model_params & /*params*/) { - fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); + LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); return nullptr; } -struct llama_model * llama_load_model_from_hf( - const char * /*repo*/, - const char * /*model*/, - const char * /*path_model*/, - const char * /*hf_token*/, +struct llama_model * common_load_model_from_hf( + const std::string & /*repo*/, + const std::string & /*remote_path*/, + const std::string & /*local_path*/, + const std::string & /*hf_token*/, const struct llama_model_params & /*params*/) { - fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); + LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); return nullptr; } +std::pair common_get_hf_file(const std::string &, const std::string &) { + LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); + return std::make_pair("", ""); +} + #endif // LLAMA_USE_CURL // // Batch utils // -void llama_batch_clear(struct llama_batch & batch) { +void common_batch_clear(struct llama_batch & batch) { batch.n_tokens = 0; } -void llama_batch_add( +void common_batch_add( struct llama_batch & batch, llama_token id, llama_pos pos, const std::vector & seq_ids, bool logits) { + GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); + batch.token [batch.n_tokens] = id; batch.pos [batch.n_tokens] = pos; batch.n_seq_id[batch.n_tokens] = seq_ids.size(); @@ -3210,30 +1631,92 @@ void llama_batch_add( batch.n_tokens++; } +// +// Token utils +// + +size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + + return i; +} + +size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { + // check for empty sequences + if (a.empty() || b.empty()) { + return 0; + } + + // get the lengths of the input sequences + size_t a_len = a.size(); + size_t b_len = b.size(); + + // initialize the maximum length of the longest common subsequence (LCS) + size_t max_length = 0; + + // use two rows instead of a 2D matrix to optimize space + std::vector prev_row(b_len + 1, 0); + std::vector curr_row(b_len + 1, 0); + + // iterate through the elements of a + for (size_t i = 1; i <= a_len; i++) { + // iterate through the elements of b + for (size_t j = 1; j <= b_len; j++) { + // if elements at the current positions match + if (a[i - 1] == b[j - 1]) { + // if it's the first element of either sequences, set LCS length to 1 + if (i == 1 || j == 1) { + curr_row[j] = 1; + } else { + // increment LCS length by 1 compared to the previous element + curr_row[j] = prev_row[j - 1] + 1; + } + + // update max_length if necessary + if (curr_row[j] > max_length) { + max_length = curr_row[j]; + } + } else { + // reset LCS length if elements don't match + curr_row[j] = 0; + } + } + + // update the previous row for the next iteration + prev_row = curr_row; + } + + // return the maximum length of the LCS + return max_length; +} + // // Vocab utils // -std::vector llama_tokenize( +std::vector common_tokenize( const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special) { - return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_tokenize(vocab, text, add_special, parse_special); } -std::vector llama_tokenize( - const struct llama_model * model, +std::vector common_tokenize( + const struct llama_vocab * vocab, const std::string & text, bool add_special, bool parse_special) { // upper limit for the number of tokens int n_tokens = text.length() + 2 * add_special; std::vector result(n_tokens); - n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -3241,13 +1724,19 @@ std::vector llama_tokenize( return result; } -std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { +std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_token_to_piece(vocab, token, special); +} + +std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); if (n_chars < 0) { piece.resize(-n_chars); - int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); GGML_ASSERT(check == -n_chars); } else { @@ -3257,13 +1746,19 @@ std::string llama_token_to_piece(const struct llama_context * ctx, llama_token t return piece; } -std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { +std::string common_detokenize(const struct llama_context * ctx, const std::vector & tokens, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_detokenize(vocab, tokens, special); +} + +std::string common_detokenize(const struct llama_vocab * vocab, const std::vector & tokens, bool special) { std::string text; text.resize(std::max(text.capacity(), tokens.size())); - int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); if (n_chars < 0) { text.resize(-n_chars); - n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization } @@ -3277,92 +1772,162 @@ std::string llama_detokenize(llama_context * ctx, const std::vector // Chat template utils // -bool llama_chat_verify_template(const std::string & tmpl) { +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { + if (use_jinja) { + try { + auto chat_template = common_chat_template(tmpl, "", ""); + common_chat_inputs inputs; + inputs.messages = json::array({{ + {"role", "user"}, + {"content", "test"}, + }}); + common_chat_params_init(chat_template, inputs); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } llama_chat_message chat[] = {{"user", "test"}}; - int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); + const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } -std::string llama_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, - const std::vector & msgs, - bool add_ass) { +std::string common_chat_apply_template( + const common_chat_template & tmpl, + const std::vector & msgs, + bool add_ass, + bool use_jinja) { + if (use_jinja) { + auto messages = json::array(); + for (const auto & msg : msgs) { + messages.push_back({{"role", msg.role}, {"content", msg.content}}); + } + common_chat_inputs inputs; + inputs.messages = messages; + inputs.add_generation_prompt = add_ass; + return common_chat_params_init(tmpl, inputs).prompt; + } + int alloc_size = 0; - bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; - for (auto & msg : msgs) { + for (const auto & msg : msgs) { chat.push_back({msg.role.c_str(), msg.content.c_str()}); alloc_size += (msg.role.size() + msg.content.size()) * 1.25; } - const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { - if (ptr_tmpl != nullptr) { - // if the custom "tmpl" is not supported, we throw an error - // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); - } else { - // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - fallback = true; - } + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); } // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template( - fallback ? nullptr : model, - fallback ? "chatml" : ptr_tmpl, - chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } std::string formatted_chat(buf.data(), res); return formatted_chat; } -std::string llama_chat_format_single(const struct llama_model * model, - const std::string & tmpl, - const std::vector & past_msg, - const llama_chat_msg & new_msg, - bool add_ass) { +std::string common_chat_format_single( + const common_chat_template & tmpl, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja) { std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false); - std::vector chat_new(past_msg); + auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); + std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { ss << "\n"; }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } -std::string llama_chat_format_example(const struct llama_model * model, - const std::string & tmpl) { - std::vector msgs = { - {"system", "You are a helpful assistant"}, - {"user", "Hello"}, - {"assistant", "Hi there"}, - {"user", "How are you?"}, +std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { + std::vector msgs = { + {"system", "You are a helpful assistant", {}}, + {"user", "Hello", {}}, + {"assistant", "Hi there", {}}, + {"user", "How are you?", {}}, + }; + return common_chat_apply_template(tmpl, msgs, true, use_jinja); +} + +common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) +{ + auto vocab = llama_model_get_vocab(model); + std::string default_template_src = chat_template_override; + std::string template_tool_use_src = chat_template_override; + bool has_explicit_template = !chat_template_override.empty(); + if (chat_template_override.empty()) { + auto str = llama_model_chat_template(model, /* name */ nullptr); + if (str) { + default_template_src = str; + has_explicit_template = true; + } + str = llama_model_chat_template(model, /* name */ "tool_use"); + if (str) { + template_tool_use_src = str; + has_explicit_template = true; + } + } + if (default_template_src.empty() || default_template_src == "chatml") { + if (!template_tool_use_src.empty()) { + default_template_src = template_tool_use_src; + } else { + default_template_src = R"( + {%- for message in messages -%} + {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} + {%- endfor -%} + {%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} + {%- endif -%} + )"; + } + } + const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { + if (token == LLAMA_TOKEN_NULL) { + if (default_template_src.find(jinja_variable_name) != std::string::npos + || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { + LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name); + } + return std::string(); + } else { + return common_token_to_piece(vocab, token, true); + } + }; + auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); + return { + has_explicit_template, + std::make_unique(default_template_src, token_bos, token_eos), + template_tool_use_src.empty() + ? nullptr + : std::make_unique(template_tool_use_src, token_bos, token_eos) }; - return llama_chat_apply_template(model, tmpl, msgs, true); } // // KV cache utils // -void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { +void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+"; printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d", @@ -3385,7 +1950,7 @@ void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { printf("\n=== Done dumping\n"); } -void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) { +void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) { static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", @@ -3437,7 +2002,7 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) { +void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) { double sum = 0.0; switch (embd_norm) { @@ -3446,7 +2011,9 @@ void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) break; case 0: // max absolute for (int i = 0; i < n; i++) { - if (sum < std::abs(inp[i])) sum = std::abs(inp[i]); + if (sum < std::abs(inp[i])) { + sum = std::abs(inp[i]); + } } sum /= 32760.0; // make an int16 range break; @@ -3471,7 +2038,7 @@ void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) } } -float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n){ +float common_embd_similarity_cos(const float * embd1, const float * embd2, int n){ double sum = 0.0; double sum1 = 0.0; double sum2 = 0.0; @@ -3497,8 +2064,8 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n) // Control vector utils // -static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) { - llama_control_vector_data result = { -1, {} }; +static common_control_vector_data common_control_vector_load_one(const common_control_vector_load_info & load_info) { + common_control_vector_data result = { -1, {} }; ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { @@ -3507,13 +2074,13 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr }; struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params); if (!ctx_gguf) { - fprintf(stderr, "%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str()); return result; } int32_t n_tensors = gguf_get_n_tensors(ctx_gguf); if (n_tensors == 0) { - fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); + LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); } for (int i = 0; i < n_tensors; i++) { @@ -3531,23 +2098,23 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr } } if (layer_idx < 0) { - fprintf(stderr, "%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } else if (layer_idx == 0) { - fprintf(stderr, "%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str()); if (tensor->type != GGML_TYPE_F32) { - fprintf(stderr, "%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } if (ggml_n_dims(tensor) != 1) { - fprintf(stderr, "%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } @@ -3555,7 +2122,7 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr if (result.n_embd == -1) { result.n_embd = ggml_nelements(tensor); } else if (ggml_nelements(tensor) != result.n_embd) { - fprintf(stderr, "%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str()); + LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str()); result.n_embd = -1; break; } @@ -3572,7 +2139,7 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr } if (result.n_embd == -1) { - fprintf(stderr, "%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str()); + LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str()); result.data.clear(); } @@ -3582,18 +2149,18 @@ static llama_control_vector_data llama_control_vector_load_one(const llama_contr return result; } -llama_control_vector_data llama_control_vector_load(const std::vector & load_infos) { - llama_control_vector_data result = { -1, {} }; +common_control_vector_data common_control_vector_load(const std::vector & load_infos) { + common_control_vector_data result = { -1, {} }; for (const auto & info : load_infos) { - auto cur = llama_control_vector_load_one(info); + auto cur = common_control_vector_load_one(info); if (cur.n_embd == -1) { result.n_embd = -1; break; } if (result.n_embd != -1 && result.n_embd != cur.n_embd) { - fprintf(stderr, "%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str()); + LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str()); result.n_embd = -1; break; } @@ -3609,217 +2176,10 @@ llama_control_vector_data llama_control_vector_load(const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%e, ", data[i]); - } - fprintf(stream, "%e]\n", data.back()); -} - -void yaml_dump_vector_int(FILE * stream, const char * prop_name, const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%d, ", data[i]); - } - fprintf(stream, "%d]\n", data.back()); -} - -void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data) { - std::string data_str(data == NULL ? "" : data); - - if (data_str.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - size_t pos_start = 0; - size_t pos_found = 0; - - if (std::isspace(data_str[0]) || std::isspace(data_str.back())) { - data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); - data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); - data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)"); - data_str = "\"" + data_str + "\""; - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - if (data_str.find('\n') == std::string::npos) { - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - fprintf(stream, "%s: |\n", prop_name); - while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { - fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); - pos_start = pos_found + 1; - } -} - -void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const auto & sparams = params.sparams; - - fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); - fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); - fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); - fprintf(stream, "cpu_has_avx_vnni: %s\n", ggml_cpu_has_avx_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_cuda: %s\n", ggml_cpu_has_cuda() ? "true" : "false"); - fprintf(stream, "cpu_has_vulkan: %s\n", ggml_cpu_has_vulkan() ? "true" : "false"); - fprintf(stream, "cpu_has_kompute: %s\n", ggml_cpu_has_kompute() ? "true" : "false"); - fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); - fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); - fprintf(stream, "cpu_has_sve: %s\n", ggml_cpu_has_sve() ? "true" : "false"); - fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); - fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); - fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); - fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); - fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false"); - -#ifdef NDEBUG - fprintf(stream, "debug: false\n"); -#else - fprintf(stream, "debug: true\n"); -#endif // NDEBUG - - fprintf(stream, "model_desc: %s\n", model_desc); - fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); - -#ifdef __OPTIMIZE__ - fprintf(stream, "optimize: true\n"); -#else - fprintf(stream, "optimize: false\n"); -#endif // __OPTIMIZE__ - - fprintf(stream, "time: %s\n", timestamp.c_str()); - - fprintf(stream, "\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "# User Inputs #\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "\n"); - - fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); - fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); - fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); - fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); - fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); - fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); - yaml_dump_string_multiline(stream, "grammar", sparams.grammar.c_str()); - fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); - fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); - fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false"); - - yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); - fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); - yaml_dump_string_multiline(stream, "in_suffix", params.input_prefix.c_str()); - fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); - fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); - fprintf(stream, "keep: %d # default: 0\n", params.n_keep); - fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); - - fprintf(stream, "logit_bias:\n"); - for (const auto & logit_bias : sparams.logit_bias) { - fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias); - } - - fprintf(stream, "lora:\n"); - for (auto & la : params.lora_adapters) { - if (la.scale == 1.0f) { - fprintf(stream, " - %s\n", la.path.c_str()); - } - } - fprintf(stream, "lora_scaled:\n"); - for (auto & la : params.lora_adapters) { - if (la.scale != 1.0f) { - fprintf(stream, " - %s: %f\n", la.path.c_str(), la.scale); - } - } - fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false"); - fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); - fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); - fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); - fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); - fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); - fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); - fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); - fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); - fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); - fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); - fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); - fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); - fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); - fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); - fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); - fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); - yaml_dump_string_multiline(stream, "prompt", params.prompt.c_str()); - fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); - fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); - fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); - yaml_dump_vector_int(stream, "prompt_tokens", prompt_tokens); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); - - fprintf(stream, "reverse_prompt:\n"); - for (std::string ap : params.antiprompt) { - size_t pos = 0; - while ((pos = ap.find('\n', pos)) != std::string::npos) { - ap.replace(pos, 1, "\\n"); - pos += 1; - } - - fprintf(stream, " - %s\n", ap.c_str()); - } - - fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); - fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); - fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); - fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); - fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); - - const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); - yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector); - - fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); - fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency()); - fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); - fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); - fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); - fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p); - fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); - fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); -} diff --git a/common/common.h b/common/common.h index d10ec6235..1b9920689 100644 --- a/common/common.h +++ b/common/common.h @@ -2,22 +2,11 @@ #pragma once -#include "llama.h" +#include "llama-cpp.h" -#include "sampling.h" - -#define LOG_NO_FILE_LINE_FUNCTION -#include "log.h" - -#include #include #include -#include -#include -#include -#include -#include -#include +#include #ifdef _WIN32 #define DIRECTORY_SEPARATOR '\\' @@ -35,32 +24,41 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" -struct llama_lora_adapter_info { +struct common_adapter_lora_info { std::string path; float scale; + + struct llama_adapter_lora * ptr; }; -struct llama_lora_adapter_container : llama_lora_adapter_info { - struct llama_lora_adapter * adapter; -}; +using llama_tokens = std::vector; // build info extern int LLAMA_BUILD_NUMBER; -extern char const * LLAMA_COMMIT; -extern char const * LLAMA_COMPILER; -extern char const * LLAMA_BUILD_TARGET; +extern const char * LLAMA_COMMIT; +extern const char * LLAMA_COMPILER; +extern const char * LLAMA_BUILD_TARGET; -struct llama_control_vector_load_info; +struct common_control_vector_load_info; // // CPU utils // +struct cpu_params { + int n_threads = -1; + bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask. + bool mask_valid = false; // Default: any CPU + enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime) + bool strict_cpu = false; // Use strict CPU placement + uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling) +}; + int32_t cpu_get_num_physical_cores(); int32_t cpu_get_num_math(); // -// CLI argument parsing +// Common params // enum llama_example { @@ -78,42 +76,139 @@ enum llama_example { LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_LLAVA, + LLAMA_EXAMPLE_LOOKUP, + LLAMA_EXAMPLE_PARALLEL, + LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_COUNT, }; +enum common_sampler_type { + COMMON_SAMPLER_TYPE_NONE = 0, + COMMON_SAMPLER_TYPE_DRY = 1, + COMMON_SAMPLER_TYPE_TOP_K = 2, + COMMON_SAMPLER_TYPE_TOP_P = 3, + COMMON_SAMPLER_TYPE_MIN_P = 4, + //COMMON_SAMPLER_TYPE_TFS_Z = 5, + COMMON_SAMPLER_TYPE_TYPICAL_P = 6, + COMMON_SAMPLER_TYPE_TEMPERATURE = 7, + COMMON_SAMPLER_TYPE_XTC = 8, + COMMON_SAMPLER_TYPE_INFILL = 9, + COMMON_SAMPLER_TYPE_PENALTIES = 10, +}; + // dimensionality reduction methods, used by cvector-generator enum dimre_method { DIMRE_METHOD_PCA, DIMRE_METHOD_MEAN, }; -struct cpu_params { - int n_threads = -1; - bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask. - bool mask_valid = false; // Default: any CPU - enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime) - bool strict_cpu = false; // Use strict CPU placement - uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling) +enum common_conversation_mode { + COMMON_CONVERSATION_MODE_DISABLED = 0, + COMMON_CONVERSATION_MODE_ENABLED = 1, + COMMON_CONVERSATION_MODE_AUTO = 2, }; -struct gpt_params { - enum llama_example curr_ex = LLAMA_EXAMPLE_COMMON; +struct common_grammar_trigger { + std::string word; + bool at_start; +}; +// sampling parameters +struct common_params_sampling { + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler + + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float xtc_probability = 0.00f; // 0.0 = disabled + float xtc_threshold = 0.10f; // > 0.5 disables XTC + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool ignore_eos = false; + bool no_perf = false; // disable performance metrics + bool timing_per_token = false; + + std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY + + + std::vector samplers = { + COMMON_SAMPLER_TYPE_PENALTIES, + COMMON_SAMPLER_TYPE_DRY, + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TYPICAL_P, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_MIN_P, + COMMON_SAMPLER_TYPE_XTC, + COMMON_SAMPLER_TYPE_TEMPERATURE, + }; + + std::string grammar; // optional BNF-like grammar to constrain sampling + bool grammar_lazy = false; + std::vector grammar_trigger_words; // optional trigger words to trigger lazy grammar + std::vector grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens. + + std::vector logit_bias; // logit biases to apply + + // print the parameters into a string + std::string print() const; +}; + +struct common_params_speculative { + std::vector devices; // devices to use for offloading + + int32_t n_ctx = 0; // draft context size + int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding + int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + float p_split = 0.1f; // speculative decoding split probability + float p_min = 0.9f; // minimum speculative decoding probability (greedy) + + struct cpu_params cpuparams; + struct cpu_params cpuparams_batch; + + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + + std::string model = ""; // draft model for speculative decoding // NOLINT + std::string model_url = ""; // model url to download // NOLINT +}; + +struct common_params_vocoder { + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + + std::string model = ""; // model path // NOLINT + std::string model_url = ""; // model url to download // NOLINT + + bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT +}; + +struct common_params { int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 0; // context size + int32_t n_ctx = 4096; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode - float p_split = 0.1f; // speculative decoding split probability - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs int32_t grp_attn_n = 1; // group-attention factor int32_t grp_attn_w = 512; // group-attention width int32_t n_print = -1; // print token count every n tokens (-1 = disabled) @@ -124,51 +219,56 @@ struct gpt_params { float yarn_beta_fast = 32.0f; // YaRN low correction dim float yarn_beta_slow = 1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length - float defrag_thold = -1.0f; // KV cache defragmentation threshold + float defrag_thold = 0.1f; // KV cache defragmentation threshold + + // offload params + std::vector devices; // devices to use for offloading + + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + + enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - struct cpu_params draft_cpuparams; - struct cpu_params draft_cpuparams_batch; ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; - enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - struct gpt_sampler_params sparams; + struct common_params_sampling sampling; + struct common_params_speculative speculative; + struct common_params_vocoder vocoder; - std::string model = ""; // model path - std::string model_draft = ""; // draft model for speculative decoding - std::string model_alias = "unknown"; // model alias - std::string model_url = ""; // model url to download - std::string hf_token = ""; // HF token - std::string hf_repo = ""; // HF repo - std::string hf_file = ""; // HF file - std::string prompt = ""; - std::string prompt_file = ""; // store the external prompt file name - std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::string logdir = ""; // directory in which to save YAML log files - std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding - std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding - std::string logits_file = ""; // file for saving *all* logits - std::string rpc_servers = ""; // comma separated list of RPC servers + std::string model = ""; // model path // NOLINT + std::string model_alias = ""; // model alias // NOLINT + std::string model_url = ""; // model url to download // NOLINT + std::string hf_token = ""; // HF token // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + std::string prompt = ""; // NOLINT + std::string prompt_file = ""; // store the external prompt file name // NOLINT + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT + std::string input_prefix = ""; // string to prefix user inputs with // NOLINT + std::string input_suffix = ""; // string to suffix user inputs with // NOLINT + std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT + std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT + std::string logits_file = ""; // file for saving *all* logits // NOLINT std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; - bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) - std::vector lora_adapters; // lora adapter path with user defined scale + bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) + std::vector lora_adapters; // lora adapter path with user defined scale - std::vector control_vectors; // control vector with user defined scale + std::vector control_vectors; // control vector with user defined scale int32_t verbosity = 0; int32_t control_vector_layer_start = -1; // layer range for control vector @@ -189,13 +289,11 @@ struct gpt_params { bool kl_divergence = false; // compute KL divergence - std::function print_usage = nullptr; // print example-specific usage and example bool usage = false; // print usage bool use_color = false; // use color to distinguish generations and inputs bool special = false; // enable special token output bool interactive = false; // interactive mode bool interactive_first = false; // wait for user input immediately - bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it @@ -204,6 +302,8 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention + bool no_perf = false; // disable performance metrics + bool ctx_shift = true; // context shift on inifinite text generation bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool logits_all = false; // return logits for all tokens in the batch @@ -211,43 +311,49 @@ struct gpt_params { bool use_mlock = false; // use mlock to keep model in memory bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation - bool infill = false; // use infill mode bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data - std::string cache_type_k = "f16"; // KV cache data type for the K - std::string cache_type_v = "f16"; // KV cache data type for the V + ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K + ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V + + common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; // multimodal models (see examples/llava) - std::string mmproj = ""; // path to multimodal projector + std::string mmproj = ""; // path to multimodal projector // NOLINT std::vector image; // path to image file(s) // embedding bool embedding = false; // get only sentence embedding - int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) + int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix - std::string embd_sep = "\n"; // separator of embendings + std::string embd_sep = "\n"; // separator of embeddings + bool reranking = false; // enable reranking support on server // server params int32_t port = 8080; // server listens on this network port int32_t timeout_read = 600; // http read timeout in seconds int32_t timeout_write = timeout_read; // http write timeout in seconds - int n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) + int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) + int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting std::string hostname = "127.0.0.1"; - std::string public_path = ""; - std::string chat_template = ""; - std::string system_prompt = ""; + std::string public_path = ""; // NOLINT + std::string chat_template = ""; // NOLINT + bool use_jinja = false; // NOLINT bool enable_chat_template = true; std::vector api_keys; - std::string ssl_file_key = ""; - std::string ssl_file_cert = ""; + std::string ssl_file_key = ""; // NOLINT + std::string ssl_file_cert = ""; // NOLINT - bool endpoint_slots = true; + // "advanced" endpoints are disabled by default for better security + bool webui = true; + bool endpoint_slots = false; + bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; bool log_json = false; @@ -300,112 +406,46 @@ struct gpt_params { bool batched_bench_output_jsonl = false; }; -struct llama_arg { - std::set examples = {LLAMA_EXAMPLE_COMMON}; - std::vector args; - const char * value_hint = nullptr; // help text or example for arg value - const char * value_hint_2 = nullptr; // for second arg value - const char * env = nullptr; - std::string help; - void (*handler_void) (gpt_params & params) = nullptr; - void (*handler_string) (gpt_params & params, const std::string &) = nullptr; - void (*handler_str_str)(gpt_params & params, const std::string &, const std::string &) = nullptr; - void (*handler_int) (gpt_params & params, int) = nullptr; +// call once at the start of a program if it uses libcommon +// initializes the logging system and prints info about the build +void common_init(); - llama_arg( - const std::initializer_list & args, - const char * value_hint, - const std::string & help, - void (*handler)(gpt_params & params, const std::string &) - ) : args(args), value_hint(value_hint), help(help), handler_string(handler) {} +std::string common_params_get_system_info(const common_params & params); - llama_arg( - const std::initializer_list & args, - const char * value_hint, - const std::string & help, - void (*handler)(gpt_params & params, int) - ) : args(args), value_hint(value_hint), help(help), handler_int(handler) {} - - llama_arg( - const std::initializer_list & args, - const std::string & help, - void (*handler)(gpt_params & params) - ) : args(args), help(help), handler_void(handler) {} - - // support 2 values for arg - llama_arg( - const std::initializer_list & args, - const char * value_hint, - const char * value_hint_2, - const std::string & help, - void (*handler)(gpt_params & params, const std::string &, const std::string &) - ) : args(args), value_hint(value_hint), value_hint_2(value_hint_2), help(help), handler_str_str(handler) {} - - llama_arg & set_examples(std::initializer_list examples) { - this->examples = std::move(examples); - return *this; - } - - llama_arg & set_env(const char * env) { - help = help + "\n(env: " + env + ")"; - this->env = env; - return *this; - } - - bool in_example(enum llama_example ex) { - return examples.find(ex) != examples.end(); - } - - bool get_value_from_env(std::string & output) const { - if (env == nullptr) return false; - char * value = std::getenv(env); - if (value) { - output = value; - return true; - } - return false; - } - - bool has_value_from_env() const { - return env != nullptr && std::getenv(env); - } - - std::string to_string(); -}; - -// initialize list of options (arguments) that can be used by the current example -std::vector gpt_params_parser_init(gpt_params & params, llama_example ex); -// optionally, we can provide "print_usage" to print example usage -std::vector gpt_params_parser_init(gpt_params & params, llama_example ex, std::function print_usage); - -// parse input arguments from CLI -// if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message) -bool gpt_params_parse (int argc, char ** argv, gpt_params & params, std::vector & options); -bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params, std::vector & options); - -// print full usage message; it will be called internally by gpt_params_parse() if "-h" is set -void gpt_params_print_usage(gpt_params & params, std::vector & options); - -std::string gpt_params_get_system_info(const gpt_params & params); - -bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]); -bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]); -void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr); +bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]); +bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]); +void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr); bool set_process_priority(enum ggml_sched_priority prio); // // String utils // -std::vector string_split(std::string input, char separator); +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) +#endif + +LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) +std::string string_format(const char * fmt, ...); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); +std::string string_join(const std::vector & values, const std::string & separator); +std::vector string_split(const std::string & str, const std::string & delimiter); +std::string string_repeat(const std::string & str, size_t n); + void string_replace_all(std::string & s, const std::string & search, const std::string & replace); template static std::vector string_split(const std::string & str, char delim) { + static_assert(!std::is_same::value, "Please use the specialized version for std::string"); std::vector values; std::istringstream str_stream(str); std::string token; @@ -418,9 +458,40 @@ static std::vector string_split(const std::string & str, char delim) { return values; } +template<> +std::vector string_split(const std::string & input, char separator) +{ + std::vector parts; + size_t begin_pos = 0; + size_t separator_pos = input.find(separator); + while (separator_pos != std::string::npos) { + std::string part = input.substr(begin_pos, separator_pos - begin_pos); + parts.emplace_back(part); + begin_pos = separator_pos + 1; + separator_pos = input.find(separator, begin_pos); + } + parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos)); + return parts; +} + +static bool string_starts_with(const std::string & str, + const std::string & prefix) { // While we wait for C++20's std::string::starts_with... + return str.rfind(prefix, 0) == 0; +} + +static bool string_ends_with(const std::string & str, + const std::string & suffix) { // While we wait for C++20's std::string::ends_with... + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); +std::string string_from(bool value); +std::string string_from(const std::vector & values); +std::string string_from(const struct llama_context * ctx, const std::vector & tokens); +std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch); + // // Filesystem utils // @@ -435,65 +506,103 @@ std::string fs_get_cache_file(const std::string & filename); // Model utils // -struct llama_init_result { - struct llama_model * model = nullptr; - struct llama_context * context = nullptr; - std::vector lora_adapters; +// note: defines object's lifetime +struct common_init_result { + llama_model_ptr model; + llama_context_ptr context; + + std::vector lora; }; -struct llama_init_result llama_init_from_gpt_params(gpt_params & params); +struct common_init_result common_init_from_params(common_params & params); -struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); -struct llama_context_params llama_context_params_from_gpt_params (const gpt_params & params); +struct llama_model_params common_model_params_to_llama ( common_params & params); +struct llama_context_params common_context_params_to_llama(const common_params & params); struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); -struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params); -struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params); +struct llama_model * common_load_model_from_url( + const std::string & model_url, + const std::string & local_path, + const std::string & hf_token, + const struct llama_model_params & params); + +struct llama_model * common_load_model_from_hf( + const std::string & repo, + const std::string & remote_path, + const std::string & local_path, + const std::string & hf_token, + const struct llama_model_params & params); + +std::pair common_get_hf_file( + const std::string & hf_repo_with_tag, + const std::string & hf_token); // clear LoRA adapters from context, then apply new list of adapters -void llama_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); +// // Batch utils +// -void llama_batch_clear(struct llama_batch & batch); +void common_batch_clear(struct llama_batch & batch); -void llama_batch_add( +void common_batch_add( struct llama_batch & batch, llama_token id, llama_pos pos, const std::vector & seq_ids, bool logits); +// +// Token utils +// + +// longest common prefix +size_t common_lcp(const llama_tokens & a, const llama_tokens & b); + +// longet common subsequence +size_t common_lcs(const llama_tokens & a, const llama_tokens & b); + // // Vocab utils // // tokenizes a string into a vector of tokens // should work similar to Python's `tokenizer.encode` -std::vector llama_tokenize( +std::vector common_tokenize( const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special = false); -std::vector llama_tokenize( - const struct llama_model * model, +std::vector common_tokenize( + const struct llama_vocab * vocab, const std::string & text, bool add_special, bool parse_special = false); // tokenizes a token into a piece, optionally renders special/control tokens // should work similar to Python's `tokenizer.id_to_piece` -std::string llama_token_to_piece( +std::string common_token_to_piece( const struct llama_context * ctx, llama_token token, bool special = true); +std::string common_token_to_piece( + const struct llama_vocab * vocab, + llama_token token, + bool special = true); + // detokenizes a vector of tokens into a string // should work similar to Python's `tokenizer.decode` // optionally renders special/control tokens -std::string llama_detokenize( - llama_context * ctx, +std::string common_detokenize( + const struct llama_context * ctx, + const std::vector & tokens, + bool special = true); + +std::string common_detokenize( + const struct llama_vocab * vocab, const std::vector & tokens, bool special = true); @@ -501,64 +610,88 @@ std::string llama_detokenize( // Chat template utils // +struct common_tool_call { + std::string name; + std::string arguments; + std::string id; +}; + // same with llama_chat_message, but uses std::string -struct llama_chat_msg { +struct common_chat_msg { std::string role; std::string content; + std::vector tool_calls; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool llama_chat_verify_template(const std::string & tmpl); +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); + +namespace minja { + class chat_template; +} + +typedef minja::chat_template common_chat_template; + +struct common_chat_templates { + bool has_explicit_template; // Model had builtin template or template overridde was specified. + std::unique_ptr template_default; // always set (defaults to chatml) + std::unique_ptr template_tool_use; +}; // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error -std::string llama_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, - const std::vector & chat, - bool add_ass); +std::string common_chat_apply_template( + const common_chat_template & tmpl, + const std::vector & chat, + bool add_ass, + bool use_jinja); // Format single message, while taking into account the position of that message in chat history -std::string llama_chat_format_single(const struct llama_model * model, - const std::string & tmpl, - const std::vector & past_msg, - const llama_chat_msg & new_msg, - bool add_ass); +std::string common_chat_format_single( + const common_chat_template & tmpl, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja); // Returns an example of formatted chat -std::string llama_chat_format_example(const struct llama_model * model, - const std::string & tmpl); +std::string common_chat_format_example( + const common_chat_template & tmpl, bool use_jinja); + +common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); // // KV cache utils // // Dump the KV cache view with the number of sequences per cell. -void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80); +void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). -void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40); +void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40); // // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); +// TODO: repace embd_norm with an enum +void common_embd_normalize(const float * inp, float * out, int n, int embd_norm); -float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); +float common_embd_similarity_cos(const float * embd1, const float * embd2, int n); // // Control vector utils // -struct llama_control_vector_data { +struct common_control_vector_data { int n_embd; // stores data for layers [1, n_layer] where n_layer = data.size() / n_embd std::vector data; }; -struct llama_control_vector_load_info { +struct common_control_vector_load_info { float strength; std::string fname; @@ -566,24 +699,16 @@ struct llama_control_vector_load_info { // Load control vectors, scale each by strength, and add them together. // On error, returns {-1, empty} -llama_control_vector_data llama_control_vector_load(const std::vector & load_infos); +common_control_vector_data common_control_vector_load(const std::vector & load_infos); // // Split utils // -static const char * const LLM_KV_SPLIT_NO = "split.no"; -static const char * const LLM_KV_SPLIT_COUNT = "split.count"; -static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; +namespace { -// -// YAML utils -// +const char * const LLM_KV_SPLIT_NO = "split.no"; +const char * const LLM_KV_SPLIT_COUNT = "split.count"; +const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; -void yaml_dump_vector_float (FILE * stream, const char * prop_name, const std::vector & data); -void yaml_dump_vector_int (FILE * stream, const char * prop_name, const std::vector & data); -void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data); - -void yaml_dump_non_result_info( - FILE * stream, const gpt_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); +} diff --git a/common/console.cpp b/common/console.cpp index f65cbc6ed..078a8d678 100644 --- a/common/console.cpp +++ b/common/console.cpp @@ -94,6 +94,9 @@ namespace console { simple_io = true; } } + if (simple_io) { + _setmode(_fileno(stdin), _O_U8TEXT); + } #else // POSIX-specific console initialization if (!simple_io) { diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 881eb49e3..1f47e313e 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -1,4 +1,6 @@ #include "json-schema-to-grammar.h" +#include "common.h" + #include #include #include @@ -11,11 +13,6 @@ using json = nlohmann::ordered_json; -template -static std::string join(Iterator begin, Iterator end, const std::string & separator); - -static std::string repeat(const std::string & str, size_t n); - static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { auto has_max = max_items != std::numeric_limits::max(); @@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & if (sub_len > 0) { auto from_sub = from.substr(i + 1); auto to_sub = to.substr(i + 1); - auto sub_zeros = repeat("0", sub_len); - auto sub_nines = repeat("9", sub_len); + auto sub_zeros = string_repeat("0", sub_len); + auto sub_nines = string_repeat("9", sub_len); auto to_reached = false; out << "("; @@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & auto max_digits = max_s.length(); for (auto digits = min_digits; digits < max_digits; digits++) { - uniform_range(min_s, repeat("9", digits)); - min_s = "1" + repeat("0", digits); + uniform_range(min_s, string_repeat("9", digits)); + min_s = "1" + string_repeat("0", digits); out << " | "; } uniform_range(min_s, max_s); @@ -318,49 +315,6 @@ std::unordered_map GRAMMAR_LITERAL_ESCAPES = { std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; -template -std::string join(Iterator begin, Iterator end, const std::string & separator) { - std::ostringstream result; - if (begin != end) { - result << *begin; - for (Iterator it = begin + 1; it != end; ++it) { - result << separator << *it; - } - } - return result.str(); -} - -static std::vector split(const std::string & str, const std::string & delimiter) { - std::vector tokens; - size_t start = 0; - size_t end = str.find(delimiter); - - while (end != std::string::npos) { - tokens.push_back(str.substr(start, end - start)); - start = end + delimiter.length(); - end = str.find(delimiter, start); - } - - tokens.push_back(str.substr(start)); - - return tokens; -} - -static std::string repeat(const std::string & str, size_t n) { - if (n == 0) { - return ""; - } - - std::string result; - result.reserve(str.length() * n); - - for (size_t i = 0; i < n; ++i) { - result += str; - } - - return result; -} - static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { std::smatch match; std::string result; @@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) { class SchemaConverter { private: + friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); std::function _fetch_json; bool _dotall; std::map _rules; @@ -418,7 +373,7 @@ private: for (size_t i = 0; i < alt_schemas.size(); i++) { rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); } - return join(rules.begin(), rules.end(), " | "); + return string_join(rules, " | "); } std::string _visit_pattern(const std::string & pattern, const std::string & name) { @@ -481,7 +436,7 @@ private: for (const auto & item : ret) { results.push_back(to_rule(item)); } - return std::make_pair(join(results.begin(), results.end(), " "), false); + return std::make_pair(string_join(results, " "), false); }; while (i < length) { @@ -539,7 +494,7 @@ private: } curly_brackets += '}'; i++; - auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); + auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); int min_times = 0; int max_times = std::numeric_limits::max(); try { @@ -611,7 +566,7 @@ private: } return join_seq(); }; - return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space"); + return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space"); } /* @@ -809,10 +764,11 @@ private: public: SchemaConverter( const std::function & fetch_json, - bool dotall) + bool dotall, + bool compact_spaces) : _fetch_json(fetch_json), _dotall(dotall) { - _rules["space"] = SPACE_RULE; + _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE; } void resolve_refs(json & schema, const std::string & url) { @@ -854,7 +810,7 @@ public: return; } std::string pointer = ref.substr(ref.find('#') + 1); - std::vector tokens = split(pointer, "/"); + std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; if (target.is_null() || !target.contains(sel)) { @@ -905,7 +861,7 @@ public: for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -1019,10 +975,10 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n")); + throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { - fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str()); + fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); } } @@ -1036,10 +992,27 @@ public: }; std::string json_schema_to_grammar(const json & schema) { - SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); - auto copy = schema; - converter.resolve_refs(copy, "input"); - converter.visit(copy, ""); + return build_grammar([&](const common_grammar_builder & callbacks) { + auto copy = schema; + callbacks.resolve_refs(copy); + callbacks.add_schema("", copy); + }); +} + +std::string build_grammar(const std::function & cb, const common_grammar_options & options) { + SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces); + common_grammar_builder builder { + /* .add_rule = */ [&](const std::string & name, const std::string & rule) { + return converter._add_rule(name, rule); + }, + /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) { + return converter.visit(schema, name == "root" ? "" : name); + }, + /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) { + converter.resolve_refs(schema, ""); + } + }; + cb(builder); converter.check_errors(); return converter.format_grammar(); } diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 41623b346..ba4112cb9 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -5,4 +5,17 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); +std::string json_schema_to_grammar(const nlohmann::ordered_json & schema); + +struct common_grammar_builder { + std::function add_rule; + std::function add_schema; + std::function resolve_refs; +}; + +struct common_grammar_options { + bool dotall = false; + bool compact_spaces = false; +}; + +std::string build_grammar(const std::function & cb, const common_grammar_options & options = {}); diff --git a/common/log.cpp b/common/log.cpp new file mode 100644 index 000000000..04c7c0ed1 --- /dev/null +++ b/common/log.cpp @@ -0,0 +1,401 @@ +#include "log.h" + +#include +#include +#include +#include +#include +#include +#include + +int common_log_verbosity_thold = LOG_DEFAULT_LLAMA; + +void common_log_set_verbosity_thold(int verbosity) { + common_log_verbosity_thold = verbosity; +} + +#define LOG_COL_DEFAULT "\033[0m" +#define LOG_COL_BOLD "\033[1m" +#define LOG_COL_RED "\033[31m" +#define LOG_COL_GREEN "\033[32m" +#define LOG_COL_YELLOW "\033[33m" +#define LOG_COL_BLUE "\033[34m" +#define LOG_COL_MAGENTA "\033[35m" +#define LOG_COL_CYAN "\033[36m" +#define LOG_COL_WHITE "\033[37m" + +static int64_t t_us() { + return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); +} + +// colors +enum common_log_col : int { + COMMON_LOG_COL_DEFAULT = 0, + COMMON_LOG_COL_BOLD, + COMMON_LOG_COL_RED, + COMMON_LOG_COL_GREEN, + COMMON_LOG_COL_YELLOW, + COMMON_LOG_COL_BLUE, + COMMON_LOG_COL_MAGENTA, + COMMON_LOG_COL_CYAN, + COMMON_LOG_COL_WHITE, +}; + +// disable colors by default +static std::vector g_col = { + "", + "", + "", + "", + "", + "", + "", + "", + "", +}; + +struct common_log_entry { + enum ggml_log_level level; + + bool prefix; + + int64_t timestamp; + + std::vector msg; + + // signals the worker thread to stop + bool is_end; + + void print(FILE * file = nullptr) const { + FILE * fcur = file; + if (!fcur) { + // stderr displays DBG messages only when their verbosity level is not higher than the threshold + // these messages will still be logged to a file + if (level == GGML_LOG_LEVEL_DEBUG && common_log_verbosity_thold < LOG_DEFAULT_DEBUG) { + return; + } + + fcur = stdout; + + if (level != GGML_LOG_LEVEL_NONE) { + fcur = stderr; + } + } + + if (level != GGML_LOG_LEVEL_NONE && level != GGML_LOG_LEVEL_CONT && prefix) { + if (timestamp) { + // [M.s.ms.us] + fprintf(fcur, "%s%d.%02d.%03d.%03d%s ", + g_col[COMMON_LOG_COL_BLUE], + (int) (timestamp / 1000000 / 60), + (int) (timestamp / 1000000 % 60), + (int) (timestamp / 1000 % 1000), + (int) (timestamp % 1000), + g_col[COMMON_LOG_COL_DEFAULT]); + } + + switch (level) { + case GGML_LOG_LEVEL_INFO: fprintf(fcur, "%sI %s", g_col[COMMON_LOG_COL_GREEN], g_col[COMMON_LOG_COL_DEFAULT]); break; + case GGML_LOG_LEVEL_WARN: fprintf(fcur, "%sW %s", g_col[COMMON_LOG_COL_MAGENTA], "" ); break; + case GGML_LOG_LEVEL_ERROR: fprintf(fcur, "%sE %s", g_col[COMMON_LOG_COL_RED], "" ); break; + case GGML_LOG_LEVEL_DEBUG: fprintf(fcur, "%sD %s", g_col[COMMON_LOG_COL_YELLOW], "" ); break; + default: + break; + } + } + + fprintf(fcur, "%s", msg.data()); + + if (level == GGML_LOG_LEVEL_WARN || level == GGML_LOG_LEVEL_ERROR || level == GGML_LOG_LEVEL_DEBUG) { + fprintf(fcur, "%s", g_col[COMMON_LOG_COL_DEFAULT]); + } + + fflush(fcur); + } +}; + +struct common_log { + // default capacity - will be expanded if needed + common_log() : common_log(256) {} + + common_log(size_t capacity) { + file = nullptr; + prefix = false; + timestamps = false; + running = false; + t_start = t_us(); + + // initial message size - will be expanded if longer messages arrive + entries.resize(capacity); + for (auto & entry : entries) { + entry.msg.resize(256); + } + + head = 0; + tail = 0; + + resume(); + } + + ~common_log() { + pause(); + if (file) { + fclose(file); + } + } + +private: + std::mutex mtx; + std::thread thrd; + std::condition_variable cv; + + FILE * file; + + bool prefix; + bool timestamps; + bool running; + + int64_t t_start; + + // ring buffer of entries + std::vector entries; + size_t head; + size_t tail; + + // worker thread copies into this + common_log_entry cur; + +public: + void add(enum ggml_log_level level, const char * fmt, va_list args) { + std::lock_guard lock(mtx); + + if (!running) { + // discard messages while the worker thread is paused + return; + } + + auto & entry = entries[tail]; + + { + // cannot use args twice, so make a copy in case we need to expand the buffer + va_list args_copy; + va_copy(args_copy, args); + +#if 1 + const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args); + if (n >= entry.msg.size()) { + entry.msg.resize(n + 1); + vsnprintf(entry.msg.data(), entry.msg.size(), fmt, args_copy); + } +#else + // hack for bolding arguments + + std::stringstream ss; + for (int i = 0; fmt[i] != 0; i++) { + if (fmt[i] == '%') { + ss << LOG_COL_BOLD; + while (fmt[i] != ' ' && fmt[i] != ')' && fmt[i] != ']' && fmt[i] != 0) ss << fmt[i++]; + ss << LOG_COL_DEFAULT; + if (fmt[i] == 0) break; + } + ss << fmt[i]; + } + const size_t n = vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args); + if (n >= entry.msg.size()) { + entry.msg.resize(n + 1); + vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy); + } +#endif + } + + entry.level = level; + entry.prefix = prefix; + entry.timestamp = 0; + if (timestamps) { + entry.timestamp = t_us() - t_start; + } + entry.is_end = false; + + tail = (tail + 1) % entries.size(); + if (tail == head) { + // expand the buffer + std::vector new_entries(2*entries.size()); + + size_t new_tail = 0; + + do { + new_entries[new_tail] = std::move(entries[head]); + + head = (head + 1) % entries.size(); + new_tail = (new_tail + 1); + } while (head != tail); + + head = 0; + tail = new_tail; + + for (size_t i = tail; i < new_entries.size(); i++) { + new_entries[i].msg.resize(256); + } + + entries = std::move(new_entries); + } + + cv.notify_one(); + } + + void resume() { + std::lock_guard lock(mtx); + + if (running) { + return; + } + + running = true; + + thrd = std::thread([this]() { + while (true) { + { + std::unique_lock lock(mtx); + cv.wait(lock, [this]() { return head != tail; }); + + cur = entries[head]; + + head = (head + 1) % entries.size(); + } + + if (cur.is_end) { + break; + } + + cur.print(); // stdout and stderr + + if (file) { + cur.print(file); + } + } + }); + } + + void pause() { + { + std::lock_guard lock(mtx); + + if (!running) { + return; + } + + running = false; + + // push an entry to signal the worker thread to stop + { + auto & entry = entries[tail]; + entry.is_end = true; + + tail = (tail + 1) % entries.size(); + } + + cv.notify_one(); + } + + thrd.join(); + } + + void set_file(const char * path) { + pause(); + + if (file) { + fclose(file); + } + + if (path) { + file = fopen(path, "w"); + } else { + file = nullptr; + } + + resume(); + } + + void set_colors(bool colors) { + pause(); + + if (colors) { + g_col[COMMON_LOG_COL_DEFAULT] = LOG_COL_DEFAULT; + g_col[COMMON_LOG_COL_BOLD] = LOG_COL_BOLD; + g_col[COMMON_LOG_COL_RED] = LOG_COL_RED; + g_col[COMMON_LOG_COL_GREEN] = LOG_COL_GREEN; + g_col[COMMON_LOG_COL_YELLOW] = LOG_COL_YELLOW; + g_col[COMMON_LOG_COL_BLUE] = LOG_COL_BLUE; + g_col[COMMON_LOG_COL_MAGENTA] = LOG_COL_MAGENTA; + g_col[COMMON_LOG_COL_CYAN] = LOG_COL_CYAN; + g_col[COMMON_LOG_COL_WHITE] = LOG_COL_WHITE; + } else { + for (size_t i = 0; i < g_col.size(); i++) { + g_col[i] = ""; + } + } + + resume(); + } + + void set_prefix(bool prefix) { + std::lock_guard lock(mtx); + + this->prefix = prefix; + } + + void set_timestamps(bool timestamps) { + std::lock_guard lock(mtx); + + this->timestamps = timestamps; + } +}; + +// +// public API +// + +struct common_log * common_log_init() { + return new common_log; +} + +struct common_log * common_log_main() { + static struct common_log log; + + return &log; +} + +void common_log_pause(struct common_log * log) { + log->pause(); +} + +void common_log_resume(struct common_log * log) { + log->resume(); +} + +void common_log_free(struct common_log * log) { + delete log; +} + +void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...) { + va_list args; + va_start(args, fmt); + log->add(level, fmt, args); + va_end(args); +} + +void common_log_set_file(struct common_log * log, const char * file) { + log->set_file(file); +} + +void common_log_set_colors(struct common_log * log, bool colors) { + log->set_colors(colors); +} + +void common_log_set_prefix(struct common_log * log, bool prefix) { + log->set_prefix(prefix); +} + +void common_log_set_timestamps(struct common_log * log, bool timestamps) { + log->set_timestamps(timestamps); +} diff --git a/common/log.h b/common/log.h index 1bc5328ce..66605cc69 100644 --- a/common/log.h +++ b/common/log.h @@ -1,724 +1,92 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include +#include "ggml.h" // for ggml_log_level -// -------------------------------- -// -// Basic usage: -// -// -------- -// -// The LOG() and LOG_TEE() macros are ready to go by default -// they do not require any initialization. -// -// LOGLN() and LOG_TEELN() are variants which automatically -// include \n character at the end of the log string. -// -// LOG() behaves exactly like printf, by default writing to a logfile. -// LOG_TEE() additionally, prints to the screen too ( mimics Unix tee command ). -// -// Default logfile is named -// "llama..log" -// Default LOG_TEE() secondary output target is -// stderr -// -// Logs can be dynamically disabled or enabled using functions: -// log_disable() -// and -// log_enable() -// -// A log target can be changed with: -// log_set_target( string ) -// creating and opening, or re-opening a file by string filename -// or -// log_set_target( FILE* ) -// allowing to point at stderr, stdout, or any valid FILE* file handler. -// -// -------- -// -// End of Basic usage. -// -// -------------------------------- - -// Specifies a log target. -// default uses log_handler() with "llama.log" log file -// this can be changed, by defining LOG_TARGET -// like so: -// -// #define LOG_TARGET (a valid FILE*) -// #include "log.h" -// -// or it can be simply redirected to stdout or stderr -// like so: -// -// #define LOG_TARGET stderr -// #include "log.h" -// -// The log target can also be redirected to a different function -// like so: -// -// #define LOG_TARGET log_handler_different() -// #include "log.h" -// -// FILE* log_handler_different() -// { -// return stderr; -// } -// -// or: -// -// #define LOG_TARGET log_handler_another_one("somelog.log") -// #include "log.h" -// -// FILE* log_handler_another_one(char*filename) -// { -// static FILE* logfile = nullptr; -// (...) -// if( !logfile ) -// { -// fopen(...) -// } -// (...) -// return logfile -// } -// -#ifndef LOG_TARGET - #define LOG_TARGET log_handler() -#endif - -#ifndef LOG_TEE_TARGET - #define LOG_TEE_TARGET stderr -#endif - -// Utility for synchronizing log configuration state -// since std::optional was introduced only in c++17 -enum LogTriState -{ - LogTriStateSame, - LogTriStateFalse, - LogTriStateTrue -}; - -// Utility to obtain "pid" like unique process id and use it when creating log files. -inline std::string log_get_pid() -{ - static std::string pid; - if (pid.empty()) - { - // std::this_thread::get_id() is the most portable way of obtaining a "process id" - // it's not the same as "pid" but is unique enough to solve multiple instances - // trying to write to the same log. - std::stringstream ss; - ss << std::this_thread::get_id(); - pid = ss.str(); - } - - return pid; -} - -// Utility function for generating log file names with unique id based on thread id. -// invocation with log_filename_generator( "llama", "log" ) creates a string "llama..log" -// where the number is a runtime id of the current thread. - -#define log_filename_generator(log_file_basename, log_file_extension) log_filename_generator_impl(LogTriStateSame, log_file_basename, log_file_extension) - -// INTERNAL, DO NOT USE -inline std::string log_filename_generator_impl(LogTriState multilog, const std::string & log_file_basename, const std::string & log_file_extension) -{ - static bool _multilog = false; - - if (multilog != LogTriStateSame) - { - _multilog = multilog == LogTriStateTrue; - } - - std::stringstream buf; - - buf << log_file_basename; - if (_multilog) - { - buf << "."; - buf << log_get_pid(); - } - buf << "."; - buf << log_file_extension; - - return buf.str(); -} - -#ifndef LOG_DEFAULT_FILE_NAME - #define LOG_DEFAULT_FILE_NAME log_filename_generator("llama", "log") -#endif - -// Utility for turning #define values into string literals -// so we can have a define for stderr and -// we can print "stderr" instead of literal stderr, etc. -#define LOG_STRINGIZE1(s) #s -#define LOG_STRINGIZE(s) LOG_STRINGIZE1(s) - -#define LOG_TEE_TARGET_STRING LOG_STRINGIZE(LOG_TEE_TARGET) - -// Allows disabling timestamps. -// in order to disable, define LOG_NO_TIMESTAMPS -// like so: -// -// #define LOG_NO_TIMESTAMPS -// #include "log.h" -// -#ifndef LOG_NO_TIMESTAMPS - #ifndef _MSC_VER - #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " - #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() - #else - #define LOG_TIMESTAMP_FMT "[%" PRIu64 "] " - #define LOG_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() - #endif +#ifndef __GNUC__ +# define LOG_ATTRIBUTE_FORMAT(...) +#elif defined(__MINGW32__) +# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #else - #define LOG_TIMESTAMP_FMT "%s" - #define LOG_TIMESTAMP_VAL ,"" +# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) #endif -#ifdef LOG_TEE_TIMESTAMPS - #ifndef _MSC_VER - #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " - #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() - #else - #define LOG_TEE_TIMESTAMP_FMT "[%" PRIu64 "] " - #define LOG_TEE_TIMESTAMP_VAL , (std::chrono::duration_cast>(std::chrono::system_clock::now().time_since_epoch())).count() - #endif -#else - #define LOG_TEE_TIMESTAMP_FMT "%s" - #define LOG_TEE_TIMESTAMP_VAL ,"" -#endif +#define LOG_DEFAULT_DEBUG 1 +#define LOG_DEFAULT_LLAMA 0 -// Allows disabling file/line/function prefix -// in order to disable, define LOG_NO_FILE_LINE_FUNCTION -// like so: +// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower +// set via common_log_set_verbosity() +extern int common_log_verbosity_thold; + +void common_log_set_verbosity_thold(int verbosity); // not thread-safe + +// the common_log uses an internal worker thread to print/write log messages +// when the worker thread is paused, incoming log messages are discarded +struct common_log; + +struct common_log * common_log_init(); +struct common_log * common_log_main(); // singleton, automatically destroys itself on exit +void common_log_pause (struct common_log * log); // pause the worker thread, not thread-safe +void common_log_resume(struct common_log * log); // resume the worker thread, not thread-safe +void common_log_free (struct common_log * log); + +LOG_ATTRIBUTE_FORMAT(3, 4) +void common_log_add(struct common_log * log, enum ggml_log_level level, const char * fmt, ...); + +// defaults: file = NULL, colors = false, prefix = false, timestamps = false // -// #define LOG_NO_FILE_LINE_FUNCTION -// #include "log.h" +// regular log output: // -#ifndef LOG_NO_FILE_LINE_FUNCTION - #ifndef _MSC_VER - #define LOG_FLF_FMT "[%24s:%5d][%24s] " - #define LOG_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ - #else - #define LOG_FLF_FMT "[%24s:%5ld][%24s] " - #define LOG_FLF_VAL , __FILE__, (long)__LINE__, __FUNCTION__ - #endif -#else - #define LOG_FLF_FMT "%s" - #define LOG_FLF_VAL ,"" -#endif - -#ifdef LOG_TEE_FILE_LINE_FUNCTION - #ifndef _MSC_VER - #define LOG_TEE_FLF_FMT "[%24s:%5d][%24s] " - #define LOG_TEE_FLF_VAL , __FILE__, __LINE__, __FUNCTION__ - #else - #define LOG_TEE_FLF_FMT "[%24s:%5ld][%24s] " - #define LOG_TEE_FLF_VAL , __FILE__, (long)__LINE__, __FUNCTION__ - #endif -#else - #define LOG_TEE_FLF_FMT "%s" - #define LOG_TEE_FLF_VAL ,"" -#endif - -// INTERNAL, DO NOT USE -// USE LOG() INSTEAD +// ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34) +// llm_load_tensors: ggml ctx size = 0.27 MiB +// llm_load_tensors: offloading 32 repeating layers to GPU +// llm_load_tensors: offloading non-repeating layers to GPU // -#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) || defined(__clang__) - #define LOG_IMPL(str, ...) \ - do { \ - if (LOG_TARGET != nullptr) \ - { \ - fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ - fflush(LOG_TARGET); \ - } \ +// with prefix = true, timestamps = true, the log output will look like this: +// +// 0.00.035.060 D ggml_backend_metal_log_allocated_size: allocated buffer, size = 6695.84 MiB, ( 6695.91 / 21845.34) +// 0.00.035.064 I llm_load_tensors: ggml ctx size = 0.27 MiB +// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU +// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU +// +// I - info (stdout, V = 0) +// W - warning (stderr, V = 0) +// E - error (stderr, V = 0) +// D - debug (stderr, V = LOG_DEFAULT_DEBUG) +// + +void common_log_set_file (struct common_log * log, const char * file); // not thread-safe +void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe +void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log +void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix + +// helper macros for logging +// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold +// +// for example: +// +// LOG_DBG("this is a debug message: %d\n", expensive_function()); +// +// this will avoid calling expensive_function() if LOG_DEFAULT_DEBUG > common_log_verbosity_thold +// + +#define LOG_TMPL(level, verbosity, ...) \ + do { \ + if ((verbosity) <= common_log_verbosity_thold) { \ + common_log_add(common_log_main(), (level), __VA_ARGS__); \ + } \ } while (0) -#else - #define LOG_IMPL(str, ...) \ - do { \ - if (LOG_TARGET != nullptr) \ - { \ - fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ - fflush(LOG_TARGET); \ - } \ - } while (0) -#endif -// INTERNAL, DO NOT USE -// USE LOG_TEE() INSTEAD -// -#if !defined(_MSC_VER) || defined(__INTEL_LLVM_COMPILER) || defined(__clang__) - #define LOG_TEE_IMPL(str, ...) \ - do { \ - if (LOG_TARGET != nullptr) \ - { \ - fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL, __VA_ARGS__); \ - fflush(LOG_TARGET); \ - } \ - if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ - { \ - fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL, __VA_ARGS__); \ - fflush(LOG_TEE_TARGET); \ - } \ - } while (0) -#else - #define LOG_TEE_IMPL(str, ...) \ - do { \ - if (LOG_TARGET != nullptr) \ - { \ - fprintf(LOG_TARGET, LOG_TIMESTAMP_FMT LOG_FLF_FMT str "%s" LOG_TIMESTAMP_VAL LOG_FLF_VAL "", ##__VA_ARGS__); \ - fflush(LOG_TARGET); \ - } \ - if (LOG_TARGET != nullptr && LOG_TARGET != stdout && LOG_TARGET != stderr && LOG_TEE_TARGET != nullptr) \ - { \ - fprintf(LOG_TEE_TARGET, LOG_TEE_TIMESTAMP_FMT LOG_TEE_FLF_FMT str "%s" LOG_TEE_TIMESTAMP_VAL LOG_TEE_FLF_VAL "", ##__VA_ARGS__); \ - fflush(LOG_TEE_TARGET); \ - } \ - } while (0) -#endif +#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__) +#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__) -// The '\0' as a last argument, is a trick to bypass the silly -// "warning: ISO C++11 requires at least one argument for the "..." in a variadic macro" -// so we can have a single macro which can be called just like printf. +#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__) +#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__) +#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__) +#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__) +#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__) -// Main LOG macro. -// behaves like printf, and supports arguments the exact same way. -// -#if !defined(_MSC_VER) || defined(__clang__) - #define LOG(...) LOG_IMPL(__VA_ARGS__, "") -#else - #define LOG(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "") -#endif - -// Main TEE macro. -// does the same as LOG -// and -// simultaneously writes stderr. -// -// Secondary target can be changed just like LOG_TARGET -// by defining LOG_TEE_TARGET -// -#if !defined(_MSC_VER) || defined(__clang__) - #define LOG_TEE(...) LOG_TEE_IMPL(__VA_ARGS__, "") -#else - #define LOG_TEE(str, ...) LOG_TEE_IMPL("%s" str, "", ##__VA_ARGS__, "") -#endif - -// LOG macro variants with auto endline. -#if !defined(_MSC_VER) || defined(__clang__) - #define LOGLN(...) LOG_IMPL(__VA_ARGS__, "\n") - #define LOG_TEELN(...) LOG_TEE_IMPL(__VA_ARGS__, "\n") -#else - #define LOGLN(str, ...) LOG_IMPL("%s" str, "", ##__VA_ARGS__, "\n") - #define LOG_TEELN(str, ...) LOG_TEE_IMPL("%s" str, "", ##__VA_ARGS__, "\n") -#endif - -// INTERNAL, DO NOT USE -inline FILE *log_handler1_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, const std::string & filename = LOG_DEFAULT_FILE_NAME, FILE *target = nullptr) -{ - static bool _initialized = false; - static bool _append = false; - static bool _disabled = filename.empty() && target == nullptr; - static std::string log_current_filename{filename}; - static FILE *log_current_target{target}; - static FILE *logfile = nullptr; - - if (change) - { - if (append != LogTriStateSame) - { - _append = append == LogTriStateTrue; - return logfile; - } - - if (disable == LogTriStateTrue) - { - // Disable primary target - _disabled = true; - } - // If previously disabled, only enable, and keep previous target - else if (disable == LogTriStateFalse) - { - _disabled = false; - } - // Otherwise, process the arguments - else if (log_current_filename != filename || log_current_target != target) - { - _initialized = false; - } - } - - if (_disabled) - { - // Log is disabled - return nullptr; - } - - if (_initialized) - { - // with fallback in case something went wrong - return logfile ? logfile : stderr; - } - - // do the (re)initialization - if (target != nullptr) - { - if (logfile != nullptr && logfile != stdout && logfile != stderr) - { - fclose(logfile); - } - - log_current_filename = LOG_DEFAULT_FILE_NAME; - log_current_target = target; - - logfile = target; - } - else - { - if (log_current_filename != filename) - { - if (logfile != nullptr && logfile != stdout && logfile != stderr) - { - fclose(logfile); - } - } - - logfile = fopen(filename.c_str(), _append ? "a" : "w"); - } - - if (!logfile) - { - // Verify whether the file was opened, otherwise fallback to stderr - logfile = stderr; - - fprintf(stderr, "Failed to open logfile '%s' with error '%s'\n", filename.c_str(), std::strerror(errno)); - fflush(stderr); - - // At this point we let the init flag be to true below, and let the target fallback to stderr - // otherwise we would repeatedly fopen() which was already unsuccessful - } - - _initialized = true; - - return logfile ? logfile : stderr; -} - -// INTERNAL, DO NOT USE -inline FILE *log_handler2_impl(bool change = false, LogTriState append = LogTriStateSame, LogTriState disable = LogTriStateSame, FILE *target = nullptr, const std::string & filename = LOG_DEFAULT_FILE_NAME) -{ - return log_handler1_impl(change, append, disable, filename, target); -} - -// Disables logs entirely at runtime. -// Makes LOG() and LOG_TEE() produce no output, -// until enabled back. -#define log_disable() log_disable_impl() - -// INTERNAL, DO NOT USE -inline FILE *log_disable_impl() -{ - return log_handler1_impl(true, LogTriStateSame, LogTriStateTrue); -} - -// Enables logs at runtime. -#define log_enable() log_enable_impl() - -// INTERNAL, DO NOT USE -inline FILE *log_enable_impl() -{ - return log_handler1_impl(true, LogTriStateSame, LogTriStateFalse); -} - -// Sets target fir logs, either by a file name or FILE* pointer (stdout, stderr, or any valid FILE*) -#define log_set_target(target) log_set_target_impl(target) - -// INTERNAL, DO NOT USE -inline FILE *log_set_target_impl(const std::string & filename) { return log_handler1_impl(true, LogTriStateSame, LogTriStateSame, filename); } -inline FILE *log_set_target_impl(FILE *target) { return log_handler2_impl(true, LogTriStateSame, LogTriStateSame, target); } - -// INTERNAL, DO NOT USE -inline FILE *log_handler() { return log_handler1_impl(); } - -// Enable or disable creating separate log files for each run. -// can ONLY be invoked BEFORE first log use. -#define log_multilog(enable) log_filename_generator_impl((enable) ? LogTriStateTrue : LogTriStateFalse, "", "") -// Enable or disable append mode for log file. -// can ONLY be invoked BEFORE first log use. -#define log_append(enable) log_append_impl(enable) -// INTERNAL, DO NOT USE -inline FILE *log_append_impl(bool enable) -{ - return log_handler1_impl(true, enable ? LogTriStateTrue : LogTriStateFalse, LogTriStateSame); -} - -inline void log_test() -{ - log_disable(); - LOG("01 Hello World to nobody, because logs are disabled!\n"); - log_enable(); - LOG("02 Hello World to default output, which is \"%s\" ( Yaaay, arguments! )!\n", LOG_STRINGIZE(LOG_TARGET)); - LOG_TEE("03 Hello World to **both** default output and " LOG_TEE_TARGET_STRING "!\n"); - log_set_target(stderr); - LOG("04 Hello World to stderr!\n"); - LOG_TEE("05 Hello World TEE with double printing to stderr prevented!\n"); - log_set_target(LOG_DEFAULT_FILE_NAME); - LOG("06 Hello World to default log file!\n"); - log_set_target(stdout); - LOG("07 Hello World to stdout!\n"); - log_set_target(LOG_DEFAULT_FILE_NAME); - LOG("08 Hello World to default log file again!\n"); - log_disable(); - LOG("09 Hello World _1_ into the void!\n"); - log_enable(); - LOG("10 Hello World back from the void ( you should not see _1_ in the log or the output )!\n"); - log_disable(); - log_set_target("llama.anotherlog.log"); - LOG("11 Hello World _2_ to nobody, new target was selected but logs are still disabled!\n"); - log_enable(); - LOG("12 Hello World this time in a new file ( you should not see _2_ in the log or the output )?\n"); - log_set_target("llama.yetanotherlog.log"); - LOG("13 Hello World this time in yet new file?\n"); - log_set_target(log_filename_generator("llama_autonamed", "log")); - LOG("14 Hello World in log with generated filename!\n"); -#ifdef _MSC_VER - LOG_TEE("15 Hello msvc TEE without arguments\n"); - LOG_TEE("16 Hello msvc TEE with (%d)(%s) arguments\n", 1, "test"); - LOG_TEELN("17 Hello msvc TEELN without arguments\n"); - LOG_TEELN("18 Hello msvc TEELN with (%d)(%s) arguments\n", 1, "test"); - LOG("19 Hello msvc LOG without arguments\n"); - LOG("20 Hello msvc LOG with (%d)(%s) arguments\n", 1, "test"); - LOGLN("21 Hello msvc LOGLN without arguments\n"); - LOGLN("22 Hello msvc LOGLN with (%d)(%s) arguments\n", 1, "test"); -#endif -} - -inline bool log_param_single_parse(const std::string & param) -{ - if ( param == "--log-test") - { - log_test(); - return true; - } - - if ( param == "--log-disable") - { - log_disable(); - return true; - } - - if ( param == "--log-enable") - { - log_enable(); - return true; - } - - if (param == "--log-new") - { - log_multilog(true); - return true; - } - - if (param == "--log-append") - { - log_append(true); - return true; - } - - return false; -} - -inline bool log_param_pair_parse(bool check_but_dont_parse, const std::string & param, const std::string & next = std::string()) -{ - if ( param == "--log-file") - { - if (!check_but_dont_parse) - { - log_set_target(log_filename_generator(next.empty() ? "unnamed" : next, "log")); - } - - return true; - } - - return false; -} - -inline void log_print_usage() -{ - printf("log options:\n"); - /* format - printf(" -h, --help show this help message and exit\n");*/ - /* spacing - printf("__-param----------------Description\n");*/ - printf(" --log-test Run simple logging test\n"); - printf(" --log-disable Disable trace logs\n"); - printf(" --log-enable Enable trace logs\n"); - printf(" --log-file Specify a log filename (without extension)\n"); - printf(" --log-new Create a separate new log file on start. " - "Each log file will have unique name: \"..log\"\n"); - printf(" --log-append Don't truncate the old log file.\n"); - printf("\n"); -} - -#define log_dump_cmdline(argc, argv) log_dump_cmdline_impl(argc, argv) - -// INTERNAL, DO NOT USE -inline void log_dump_cmdline_impl(int argc, char **argv) -{ - std::stringstream buf; - for (int i = 0; i < argc; ++i) - { - if (std::string(argv[i]).find(' ') != std::string::npos) - { - buf << " \"" << argv[i] <<"\""; - } - else - { - buf << " " << argv[i]; - } - } - LOGLN("Cmd:%s", buf.str().c_str()); -} - -#define log_tostr(var) log_var_to_string_impl(var).c_str() - -inline std::string log_var_to_string_impl(bool var) -{ - return var ? "true" : "false"; -} - -inline std::string log_var_to_string_impl(std::string var) -{ - return var; -} - -inline std::string log_var_to_string_impl(const std::vector & var) -{ - std::stringstream buf; - buf << "[ "; - bool first = true; - for (auto e : var) - { - if (first) - { - first = false; - } - else - { - buf << ", "; - } - buf << std::to_string(e); - } - buf << " ]"; - - return buf.str(); -} - -template -inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens) -{ - std::stringstream buf; - buf << "[ "; - - bool first = true; - for (const auto & token : tokens) - { - if (!first) { - buf << ", "; - } else { - first = false; - } - - auto detokenized = llama_token_to_piece(ctx, token); - - detokenized.erase( - std::remove_if( - detokenized.begin(), - detokenized.end(), - [](const unsigned char c) { return !std::isprint(c); }), - detokenized.end()); - - buf - << "'" << detokenized << "'" - << ":" << std::to_string(token); - } - buf << " ]"; - - return buf.str(); -} - -template -inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch) -{ - std::stringstream buf; - buf << "[ "; - - bool first = true; - for (int i = 0; i < batch.n_tokens; ++i) - { - if (!first) { - buf << ", "; - } else { - first = false; - } - - auto detokenized = llama_token_to_piece(ctx, batch.token[i]); - - detokenized.erase( - std::remove_if( - detokenized.begin(), - detokenized.end(), - [](const unsigned char c) { return !std::isprint(c); }), - detokenized.end()); - - buf - << "\n" << std::to_string(i) - << ":token '" << detokenized << "'" - << ":pos " << std::to_string(batch.pos[i]) - << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) - << ":seq_id " << std::to_string(batch.seq_id[i][0]) - << ":logits " << std::to_string(batch.logits[i]); - } - buf << " ]"; - - return buf.str(); -} - -#ifdef LOG_DISABLE_LOGS - -#undef LOG -#define LOG(...) // dummy stub -#undef LOGLN -#define LOGLN(...) // dummy stub - -#undef LOG_TEE -#define LOG_TEE(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf - -#undef LOG_TEELN -#define LOG_TEELN(...) fprintf(stderr, __VA_ARGS__) // convert to normal fprintf - -#undef LOG_DISABLE -#define LOG_DISABLE() // dummy stub - -#undef LOG_ENABLE -#define LOG_ENABLE() // dummy stub - -#undef LOG_ENABLE -#define LOG_ENABLE() // dummy stub - -#undef LOG_SET_TARGET -#define LOG_SET_TARGET(...) // dummy stub - -#undef LOG_DUMP_CMDLINE -#define LOG_DUMP_CMDLINE(...) // dummy stub - -#endif // LOG_DISABLE_LOGS +#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__) +#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__) +#define LOG_ERRV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, verbosity, __VA_ARGS__) +#define LOG_DBGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, verbosity, __VA_ARGS__) +#define LOG_CNTV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_CONT, verbosity, __VA_ARGS__) diff --git a/common/minja.hpp b/common/minja.hpp new file mode 100644 index 000000000..f0e80fd7c --- /dev/null +++ b/common/minja.hpp @@ -0,0 +1,2819 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace minja { + +class Context; + +struct Options { + bool trim_blocks; // removes the first newline after a block + bool lstrip_blocks; // removes leading whitespace on the line of the block + bool keep_trailing_newline; // don't remove last newline +}; + +struct ArgumentsValue; + +inline std::string normalize_newlines(const std::string & s) { +#ifdef _WIN32 + static const std::regex nl_regex("\r\n"); + return std::regex_replace(s, nl_regex, "\n"); +#else + return s; +#endif +} + +/* Values that behave roughly like in Python. */ +class Value : public std::enable_shared_from_this { +public: + using CallableType = std::function &, ArgumentsValue &)>; + using FilterType = std::function &, ArgumentsValue &)>; + +private: + using ObjectType = nlohmann::ordered_map; // Only contains primitive keys + using ArrayType = std::vector; + + std::shared_ptr array_; + std::shared_ptr object_; + std::shared_ptr callable_; + json primitive_; + + Value(const std::shared_ptr & array) : array_(array) {} + Value(const std::shared_ptr & object) : object_(object) {} + Value(const std::shared_ptr & callable) : object_(std::make_shared()), callable_(callable) {} + + /* Python-style string repr */ + static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') { + if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump()); + auto s = primitive.dump(); + if (string_quote == '"' || s.find('\'') != std::string::npos) { + out << s; + return; + } + // Reuse json dump, just changing string quotes + out << string_quote; + for (size_t i = 1, n = s.size() - 1; i < n; ++i) { + if (s[i] == '\\' && s[i + 1] == '"') { + out << '"'; + i++; + } else if (s[i] == string_quote) { + out << '\\' << string_quote; + } else { + out << s[i]; + } + } + out << string_quote; + } + void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const { + auto print_indent = [&](int level) { + if (indent > 0) { + out << "\n"; + for (int i = 0, n = level * indent; i < n; ++i) out << ' '; + } + }; + auto print_sub_sep = [&]() { + out << ','; + if (indent < 0) out << ' '; + else print_indent(level + 1); + }; + + auto string_quote = to_json ? '"' : '\''; + + if (is_null()) out << "null"; + else if (array_) { + out << "["; + print_indent(level + 1); + for (size_t i = 0; i < array_->size(); ++i) { + if (i) print_sub_sep(); + (*array_)[i].dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "]"; + } else if (object_) { + out << "{"; + print_indent(level + 1); + for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) { + if (it != begin) print_sub_sep(); + if (it->first.is_string()) { + dump_string(it->first, out, string_quote); + } else { + out << string_quote << it->first.dump() << string_quote; + } + out << ": "; + it->second.dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "}"; + } else if (callable_) { + throw std::runtime_error("Cannot dump callable to JSON"); + } else if (is_boolean() && !to_json) { + out << (this->to_bool() ? "True" : "False"); + } else if (is_string() && !to_json) { + dump_string(primitive_, out, string_quote); + } else { + out << primitive_.dump(); + } + } + +public: + Value() {} + Value(const bool& v) : primitive_(v) {} + Value(const int64_t & v) : primitive_(v) {} + Value(const double& v) : primitive_(v) {} + Value(const std::nullptr_t &) {} + Value(const std::string & v) : primitive_(v) {} + Value(const char * v) : primitive_(std::string(v)) {} + + Value(const json & v) { + if (v.is_object()) { + auto object = std::make_shared(); + for (auto it = v.begin(); it != v.end(); ++it) { + (*object)[it.key()] = it.value(); + } + object_ = std::move(object); + } else if (v.is_array()) { + auto array = std::make_shared(); + for (const auto& item : v) { + array->push_back(Value(item)); + } + array_ = array; + } else { + primitive_ = v; + } + } + + std::vector keys() { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + std::vector res; + for (const auto& item : *object_) { + res.push_back(item.first); + } + return res; + } + + size_t size() const { + if (is_object()) return object_->size(); + if (is_array()) return array_->size(); + if (is_string()) return primitive_.get().length(); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + static Value array(const std::vector values = {}) { + auto array = std::make_shared(); + for (const auto& item : values) { + array->push_back(item); + } + return Value(array); + } + static Value object(const std::shared_ptr object = std::make_shared()) { + return Value(object); + } + static Value callable(const CallableType & callable) { + return Value(std::make_shared(callable)); + } + + void insert(size_t index, const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->insert(array_->begin() + index, v); + } + void push_back(const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->push_back(v); + } + Value pop(const Value& index) { + if (is_array()) { + if (array_->empty()) + throw std::runtime_error("pop from empty list"); + if (index.is_null()) { + auto ret = array_->back(); + array_->pop_back(); + return ret; + } else if (!index.is_number_integer()) { + throw std::runtime_error("pop index must be an integer: " + index.dump()); + } else { + auto i = index.get(); + if (i < 0 || i >= static_cast(array_->size())) + throw std::runtime_error("pop index out of range: " + index.dump()); + auto it = array_->begin() + (i < 0 ? array_->size() + i : i); + auto ret = *it; + array_->erase(it); + return ret; + } + } else if (is_object()) { + if (!index.is_hashable()) + throw std::runtime_error("Unashable type: " + index.dump()); + auto it = object_->find(index.primitive_); + if (it == object_->end()) + throw std::runtime_error("Key not found: " + index.dump()); + auto ret = it->second; + object_->erase(it); + return ret; + } else { + throw std::runtime_error("Value is not an array or object: " + dump()); + } + } + Value get(const Value& key) { + if (array_) { + if (!key.is_number_integer()) { + return Value(); + } + auto index = key.get(); + return array_->at(index < 0 ? array_->size() + index : index); + } else if (object_) { + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + auto it = object_->find(key.primitive_); + if (it == object_->end()) return Value(); + return it->second; + } + return Value(); + } + void set(const Value& key, const Value& value) { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + (*object_)[key.primitive_] = value; + } + Value call(const std::shared_ptr & context, ArgumentsValue & args) const { + if (!callable_) throw std::runtime_error("Value is not callable: " + dump()); + return (*callable_)(context, args); + } + + bool is_object() const { return !!object_; } + bool is_array() const { return !!array_; } + bool is_callable() const { return !!callable_; } + bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; } + bool is_boolean() const { return primitive_.is_boolean(); } + bool is_number_integer() const { return primitive_.is_number_integer(); } + bool is_number_float() const { return primitive_.is_number_float(); } + bool is_number() const { return primitive_.is_number(); } + bool is_string() const { return primitive_.is_string(); } + bool is_iterable() const { return is_array() || is_object() || is_string(); } + + bool is_primitive() const { return !array_ && !object_ && !callable_; } + bool is_hashable() const { return is_primitive(); } + + bool empty() const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_string()) return primitive_.empty(); + if (is_array()) return array_->empty(); + if (is_object()) return object_->empty(); + return false; + } + + void for_each(const std::function & callback) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (auto& item : *array_) { + callback(item); + } + } else if (object_) { + for (auto & item : *object_) { + Value key(item.first); + callback(key); + } + } else if (is_string()) { + for (char c : primitive_.get()) { + auto val = Value(std::string(1, c)); + callback(val); + } + } else { + throw std::runtime_error("Value is not iterable: " + dump()); + } + } + + bool to_bool() const { + if (is_null()) return false; + if (is_boolean()) return get(); + if (is_number()) return get() != 0; + if (is_string()) return !get().empty(); + if (is_array()) return !empty(); + return true; + } + + int64_t to_int() const { + if (is_null()) return 0; + if (is_boolean()) return get() ? 1 : 0; + if (is_number()) return static_cast(get()); + if (is_string()) { + try { + return std::stol(get()); + } catch (const std::exception &) { + return 0; + } + } + return 0; + } + + bool operator<(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() < other.get(); + if (is_string() && other.is_string()) return get() < other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump()); + } + bool operator>=(const Value & other) const { return !(*this < other); } + + bool operator>(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() > other.get(); + if (is_string() && other.is_string()) return get() > other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump()); + } + bool operator<=(const Value & other) const { return !(*this > other); } + + bool operator==(const Value & other) const { + if (callable_ || other.callable_) { + if (callable_.get() != other.callable_.get()) return false; + } + if (array_) { + if (!other.array_) return false; + if (array_->size() != other.array_->size()) return false; + for (size_t i = 0; i < array_->size(); ++i) { + if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false; + } + return true; + } else if (object_) { + if (!other.object_) return false; + if (object_->size() != other.object_->size()) return false; + for (const auto& item : *object_) { + if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false; + } + return true; + } else { + return primitive_ == other.primitive_; + } + } + bool operator!=(const Value & other) const { return !(*this == other); } + + bool contains(const char * key) const { return contains(std::string(key)); } + bool contains(const std::string & key) const { + if (array_) { + return false; + } else if (object_) { + return object_->find(key) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + bool contains(const Value & value) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (const auto& item : *array_) { + if (item.to_bool() && item == value) return true; + } + return false; + } else if (object_) { + if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); + return object_->find(value.primitive_) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + void erase(size_t index) { + if (!array_) throw std::runtime_error("Value is not an array: " + dump()); + array_->erase(array_->begin() + index); + } + void erase(const std::string & key) { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + object_->erase(key); + } + const Value& at(const Value & index) const { + return const_cast(this)->at(index); + } + Value& at(const Value & index) { + if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (is_array()) return array_->at(index.get()); + if (is_object()) return object_->at(index.primitive_); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + const Value& at(size_t index) const { + return const_cast(this)->at(index); + } + Value& at(size_t index) { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_array()) return array_->at(index); + if (is_object()) return object_->at(index); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + template + T get(const std::string & key, T default_value) const { + if (!contains(key)) return default_value; + return at(key).get(); + } + + template + T get() const { + if (is_primitive()) return primitive_.get(); + throw std::runtime_error("get not defined for this value type: " + dump()); + } + + std::string dump(int indent=-1, bool to_json=false) const { + std::ostringstream out; + dump(out, indent, 0, to_json); + return out.str(); + } + + Value operator-() const { + if (is_number_integer()) + return -get(); + else + return -get(); + } + std::string to_str() const { + if (is_string()) return get(); + if (is_number_integer()) return std::to_string(get()); + if (is_number_float()) return std::to_string(get()); + if (is_boolean()) return get() ? "True" : "False"; + if (is_null()) return "None"; + return dump(); + } + Value operator+(const Value& rhs) const { + if (is_string() || rhs.is_string()) { + return to_str() + rhs.to_str(); + } else if (is_number_integer() && rhs.is_number_integer()) { + return get() + rhs.get(); + } else if (is_array() && rhs.is_array()) { + auto res = Value::array(); + for (const auto& item : *array_) res.push_back(item); + for (const auto& item : *rhs.array_) res.push_back(item); + return res; + } else { + return get() + rhs.get(); + } + } + Value operator-(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() - rhs.get(); + else + return get() - rhs.get(); + } + Value operator*(const Value& rhs) const { + if (is_string() && rhs.is_number_integer()) { + std::ostringstream out; + for (int64_t i = 0, n = rhs.get(); i < n; ++i) { + out << to_str(); + } + return out.str(); + } + else if (is_number_integer() && rhs.is_number_integer()) + return get() * rhs.get(); + else + return get() * rhs.get(); + } + Value operator/(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() / rhs.get(); + else + return get() / rhs.get(); + } + Value operator%(const Value& rhs) const { + return get() % rhs.get(); + } +}; + +struct ArgumentsValue { + std::vector args; + std::vector> kwargs; + + bool has_named(const std::string & name) { + for (const auto & p : kwargs) { + if (p.first == name) return true; + } + return false; + } + + Value get_named(const std::string & name) { + for (const auto & [key, value] : kwargs) { + if (key == name) return value; + } + return Value(); + } + + bool empty() { + return args.empty() && kwargs.empty(); + } + + void expectArgs(const std::string & method_name, const std::pair & pos_count, const std::pair & kw_count) { + if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { + std::ostringstream out; + out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments"; + throw std::runtime_error(out.str()); + } + } +}; + +template <> +inline json Value::get() const { + if (is_primitive()) return primitive_; + if (is_null()) return json(); + if (array_) { + std::vector res; + for (const auto& item : *array_) { + res.push_back(item.get()); + } + return res; + } + if (object_) { + json res = json::object(); + for (const auto& [key, value] : *object_) { + if (key.is_string()) { + res[key.get()] = value.get(); + } else if (key.is_primitive()) { + res[key.dump()] = value.get(); + } else { + throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump()); + } + } + if (is_callable()) { + res["__callable__"] = true; + } + return res; + } + throw std::runtime_error("get not defined for this value type: " + dump()); +} + +} // namespace minja + +namespace std { + template <> + struct hash { + size_t operator()(const minja::Value & v) const { + if (!v.is_hashable()) + throw std::runtime_error("Unsupported type for hashing: " + v.dump()); + return std::hash()(v.get()); + } + }; +} // namespace std + +namespace minja { + +static std::string error_location_suffix(const std::string & source, size_t pos) { + auto get_line = [&](size_t line) { + auto start = source.begin(); + for (size_t i = 1; i < line; ++i) { + start = std::find(start, source.end(), '\n') + 1; + } + auto end = std::find(start, source.end(), '\n'); + return std::string(start, end); + }; + auto start = source.begin(); + auto end = source.end(); + auto it = start + pos; + auto line = std::count(start, it, '\n') + 1; + auto max_line = std::count(start, end, '\n') + 1; + auto col = pos - std::string(start, it).rfind('\n'); + std::ostringstream out; + out << " at row " << line << ", column " << col << ":\n"; + if (line > 1) out << get_line(line - 1) << "\n"; + out << get_line(line) << "\n"; + out << std::string(col - 1, ' ') << "^\n"; + if (line < max_line) out << get_line(line + 1) << "\n"; + + return out.str(); +} + +class Context : public std::enable_shared_from_this { + protected: + Value values_; + std::shared_ptr parent_; + public: + Context(Value && values, const std::shared_ptr & parent = nullptr) : values_(std::move(values)), parent_(parent) { + if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump()); + } + virtual ~Context() {} + + static std::shared_ptr builtins(); + static std::shared_ptr make(Value && values, const std::shared_ptr & parent = builtins()); + + std::vector keys() { + return values_.keys(); + } + virtual Value get(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->get(key); + return Value(); + } + virtual Value & at(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->at(key); + throw std::runtime_error("Undefined variable: " + key.dump()); + } + virtual bool contains(const Value & key) { + if (values_.contains(key)) return true; + if (parent_) return parent_->contains(key); + return false; + } + virtual void set(const Value & key, const Value & value) { + values_.set(key, value); + } +}; + +struct Location { + std::shared_ptr source; + size_t pos; +}; + +class Expression { +protected: + virtual Value do_evaluate(const std::shared_ptr & context) const = 0; +public: + using Parameters = std::vector>>; + + Location location; + + Expression(const Location & location) : location(location) {} + virtual ~Expression() = default; + + Value evaluate(const std::shared_ptr & context) const { + try { + return do_evaluate(context); + } catch (const std::exception & e) { + std::ostringstream out; + out << e.what(); + if (location.source) out << error_location_suffix(*location.source, location.pos); + throw std::runtime_error(out.str()); + } + } +}; + +class VariableExpr : public Expression { + std::string name; +public: + VariableExpr(const Location & location, const std::string& n) + : Expression(location), name(n) {} + std::string get_name() const { return name; } + Value do_evaluate(const std::shared_ptr & context) const override { + if (!context->contains(name)) { + return Value(); + } + return context->at(name); + } +}; + +static void destructuring_assign(const std::vector & var_names, const std::shared_ptr & context, Value& item) { + if (var_names.size() == 1) { + Value name(var_names[0]); + context->set(name, item); + } else { + if (!item.is_array() || item.size() != var_names.size()) { + throw std::runtime_error("Mismatched number of variables and items in destructuring assignment"); + } + for (size_t i = 0; i < var_names.size(); ++i) { + context->set(var_names[i], item.at(i)); + } + } +} + +enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; + +class TemplateToken { +public: + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter }; + + static std::string typeToString(Type t) { + switch (t) { + case Type::Text: return "text"; + case Type::Expression: return "expression"; + case Type::If: return "if"; + case Type::Else: return "else"; + case Type::Elif: return "elif"; + case Type::EndIf: return "endif"; + case Type::For: return "for"; + case Type::EndFor: return "endfor"; + case Type::Set: return "set"; + case Type::EndSet: return "endset"; + case Type::Comment: return "comment"; + case Type::Macro: return "macro"; + case Type::EndMacro: return "endmacro"; + case Type::Filter: return "filter"; + case Type::EndFilter: return "endfilter"; + case Type::Generation: return "generation"; + case Type::EndGeneration: return "endgeneration"; + } + return "Unknown"; + } + + TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {} + virtual ~TemplateToken() = default; + + Type type; + Location location; + SpaceHandling pre_space = SpaceHandling::Keep; + SpaceHandling post_space = SpaceHandling::Keep; +}; + +struct TextTemplateToken : public TemplateToken { + std::string text; + TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} +}; + +struct ExpressionTemplateToken : public TemplateToken { + std::shared_ptr expr; + ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} +}; + +struct IfTemplateToken : public TemplateToken { + std::shared_ptr condition; + IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} +}; + +struct ElifTemplateToken : public TemplateToken { + std::shared_ptr condition; + ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} +}; + +struct ElseTemplateToken : public TemplateToken { + ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} +}; + +struct EndIfTemplateToken : public TemplateToken { + EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} +}; + +struct MacroTemplateToken : public TemplateToken { + std::shared_ptr name; + Expression::Parameters params; + MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) + : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} +}; + +struct EndMacroTemplateToken : public TemplateToken { + EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} +}; + +struct FilterTemplateToken : public TemplateToken { + std::shared_ptr filter; + FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) + : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {} +}; + +struct EndFilterTemplateToken : public TemplateToken { + EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {} +}; + +struct ForTemplateToken : public TemplateToken { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + bool recursive; + ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, + std::shared_ptr && c, bool r) + : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} +}; + +struct EndForTemplateToken : public TemplateToken { + EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} +}; + +struct GenerationTemplateToken : public TemplateToken { + GenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Generation, location, pre, post) {} +}; + +struct EndGenerationTemplateToken : public TemplateToken { + EndGenerationTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndGeneration, location, pre, post) {} +}; + +struct SetTemplateToken : public TemplateToken { + std::string ns; + std::vector var_names; + std::shared_ptr value; + SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} +}; + +struct EndSetTemplateToken : public TemplateToken { + EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} +}; + +struct CommentTemplateToken : public TemplateToken { + std::string text; + CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} +}; + +class TemplateNode { + Location location_; +protected: + virtual void do_render(std::ostringstream & out, const std::shared_ptr & context) const = 0; + +public: + TemplateNode(const Location & location) : location_(location) {} + void render(std::ostringstream & out, const std::shared_ptr & context) const { + try { + do_render(out, context); + } catch (const std::exception & e) { + std::ostringstream err; + err << e.what(); + if (location_.source) err << error_location_suffix(*location_.source, location_.pos); + throw std::runtime_error(err.str()); + } + } + const Location & location() const { return location_; } + virtual ~TemplateNode() = default; + std::string render(const std::shared_ptr & context) const { + std::ostringstream out; + render(out, context); + return out.str(); + } +}; + +class SequenceNode : public TemplateNode { + std::vector> children; +public: + SequenceNode(const Location & location, std::vector> && c) + : TemplateNode(location), children(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& child : children) child->render(out, context); + } +}; + +class TextNode : public TemplateNode { + std::string text; +public: + TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} + void do_render(std::ostringstream & out, const std::shared_ptr &) const override { + out << text; + } +}; + +class ExpressionNode : public TemplateNode { + std::shared_ptr expr; +public: + ExpressionNode(const Location & location, std::shared_ptr && e) : TemplateNode(location), expr(std::move(e)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); + auto result = expr->evaluate(context); + if (result.is_string()) { + out << result.get(); + } else if (result.is_boolean()) { + out << (result.get() ? "True" : "False"); + } else if (!result.is_null()) { + out << result.dump(); + } + } +}; + +class IfNode : public TemplateNode { + std::vector, std::shared_ptr>> cascade; +public: + IfNode(const Location & location, std::vector, std::shared_ptr>> && c) + : TemplateNode(location), cascade(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& branch : cascade) { + auto enter_branch = true; + if (branch.first) { + enter_branch = branch.first->evaluate(context).to_bool(); + } + if (enter_branch) { + if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null"); + branch.second->render(out, context); + return; + } + } + } +}; + +class ForNode : public TemplateNode { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + std::shared_ptr body; + bool recursive; + std::shared_ptr else_body; +public: + ForNode(const Location & location, std::vector && var_names, std::shared_ptr && iterable, + std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) + : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + // https://jinja.palletsprojects.com/en/3.0.x/templates/#for + if (!iterable) throw std::runtime_error("ForNode.iterable is null"); + if (!body) throw std::runtime_error("ForNode.body is null"); + + auto iterable_value = iterable->evaluate(context); + Value::CallableType loop_function; + + std::function visit = [&](Value& iter) { + auto filtered_items = Value::array(); + if (!iter.is_null()) { + if (!iterable_value.is_iterable()) { + throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump()); + } + iterable_value.for_each([&](Value & item) { + destructuring_assign(var_names, context, item); + if (!condition || condition->evaluate(context).to_bool()) { + filtered_items.push_back(item); + } + }); + } + if (filtered_items.empty()) { + if (else_body) { + else_body->render(out, context); + } + } else { + auto loop = recursive ? Value::callable(loop_function) : Value::object(); + loop.set("length", (int64_t) filtered_items.size()); + + size_t cycle_index = 0; + loop.set("cycle", Value::callable([&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.empty() || !args.kwargs.empty()) { + throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg"); + } + auto item = args.args[cycle_index]; + cycle_index = (cycle_index + 1) % args.args.size(); + return item; + })); + auto loop_context = Context::make(Value::object(), context); + loop_context->set("loop", loop); + for (size_t i = 0, n = filtered_items.size(); i < n; ++i) { + auto & item = filtered_items.at(i); + destructuring_assign(var_names, loop_context, item); + loop.set("index", (int64_t) i + 1); + loop.set("index0", (int64_t) i); + loop.set("revindex", (int64_t) (n - i)); + loop.set("revindex0", (int64_t) (n - i - 1)); + loop.set("length", (int64_t) n); + loop.set("first", i == 0); + loop.set("last", i == (n - 1)); + loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); + loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); + body->render(out, loop_context); + } + } + }; + + if (recursive) { + loop_function = [&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) { + throw std::runtime_error("loop() expects exactly 1 positional iterable argument"); + } + auto & items = args.args[0]; + visit(items); + return Value(); + }; + } + + visit(iterable_value); + } +}; + +class MacroNode : public TemplateNode { + std::shared_ptr name; + Expression::Parameters params; + std::shared_ptr body; + std::unordered_map named_param_positions; +public: + MacroNode(const Location & location, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) + : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { + for (size_t i = 0; i < params.size(); ++i) { + const auto & name = params[i].first; + if (!name.empty()) { + named_param_positions[name] = i; + } + } + } + void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { + if (!name) throw std::runtime_error("MacroNode.name is null"); + if (!body) throw std::runtime_error("MacroNode.body is null"); + auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { + auto call_context = macro_context; + std::vector param_set(params.size(), false); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); + param_set[i] = true; + auto & param_name = params[i].first; + call_context->set(param_name, arg); + } + for (auto & [arg_name, value] : args.kwargs) { + auto it = named_param_positions.find(arg_name); + if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); + + call_context->set(arg_name, value); + param_set[it->second] = true; + } + // Set default values for parameters that were not passed + for (size_t i = 0, n = params.size(); i < n; i++) { + if (!param_set[i] && params[i].second != nullptr) { + auto val = params[i].second->evaluate(context); + call_context->set(params[i].first, val); + } + } + return body->render(call_context); + }); + macro_context->set(name->get_name(), callable); + } +}; + +class FilterNode : public TemplateNode { + std::shared_ptr filter; + std::shared_ptr body; + +public: + FilterNode(const Location & location, std::shared_ptr && f, std::shared_ptr && b) + : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!filter) throw std::runtime_error("FilterNode.filter is null"); + if (!body) throw std::runtime_error("FilterNode.body is null"); + auto filter_value = filter->evaluate(context); + if (!filter_value.is_callable()) { + throw std::runtime_error("Filter must be a callable: " + filter_value.dump()); + } + std::string rendered_body = body->render(context); + + ArgumentsValue filter_args = {{Value(rendered_body)}, {}}; + auto result = filter_value.call(context, filter_args); + out << result.to_str(); + } +}; + +class SetNode : public TemplateNode { + std::string ns; + std::vector var_names; + std::shared_ptr value; +public: + SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!value) throw std::runtime_error("SetNode.value is null"); + if (!ns.empty()) { + if (var_names.size() != 1) { + throw std::runtime_error("Namespaced set only supports a single variable name"); + } + auto & name = var_names[0]; + auto ns_value = context->get(ns); + if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object"); + ns_value.set(name, this->value->evaluate(context)); + } else { + auto val = value->evaluate(context); + destructuring_assign(var_names, context, val); + } + } +}; + +class SetTemplateNode : public TemplateNode { + std::string name; + std::shared_ptr template_value; +public: + SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr && tv) + : TemplateNode(location), name(name), template_value(std::move(tv)) {} + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); + Value value { template_value->render(context) }; + context->set(name, value); + } +}; + +class IfExpr : public Expression { + std::shared_ptr condition; + std::shared_ptr then_expr; + std::shared_ptr else_expr; +public: + IfExpr(const Location & location, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) + : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!condition) throw std::runtime_error("IfExpr.condition is null"); + if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); + if (condition->evaluate(context).to_bool()) { + return then_expr->evaluate(context); + } + if (else_expr) { + return else_expr->evaluate(context); + } + return nullptr; + } +}; + +class LiteralExpr : public Expression { + Value value; +public: + LiteralExpr(const Location & location, const Value& v) + : Expression(location), value(v) {} + Value do_evaluate(const std::shared_ptr &) const override { return value; } +}; + +class ArrayExpr : public Expression { + std::vector> elements; +public: + ArrayExpr(const Location & location, std::vector> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::array(); + for (const auto& e : elements) { + if (!e) throw std::runtime_error("Array element is null"); + result.push_back(e->evaluate(context)); + } + return result; + } +}; + +class DictExpr : public Expression { + std::vector, std::shared_ptr>> elements; +public: + DictExpr(const Location & location, std::vector, std::shared_ptr>> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::object(); + for (const auto& [key, value] : elements) { + if (!key) throw std::runtime_error("Dict key is null"); + if (!value) throw std::runtime_error("Dict value is null"); + result.set(key->evaluate(context), value->evaluate(context)); + } + return result; + } +}; + +class SliceExpr : public Expression { +public: + std::shared_ptr start, end; + SliceExpr(const Location & location, std::shared_ptr && s, std::shared_ptr && e) + : Expression(location), start(std::move(s)), end(std::move(e)) {} + Value do_evaluate(const std::shared_ptr &) const override { + throw std::runtime_error("SliceExpr not implemented"); + } +}; + +class SubscriptExpr : public Expression { + std::shared_ptr base; + std::shared_ptr index; +public: + SubscriptExpr(const Location & location, std::shared_ptr && b, std::shared_ptr && i) + : Expression(location), base(std::move(b)), index(std::move(i)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!base) throw std::runtime_error("SubscriptExpr.base is null"); + if (!index) throw std::runtime_error("SubscriptExpr.index is null"); + auto target_value = base->evaluate(context); + if (auto slice = dynamic_cast(index.get())) { + auto start = slice->start ? slice->start->evaluate(context).get() : 0; + auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + if (target_value.is_string()) { + std::string s = target_value.get(); + if (start < 0) start = s.size() + start; + if (end < 0) end = s.size() + end; + return s.substr(start, end - start); + } else if (target_value.is_array()) { + if (start < 0) start = target_value.size() + start; + if (end < 0) end = target_value.size() + end; + auto result = Value::array(); + for (auto i = start; i < end; ++i) { + result.push_back(target_value.at(i)); + } + return result; + } else { + throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings"); + } + } else { + auto index_value = index->evaluate(context); + if (target_value.is_null()) { + if (auto t = dynamic_cast(base.get())) { + throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined")); + } + throw std::runtime_error("Trying to access property '" + index_value.dump() + "' on null!"); + } + return target_value.get(index_value); + } + } +}; + +class UnaryOpExpr : public Expression { +public: + enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; + std::shared_ptr expr; + Op op; + UnaryOpExpr(const Location & location, std::shared_ptr && e, Op o) + : Expression(location), expr(std::move(e)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); + auto e = expr->evaluate(context); + switch (op) { + case Op::Plus: return e; + case Op::Minus: return -e; + case Op::LogicalNot: return !e.to_bool(); + case Op::Expansion: + case Op::ExpansionDict: + throw std::runtime_error("Expansion operator is only supported in function calls and collections"); + + } + throw std::runtime_error("Unknown unary operator"); + } +}; + +class BinaryOpExpr : public Expression { +public: + enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; +private: + std::shared_ptr left; + std::shared_ptr right; + Op op; +public: + BinaryOpExpr(const Location & location, std::shared_ptr && l, std::shared_ptr && r, Op o) + : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); + if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); + auto l = left->evaluate(context); + + auto do_eval = [&](const Value & l) -> Value { + if (op == Op::Is || op == Op::IsNot) { + auto t = dynamic_cast(right.get()); + if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable"); + + auto eval = [&]() { + const auto & name = t->get_name(); + if (name == "none") return l.is_null(); + if (name == "boolean") return l.is_boolean(); + if (name == "integer") return l.is_number_integer(); + if (name == "float") return l.is_number_float(); + if (name == "number") return l.is_number(); + if (name == "string") return l.is_string(); + if (name == "mapping") return l.is_object(); + if (name == "iterable") return l.is_iterable(); + if (name == "sequence") return l.is_array(); + if (name == "defined") return !l.is_null(); + throw std::runtime_error("Unknown type for 'is' operator: " + name); + }; + auto value = eval(); + return Value(op == Op::Is ? value : !value); + } + + if (op == Op::And) { + if (!l.to_bool()) return Value(false); + return right->evaluate(context).to_bool(); + } else if (op == Op::Or) { + if (l.to_bool()) return l; + return right->evaluate(context); + } + + auto r = right->evaluate(context); + switch (op) { + case Op::StrConcat: return l.to_str() + r.to_str(); + case Op::Add: return l + r; + case Op::Sub: return l - r; + case Op::Mul: return l * r; + case Op::Div: return l / r; + case Op::MulMul: return std::pow(l.get(), r.get()); + case Op::DivDiv: return l.get() / r.get(); + case Op::Mod: return l.get() % r.get(); + case Op::Eq: return l == r; + case Op::Ne: return l != r; + case Op::Lt: return l < r; + case Op::Gt: return l > r; + case Op::Le: return l <= r; + case Op::Ge: return l >= r; + case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); + case Op::NotIn: return !(r.is_array() && r.contains(l)); + default: break; + } + throw std::runtime_error("Unknown binary operator"); + }; + + if (l.is_callable()) { + return Value::callable([l, do_eval](const std::shared_ptr & context, ArgumentsValue & args) { + auto ll = l.call(context, args); + return do_eval(ll); //args[0].second); + }); + } else { + return do_eval(l); + } + } +}; + +struct ArgumentsExpression { + std::vector> args; + std::vector>> kwargs; + + ArgumentsValue evaluate(const std::shared_ptr & context) const { + ArgumentsValue vargs; + for (const auto& arg : this->args) { + if (auto un_expr = std::dynamic_pointer_cast(arg)) { + if (un_expr->op == UnaryOpExpr::Op::Expansion) { + auto array = un_expr->expr->evaluate(context); + if (!array.is_array()) { + throw std::runtime_error("Expansion operator only supported on arrays"); + } + array.for_each([&](Value & value) { + vargs.args.push_back(value); + }); + continue; + } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) { + auto dict = un_expr->expr->evaluate(context); + if (!dict.is_object()) { + throw std::runtime_error("ExpansionDict operator only supported on objects"); + } + dict.for_each([&](const Value & key) { + vargs.kwargs.push_back({key.get(), dict.at(key)}); + }); + continue; + } + } + vargs.args.push_back(arg->evaluate(context)); + } + for (const auto& [name, value] : this->kwargs) { + vargs.kwargs.push_back({name, value->evaluate(context)}); + } + return vargs; + } +}; + +static std::string strip(const std::string & s) { + auto start = s.find_first_not_of(" \t\n\r"); + if (start == std::string::npos) return ""; + auto end = s.find_last_not_of(" \t\n\r"); + return s.substr(start, end - start + 1); +} + +static std::string html_escape(const std::string & s) { + std::string result; + result.reserve(s.size()); + for (const auto & c : s) { + switch (c) { + case '&': result += "&"; break; + case '<': result += "<"; break; + case '>': result += ">"; break; + case '"': result += """; break; + case '\'': result += "'"; break; + default: result += c; break; + } + } + return result; +} + +class MethodCallExpr : public Expression { + std::shared_ptr object; + std::shared_ptr method; + ArgumentsExpression args; +public: + MethodCallExpr(const Location & location, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) + : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("MethodCallExpr.object is null"); + if (!method) throw std::runtime_error("MethodCallExpr.method is null"); + auto obj = object->evaluate(context); + auto vargs = args.evaluate(context); + if (obj.is_null()) { + throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null"); + } + if (obj.is_array()) { + if (method->get_name() == "append") { + vargs.expectArgs("append method", {1, 1}, {0, 0}); + obj.push_back(vargs.args[0]); + return Value(); + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {0, 1}, {0, 0}); + return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]); + } else if (method->get_name() == "insert") { + vargs.expectArgs("insert method", {2, 2}, {0, 0}); + auto index = vargs.args[0].get(); + if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method"); + obj.insert(index, vargs.args[1]); + return Value(); + } + } else if (obj.is_object()) { + if (method->get_name() == "items") { + vargs.expectArgs("items method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value::array({key, obj.at(key)})); + } + return result; + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {1, 1}, {0, 0}); + return obj.pop(vargs.args[0]); + } else if (method->get_name() == "get") { + vargs.expectArgs("get method", {1, 2}, {0, 0}); + auto key = vargs.args[0]; + if (vargs.args.size() == 1) { + return obj.contains(key) ? obj.at(key) : Value(); + } else { + return obj.contains(key) ? obj.at(key) : vargs.args[1]; + } + } else if (obj.contains(method->get_name())) { + auto callable = obj.at(method->get_name()); + if (!callable.is_callable()) { + throw std::runtime_error("Property '" + method->get_name() + "' is not callable"); + } + return callable.call(context, vargs); + } + } else if (obj.is_string()) { + auto str = obj.get(); + if (method->get_name() == "strip") { + vargs.expectArgs("strip method", {0, 0}, {0, 0}); + return Value(strip(str)); + } else if (method->get_name() == "endswith") { + vargs.expectArgs("endswith method", {1, 1}, {0, 0}); + auto suffix = vargs.args[0].get(); + return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "title") { + vargs.expectArgs("title method", {0, 0}, {0, 0}); + auto res = str; + for (size_t i = 0, n = res.size(); i < n; ++i) { + if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]); + else res[i] = std::tolower(res[i]); + } + return res; + } + } + throw std::runtime_error("Unknown method: " + method->get_name()); + } +}; + +class CallExpr : public Expression { +public: + std::shared_ptr object; + ArgumentsExpression args; + CallExpr(const Location & location, std::shared_ptr && obj, ArgumentsExpression && a) + : Expression(location), object(std::move(obj)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("CallExpr.object is null"); + auto obj = object->evaluate(context); + if (!obj.is_callable()) { + throw std::runtime_error("Object is not callable: " + obj.dump(2)); + } + auto vargs = args.evaluate(context); + return obj.call(context, vargs); + } +}; + +class FilterExpr : public Expression { + std::vector> parts; +public: + FilterExpr(const Location & location, std::vector> && p) + : Expression(location), parts(std::move(p)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + Value result; + bool first = true; + for (const auto& part : parts) { + if (!part) throw std::runtime_error("FilterExpr.part is null"); + if (first) { + first = false; + result = part->evaluate(context); + } else { + if (auto ce = dynamic_cast(part.get())) { + auto target = ce->object->evaluate(context); + ArgumentsValue args = ce->args.evaluate(context); + args.args.insert(args.args.begin(), result); + result = target.call(context, args); + } else { + auto callable = part->evaluate(context); + ArgumentsValue args; + args.args.insert(args.args.begin(), result); + result = callable.call(context, args); + } + } + } + return result; + } + + void prepend(std::shared_ptr && e) { + parts.insert(parts.begin(), std::move(e)); + } +}; + +class Parser { +private: + using CharIterator = std::string::const_iterator; + + std::shared_ptr template_str; + CharIterator start, end, it; + Options options; + + Parser(const std::shared_ptr& template_str, const Options & options) : template_str(template_str), options(options) { + if (!template_str) throw std::runtime_error("Template string is null"); + start = it = this->template_str->begin(); + end = this->template_str->end(); + } + + bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) { + if (space_handling == SpaceHandling::Strip) { + while (it != end && std::isspace(*it)) ++it; + } + return true; + } + + std::unique_ptr parseString() { + auto doParse = [&](char quote) -> std::unique_ptr { + if (it == end || *it != quote) return nullptr; + std::string result; + bool escape = false; + for (++it; it != end; ++it) { + if (escape) { + escape = false; + switch (*it) { + case 'n': result += '\n'; break; + case 'r': result += '\r'; break; + case 't': result += '\t'; break; + case 'b': result += '\b'; break; + case 'f': result += '\f'; break; + case '\\': result += '\\'; break; + default: + if (*it == quote) { + result += quote; + } else { + result += *it; + } + break; + } + } else if (*it == '\\') { + escape = true; + } else if (*it == quote) { + ++it; + return std::make_unique(std::move(result)); + } else { + result += *it; + } + } + return nullptr; + }; + + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"') return doParse('"'); + if (*it == '\'') return doParse('\''); + return nullptr; + } + + json parseNumber(CharIterator& it, const CharIterator& end) { + auto before = it; + consumeSpaces(); + auto start = it; + bool hasDecimal = false; + bool hasExponent = false; + + if (it != end && (*it == '-' || *it == '+')) ++it; + + while (it != end) { + if (std::isdigit(*it)) { + ++it; + } else if (*it == '.') { + if (hasDecimal) throw std::runtime_error("Multiple decimal points"); + hasDecimal = true; + ++it; + } else if (it != start && (*it == 'e' || *it == 'E')) { + if (hasExponent) throw std::runtime_error("Multiple exponents"); + hasExponent = true; + ++it; + } else { + break; + } + } + if (start == it) { + it = before; + return json(); // No valid characters found + } + + std::string str(start, it); + try { + return json::parse(str); + } catch (json::parse_error& e) { + throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")"); + return json(); + } + } + + /** integer, float, bool, string */ + std::shared_ptr parseConstant() { + auto start = it; + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"' || *it == '\'') { + auto str = parseString(); + if (str) return std::make_shared(*str); + } + static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); + auto token = consumeToken(prim_tok); + if (!token.empty()) { + if (token == "true" || token == "True") return std::make_shared(true); + if (token == "false" || token == "False") return std::make_shared(false); + if (token == "None") return std::make_shared(nullptr); + throw std::runtime_error("Unknown constant token: " + token); + } + + auto number = parseNumber(it, end); + if (!number.is_null()) return std::make_shared(number); + + it = start; + return nullptr; + } + + class expression_parsing_error : public std::runtime_error { + const CharIterator it; + public: + expression_parsing_error(const std::string & message, const CharIterator & it) + : std::runtime_error(message), it(it) {} + size_t get_pos(const CharIterator & begin) const { + return std::distance(begin, it); + } + }; + + bool peekSymbols(const std::vector & symbols) const { + for (const auto & symbol : symbols) { + if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) { + return true; + } + } + return false; + } + + std::vector consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + std::vector ret; + for (size_t i = 0, n = match.size(); i < n; ++i) { + ret.push_back(match[i].str()); + } + return ret; + } + it = start; + return {}; + } + std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + return match[0].str(); + } + it = start; + return ""; + } + + std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) { + it += token.size(); + return token; + } + it = start; + return ""; + } + + std::shared_ptr parseExpression(bool allow_if_expr = true) { + auto left = parseLogicalOr(); + if (it == end) return left; + + if (!allow_if_expr) return left; + + static std::regex if_tok(R"(if\b)"); + if (consumeToken(if_tok).empty()) { + return left; + } + + auto location = get_location(); + auto [condition, else_expr] = parseIfExpression(); + return std::make_shared(location, std::move(condition), std::move(left), std::move(else_expr)); + } + + Location get_location() const { + return {template_str, (size_t) std::distance(start, it)}; + } + + std::pair, std::shared_ptr> parseIfExpression() { + auto condition = parseLogicalOr(); + if (!condition) throw std::runtime_error("Expected condition expression"); + + static std::regex else_tok(R"(else\b)"); + std::shared_ptr else_expr; + if (!consumeToken(else_tok).empty()) { + else_expr = parseExpression(); + if (!else_expr) throw std::runtime_error("Expected 'else' expression"); + } + return std::pair(std::move(condition), std::move(else_expr)); + } + + std::shared_ptr parseLogicalOr() { + auto left = parseLogicalAnd(); + if (!left) throw std::runtime_error("Expected left side of 'logical or' expression"); + + static std::regex or_tok(R"(or\b)"); + auto location = get_location(); + while (!consumeToken(or_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'or' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); + } + return left; + } + + std::shared_ptr parseLogicalNot() { + static std::regex not_tok(R"(not\b)"); + auto location = get_location(); + + if (!consumeToken(not_tok).empty()) { + auto sub = parseLogicalNot(); + if (!sub) throw std::runtime_error("Expected expression after 'not' keyword"); + return std::make_shared(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); + } + return parseLogicalCompare(); + } + + std::shared_ptr parseLogicalAnd() { + auto left = parseLogicalNot(); + if (!left) throw std::runtime_error("Expected left side of 'logical and' expression"); + + static std::regex and_tok(R"(and\b)"); + auto location = get_location(); + while (!consumeToken(and_tok).empty()) { + auto right = parseLogicalNot(); + if (!right) throw std::runtime_error("Expected right side of 'and' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); + } + return left; + } + + std::shared_ptr parseLogicalCompare() { + auto left = parseStringConcat(); + if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); + + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)"); + static std::regex not_tok(R"(not\b)"); + std::string op_str; + while (!(op_str = consumeToken(compare_tok)).empty()) { + auto location = get_location(); + if (op_str == "is") { + auto negated = !consumeToken(not_tok).empty(); + + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); + + return std::make_shared( + left->location, + std::move(left), std::move(identifier), + negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); + } + auto right = parseStringConcat(); + if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression"); + BinaryOpExpr::Op op; + if (op_str == "==") op = BinaryOpExpr::Op::Eq; + else if (op_str == "!=") op = BinaryOpExpr::Op::Ne; + else if (op_str == "<") op = BinaryOpExpr::Op::Lt; + else if (op_str == ">") op = BinaryOpExpr::Op::Gt; + else if (op_str == "<=") op = BinaryOpExpr::Op::Le; + else if (op_str == ">=") op = BinaryOpExpr::Op::Ge; + else if (op_str == "in") op = BinaryOpExpr::Op::In; + else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn; + else throw std::runtime_error("Unknown comparison operator: " + op_str); + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + Expression::Parameters parseParameters() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list"); + + Expression::Parameters result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.emplace_back(ident->get_name(), std::move(value)); + } else { + result.emplace_back(ident->get_name(), nullptr); + } + } else { + result.emplace_back(std::string(), std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + ArgumentsExpression parseCallArgs() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args"); + + ArgumentsExpression result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.kwargs.emplace_back(ident->get_name(), std::move(value)); + } else { + result.args.emplace_back(std::move(expr)); + } + } else { + result.args.emplace_back(std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + std::shared_ptr parseIdentifier() { + static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)"); + auto location = get_location(); + auto ident = consumeToken(ident_regex); + if (ident.empty()) + return nullptr; + return std::make_shared(location, ident); + } + + std::shared_ptr parseStringConcat() { + auto left = parseMathPow(); + if (!left) throw std::runtime_error("Expected left side of 'string concat' expression"); + + static std::regex concat_tok(R"(~(?!\}))"); + if (!consumeToken(concat_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'string concat' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); + } + return left; + } + + std::shared_ptr parseMathPow() { + auto left = parseMathPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math pow' expression"); + + while (!consumeToken("**").empty()) { + auto right = parseMathPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math pow' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); + } + return left; + } + + std::shared_ptr parseMathPlusMinus() { + static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); + + auto left = parseMathMulDiv(); + if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression"); + std::string op_str; + while (!(op_str = consumeToken(plus_minus_tok)).empty()) { + auto right = parseMathMulDiv(); + if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression"); + auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + std::shared_ptr parseMathMulDiv() { + auto left = parseMathUnaryPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); + + static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))"); + std::string op_str; + while (!(op_str = consumeToken(mul_div_tok)).empty()) { + auto right = parseMathUnaryPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression"); + auto op = op_str == "*" ? BinaryOpExpr::Op::Mul + : op_str == "**" ? BinaryOpExpr::Op::MulMul + : op_str == "/" ? BinaryOpExpr::Op::Div + : op_str == "//" ? BinaryOpExpr::Op::DivDiv + : BinaryOpExpr::Op::Mod; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + + if (!consumeToken("|").empty()) { + auto expr = parseMathMulDiv(); + if (auto filter = dynamic_cast(expr.get())) { + filter->prepend(std::move(left)); + return expr; + } else { + std::vector> parts; + parts.emplace_back(std::move(left)); + parts.emplace_back(std::move(expr)); + return std::make_shared(get_location(), std::move(parts)); + } + } + return left; + } + + std::shared_ptr call_func(const std::string & name, ArgumentsExpression && args) const { + return std::make_shared(get_location(), std::make_shared(get_location(), name), std::move(args)); + } + + std::shared_ptr parseMathUnaryPlusMinus() { + static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); + auto op_str = consumeToken(unary_plus_minus_tok); + auto expr = parseExpansion(); + if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression"); + + if (!op_str.empty()) { + auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; + return std::make_shared(get_location(), std::move(expr), op); + } + return expr; + } + + std::shared_ptr parseExpansion() { + static std::regex expansion_tok(R"(\*\*?)"); + auto op_str = consumeToken(expansion_tok); + auto expr = parseValueExpression(); + if (op_str.empty()) return expr; + if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression"); + return std::make_shared(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict); + } + + std::shared_ptr parseValueExpression() { + auto parseValue = [&]() -> std::shared_ptr { + auto location = get_location(); + auto constant = parseConstant(); + if (constant) return std::make_shared(location, *constant); + + static std::regex null_regex(R"(null\b)"); + if (!consumeToken(null_regex).empty()) return std::make_shared(location, Value()); + + auto identifier = parseIdentifier(); + if (identifier) return identifier; + + auto braced = parseBracedExpressionOrArray(); + if (braced) return braced; + + auto array = parseArray(); + if (array) return array; + + auto dictionary = parseDictionary(); + if (dictionary) return dictionary; + + throw std::runtime_error("Expected value expression"); + }; + + auto value = parseValue(); + + while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { + if (!consumeToken("[").empty()) { + std::shared_ptr index; + if (!consumeToken(":").empty()) { + auto slice_end = parseExpression(); + index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); + } else { + auto slice_start = parseExpression(); + if (!consumeToken(":").empty()) { + consumeSpaces(); + if (peekSymbols({ "]" })) { + index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); + } else { + auto slice_end = parseExpression(); + index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); + } + } else { + index = std::move(slice_start); + } + } + if (!index) throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + + value = std::make_shared(value->location, std::move(value), std::move(index)); + } else if (!consumeToken(".").empty()) { + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier in subscript"); + + consumeSpaces(); + if (peekSymbols({ "(" })) { + auto callParams = parseCallArgs(); + value = std::make_shared(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); + } else { + auto key = std::make_shared(identifier->location, Value(identifier->get_name())); + value = std::make_shared(identifier->location, std::move(value), std::move(key)); + } + } + consumeSpaces(); + } + + if (peekSymbols({ "(" })) { + auto location = get_location(); + auto callParams = parseCallArgs(); + value = std::make_shared(location, std::move(value), std::move(callParams)); + } + return value; + } + + std::shared_ptr parseBracedExpressionOrArray() { + if (consumeToken("(").empty()) return nullptr; + + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in braced expression"); + + if (!consumeToken(")").empty()) { + return expr; // Drop the parentheses + } + + std::vector> tuple; + tuple.emplace_back(std::move(expr)); + + while (it != end) { + if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple"); + auto next = parseExpression(); + if (!next) throw std::runtime_error("Expected expression in tuple"); + tuple.push_back(std::move(next)); + + if (!consumeToken(")").empty()) { + return std::make_shared(get_location(), std::move(tuple)); + } + } + throw std::runtime_error("Expected closing parenthesis"); + } + + std::shared_ptr parseArray() { + if (consumeToken("[").empty()) return nullptr; + + std::vector> elements; + if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + auto first_expr = parseExpression(); + if (!first_expr) throw std::runtime_error("Expected first expression in array"); + elements.push_back(std::move(first_expr)); + + while (it != end) { + if (!consumeToken(",").empty()) { + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in array"); + elements.push_back(std::move(expr)); + } else if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing bracket in array"); + } + } + throw std::runtime_error("Expected closing bracket"); + } + + std::shared_ptr parseDictionary() { + if (consumeToken("{").empty()) return nullptr; + + std::vector, std::shared_ptr>> elements; + if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + + auto parseKeyValuePair = [&]() { + auto key = parseExpression(); + if (!key) throw std::runtime_error("Expected key in dictionary"); + if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary"); + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in dictionary"); + elements.emplace_back(std::pair(std::move(key), std::move(value))); + }; + + parseKeyValuePair(); + + while (it != end) { + if (!consumeToken(",").empty()) { + parseKeyValuePair(); + } else if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing brace in dictionary"); + } + } + throw std::runtime_error("Expected closing brace"); + } + + SpaceHandling parsePreSpace(const std::string& s) const { + if (s == "-") + return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + SpaceHandling parsePostSpace(const std::string& s) const { + if (s == "-") return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + using TemplateTokenVector = std::vector>; + using TemplateTokenIterator = TemplateTokenVector::const_iterator; + + std::vector parseVarNames() { + static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)"); + + std::vector group; + if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); + std::vector varnames; + std::istringstream iss(group[1]); + std::string varname; + while (std::getline(iss, varname, ',')) { + varnames.push_back(strip(varname)); + } + return varnames; + } + + std::runtime_error unexpected(const TemplateToken & token) const { + return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + std::runtime_error unterminated(const TemplateToken & token) const { + return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + + TemplateTokenVector tokenize() { + static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); + static std::regex expr_open_regex(R"(\{\{([-~])?)"); + static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)"); + static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); + static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); + static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); + + TemplateTokenVector tokens; + std::vector group; + std::string text; + std::smatch match; + + try { + while (it != end) { + auto location = get_location(); + + if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto content = group[2]; + auto post_space = parsePostSpace(group[3]); + tokens.push_back(std::make_unique(location, pre_space, post_space, content)); + } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto expr = parseExpression(); + + if ((group = consumeTokenGroups(expr_close_regex)).empty()) { + throw std::runtime_error("Expected closing expression tag"); + } + + auto post_space = parsePostSpace(group[1]); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(expr))); + } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + + std::string keyword; + + auto parseBlockClose = [&]() -> SpaceHandling { + if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag"); + return parsePostSpace(group[1]); + }; + + if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword"); + + if (keyword == "if") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in if block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "elif") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in elif block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "else") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "endif") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "for") { + static std::regex recursive_tok(R"(recursive\b)"); + static std::regex if_tok(R"(if\b)"); + + auto varnames = parseVarNames(); + static std::regex in_tok(R"(in\b)"); + if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block"); + auto iterable = parseExpression(/* allow_if_expr = */ false); + if (!iterable) throw std::runtime_error("Expected iterable in for block"); + + std::shared_ptr condition; + if (!consumeToken(if_tok).empty()) { + condition = parseExpression(); + } + auto recursive = !consumeToken(recursive_tok).empty(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive)); + } else if (keyword == "endfor") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "generation") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "endgeneration") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "set") { + static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))"); + + std::string ns; + std::vector var_names; + std::shared_ptr value; + if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { + ns = group[1]; + var_names.push_back(group[2]); + + if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block"); + + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } else { + var_names = parseVarNames(); + + if (!consumeToken("=").empty()) { + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } + } + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, ns, var_names, std::move(value))); + } else if (keyword == "endset") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "macro") { + auto macroname = parseIdentifier(); + if (!macroname) throw std::runtime_error("Expected macro name in macro block"); + auto params = parseParameters(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(macroname), std::move(params))); + } else if (keyword == "endmacro") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "filter") { + auto filter = parseExpression(); + if (!filter) throw std::runtime_error("Expected expression in filter block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(filter))); + } else if (keyword == "endfilter") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else { + throw std::runtime_error("Unexpected block: " + keyword); + } + } else if (std::regex_search(it, end, match, non_text_open_regex)) { + auto text_end = it + match.position(); + text = std::string(it, text_end); + it = text_end; + tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } else { + text = std::string(it, end); + it = end; + tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } + } + return tokens; + } catch (const std::exception & e) { + throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it))); + } + } + + std::shared_ptr parseTemplate( + const TemplateTokenIterator & begin, + TemplateTokenIterator & it, + const TemplateTokenIterator & end, + bool fully = false) const { + std::vector> children; + while (it != end) { + const auto start = it; + const auto & token = *(it++); + if (auto if_token = dynamic_cast(token.get())) { + std::vector, std::shared_ptr>> cascade; + cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end)); + + while (it != end && (*it)->type == TemplateToken::Type::Elif) { + auto elif_token = dynamic_cast((*(it++)).get()); + cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end)); + } + + if (it != end && (*it)->type == TemplateToken::Type::Else) { + cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end)); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(cascade))); + } else if (auto for_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + auto else_body = std::shared_ptr(); + if (it != end && (*it)->type == TemplateToken::Type::Else) { + else_body = parseTemplate(begin, ++it, end); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); + } else if (dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndGeneration) { + throw unterminated(**start); + } + // Treat as a no-op, as our scope is templates for inference, not training (`{% generation %}` wraps generated tokens for masking). + children.emplace_back(std::move(body)); + } else if (auto text_token = dynamic_cast(token.get())) { + SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; + SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; + + auto text = text_token->text; + if (post_space == SpaceHandling::Strip) { + static std::regex trailing_space_regex(R"((\s|\r|\n)+$)"); + text = std::regex_replace(text, trailing_space_regex, ""); + } else if (options.lstrip_blocks && it != end) { + auto i = text.size(); + while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--; + if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) { + text.resize(i); + } + } + if (pre_space == SpaceHandling::Strip) { + static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); + text = std::regex_replace(text, leading_space_regex, ""); + } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { + if (text.length() > 0 && text[0] == '\n') { + text.erase(0, 1); + } + } + if (it == end && !options.keep_trailing_newline) { + auto i = text.size(); + if (i > 0 && text[i - 1] == '\n') { + i--; + if (i > 0 && text[i - 1] == '\r') i--; + text.resize(i); + } + } + children.emplace_back(std::make_shared(token->location, text)); + } else if (auto expr_token = dynamic_cast(token.get())) { + children.emplace_back(std::make_shared(token->location, std::move(expr_token->expr))); + } else if (auto set_token = dynamic_cast(token.get())) { + if (set_token->value) { + children.emplace_back(std::make_shared(token->location, set_token->ns, set_token->var_names, std::move(set_token->value))); + } else { + auto value_template = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { + throw unterminated(**start); + } + if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value"); + if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value"); + auto & name = set_token->var_names[0]; + children.emplace_back(std::make_shared(token->location, name, std::move(value_template))); + } + } else if (auto macro_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if (auto filter_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(filter_token->filter), std::move(body))); + } else if (dynamic_cast(token.get())) { + // Ignore comments + } else if (dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get())) { + it--; // unconsume the token + break; // exit the loop + } else { + throw unexpected(**(it-1)); + } + } + if (fully && it != end) { + throw unexpected(**it); + } + if (children.empty()) { + return std::make_shared(Location { template_str, 0 }, std::string()); + } else if (children.size() == 1) { + return std::move(children[0]); + } else { + return std::make_shared(children[0]->location(), std::move(children)); + } + } + +public: + + static std::shared_ptr parse(const std::string& template_str, const Options & options) { + Parser parser(std::make_shared(normalize_newlines(template_str)), options); + auto tokens = parser.tokenize(); + TemplateTokenIterator begin = tokens.begin(); + auto it = begin; + TemplateTokenIterator end = tokens.end(); + return parser.parseTemplate(begin, it, end, /* full= */ true); + } +}; + +static Value simple_function(const std::string & fn_name, const std::vector & params, const std::function &, Value & args)> & fn) { + std::map named_positions; + for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i; + + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) -> Value { + auto args_obj = Value::object(); + std::vector provided_args(params.size()); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i < params.size()) { + args_obj.set(params[i], arg); + provided_args[i] = true; + } else { + throw std::runtime_error("Too many positional params for " + fn_name); + } + } + for (auto & [name, value] : args.kwargs) { + auto named_pos_it = named_positions.find(name); + if (named_pos_it == named_positions.end()) { + throw std::runtime_error("Unknown argument " + name + " for function " + fn_name); + } + provided_args[named_pos_it->second] = true; + args_obj.set(name, value); + } + return fn(context, args_obj); + }); +} + +inline std::shared_ptr Context::builtins() { + auto globals = Value::object(); + + globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr &, Value & args) -> Value { + throw std::runtime_error(args.at("message").get()); + })); + globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { + return Value(args.at("value").dump(args.get("indent", -1), /* tojson= */ true)); + })); + globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { + auto items = Value::array(); + if (args.contains("object")) { + auto & obj = args.at("object"); + if (obj.is_string()) { + auto json_obj = json::parse(obj.get()); + for (const auto & kv : json_obj.items()) { + items.push_back(Value::array({kv.key(), kv.value()})); + } + } else if (!obj.is_null()) { + for (auto & key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); + } + } + } + return items; + })); + globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { + auto items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not a list"); + if (items.size() == 0) return Value(); + return items.at(items.size() - 1); + })); + globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { + auto & text = args.at("text"); + return text.is_null() ? text : Value(strip(text.get())); + })); + globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr &, Value & args) { + auto text = args.at("text"); + if (text.is_null()) return text; + std::string res; + auto str = text.get(); + std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower); + return Value(res); + })); + globals.set("default", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + args.expectArgs("default", {2, 3}, {0, 1}); + auto & value = args.args[0]; + auto & default_value = args.args[1]; + bool boolean = false; + if (args.args.size() == 3) { + boolean = args.args[2].get(); + } else { + Value bv = args.get_named("boolean"); + if (!bv.is_null()) { + boolean = bv.get(); + } + } + return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value; + })); + auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr &, Value & args) { + return Value(html_escape(args.at("text").get())); + }); + globals.set("e", escape); + globals.set("escape", escape); + globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr &, Value & args) { + auto sep = args.get("sep", ""); + auto first = std::make_shared(true); + return simple_function("", {}, [sep, first](const std::shared_ptr &, const Value &) -> Value { + if (*first) { + *first = false; + return ""; + } + return sep; + }); + return Value(html_escape(args.at("text").get())); + })); + globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr &, Value & args) { + return Value((int64_t) args.at("items").size()); + })); + globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr &, Value & args) { + if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)"); + auto & value = args.at("value"); + auto keys = value.keys(); + std::sort(keys.begin(), keys.end()); + auto res = Value::array(); + for (auto & key : keys) { + res.push_back(Value::array({key, value.at(key)})); + } + return res; + })); + globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { + auto do_join = [](Value & items, const std::string & sep) { + std::ostringstream oss; + auto first = true; + for (size_t i = 0, n = items.size(); i < n; ++i) { + if (first) first = false; + else oss << sep; + oss << items.at(i).to_str(); + } + return Value(oss.str()); + }; + auto sep = args.get("d", ""); + if (args.contains("items")) { + auto & items = args.at("items"); + return do_join(items, sep); + } else { + return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr &, Value & args) { + auto & items = args.at("items"); + if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump()); + return do_join(items, sep); + }); + } + })); + globals.set("namespace", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + auto ns = Value::object(); + args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits::max)()}); + for (auto & [name, value] : args.kwargs) { + ns.set(name, value); + } + return ns; + })); + auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("actual") == args.at("expected"); + }); + globals.set("equalto", equalto); + globals.set("==", equalto); + globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + return (int64_t) items.size(); + })); + globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_int(); + })); + globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + return items; + })); + globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + std::unordered_set seen; + auto result = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto pair = seen.insert(items.at(i)); + if (pair.second) { + result.push_back(items.at(i)); + } + } + return result; + })); + auto make_filter = [](const Value & filter, Value & extra_args) -> Value { + return simple_function("", { "value" }, [=](const std::shared_ptr & context, Value & args) { + auto & value = args.at("value"); + ArgumentsValue actual_args; + actual_args.args.emplace_back(value); + for (size_t i = 0, n = extra_args.size(); i < n; i++) { + actual_args.args.emplace_back(extra_args.at(i)); + } + return filter.call(context, actual_args); + }); + }; + auto select_or_reject = [make_filter](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "select" : "reject", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + auto filter_fn = context->get(args.args[1]); + if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + + auto filter_args = Value::array(); + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.push_back(args.args[i]); + } + auto filter = make_filter(filter_fn, filter_args); + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + ArgumentsValue filter_args; + filter_args.args.emplace_back(item); + auto pred_res = filter.call(context, filter_args); + if (pred_res.to_bool() == (is_select ? true : false)) { + res.push_back(item); + } + } + return res; + }); + }; + globals.set("select", select_or_reject(/* is_select= */ true)); + globals.set("reject", select_or_reject(/* is_select= */ false)); + globals.set("map", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + auto res = Value::array(); + if (args.args.size() == 1 && + ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) { + auto & items = args.args[0]; + auto attr_name = args.get_named("attribute"); + auto default_value = args.get_named("default"); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + res.push_back(attr.is_null() ? default_value : attr); + } + } else if (args.kwargs.empty() && args.args.size() >= 2) { + auto fn = context->get(args.args[1]); + if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + ArgumentsValue filter_args { {Value()}, {} }; + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.args.emplace_back(args.args[i]); + } + for (size_t i = 0, n = args.args[0].size(); i < n; i++) { + auto & item = args.args[0].at(i); + filter_args.args[0] = item; + res.push_back(fn.call(context, filter_args)); + } + } else { + throw std::runtime_error("Invalid or unsupported arguments for map"); + } + return res; + })); + globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr &, Value & args) { + auto text = args.at("text").get(); + auto first = args.get("first", false); + std::string out; + std::string indent(args.get("indent", 0), ' '); + std::istringstream iss(text); + std::string line; + auto is_first = true; + while (std::getline(iss, line, '\n')) { + auto needs_indent = !is_first || first; + if (is_first) is_first = false; + else out += "\n"; + if (needs_indent) out += indent; + out += line; + } + if (!text.empty() && text.back() == '\n') out += "\n"; + return out; + })); + auto select_or_reject_attr = [](bool is_select) { + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs(is_select ? "selectattr" : "rejectattr", {2, (std::numeric_limits::max)()}, {0, 0}); + auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + auto attr_name = args.args[1].get(); + + bool has_test = false; + Value test_fn; + ArgumentsValue test_args {{Value()}, {}}; + if (args.args.size() >= 3) { + has_test = true; + test_fn = context->get(args.args[2]); + if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); + for (size_t i = 3, n = args.args.size(); i < n; i++) { + test_args.args.emplace_back(args.args[i]); + } + test_args.kwargs = args.kwargs; + } + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + if (has_test) { + test_args.args[0] = attr; + if (test_fn.call(context, test_args).to_bool() == (is_select ? true : false)) { + res.push_back(item); + } + } else { + res.push_back(attr); + } + } + return res; + }); + }; + globals.set("selectattr", select_or_reject_attr(/* is_select= */ true)); + globals.set("rejectattr", select_or_reject_attr(/* is_select= */ false)); + globals.set("range", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + std::vector startEndStep(3); + std::vector param_set(3); + if (args.args.size() == 1) { + startEndStep[1] = args.args[0].get(); + param_set[1] = true; + } else { + for (size_t i = 0; i < args.args.size(); i++) { + auto & arg = args.args[i]; + auto v = arg.get(); + startEndStep[i] = v; + param_set[i] = true; + } + } + for (auto & [name, value] : args.kwargs) { + size_t i; + if (name == "start") i = 0; + else if (name == "end") i = 1; + else if (name == "step") i = 2; + else throw std::runtime_error("Unknown argument " + name + " for function range"); + + if (param_set[i]) { + throw std::runtime_error("Duplicate argument " + name + " for function range"); + } + startEndStep[i] = value.get(); + param_set[i] = true; + } + if (!param_set[1]) { + throw std::runtime_error("Missing required argument 'end' for function range"); + } + int64_t start = param_set[0] ? startEndStep[0] : 0; + int64_t end = startEndStep[1]; + int64_t step = param_set[2] ? startEndStep[2] : 1; + + auto res = Value::array(); + if (step > 0) { + for (int64_t i = start; i < end; i += step) { + res.push_back(Value(i)); + } + } else { + for (int64_t i = start; i > end; i += step) { + res.push_back(Value(i)); + } + } + return res; + })); + + return std::make_shared(std::move(globals)); +} + +inline std::shared_ptr Context::make(Value && values, const std::shared_ptr & parent) { + return std::make_shared(values.is_null() ? Value::object() : std::move(values), parent); +} + +} // namespace minja diff --git a/common/ngram-cache.cpp b/common/ngram-cache.cpp index 3ca112ef1..a057ae45f 100644 --- a/common/ngram-cache.cpp +++ b/common/ngram-cache.cpp @@ -2,10 +2,13 @@ #include "common.h" #include "log.h" +#include #include +#include #include +#include -void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max, +void common_ngram_cache_update(common_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector & inp, int nnew, bool print_progress) { const int64_t t_start_ms = ggml_time_ms(); const int64_t inp_size = inp.size(); @@ -17,16 +20,16 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in const int64_t i_start = std::max(inp_size - nnew, ngram_size); for (int64_t i = i_start; i < inp_size; ++i) { const int64_t ngram_start = i - ngram_size; - llama_ngram ngram(&inp[ngram_start], ngram_size); + common_ngram ngram(&inp[ngram_start], ngram_size); const llama_token token = inp[i]; - llama_ngram_cache::iterator part_it = ngram_cache.find(ngram); + common_ngram_cache::iterator part_it = ngram_cache.find(ngram); if (part_it == ngram_cache.end()) { - llama_ngram_cache_part part; + common_ngram_cache_part part; part.emplace(token, 1); ngram_cache.emplace(ngram, part); } else { - llama_ngram_cache_part::iterator token_count_it = part_it->second.find(token); + common_ngram_cache_part::iterator token_count_it = part_it->second.find(token); if (token_count_it == part_it->second.end()) { part_it->second.emplace(token, 1); } else { @@ -59,16 +62,16 @@ constexpr int draft_min_sample_size_strict[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2}; constexpr int draft_min_percent_strict[LLAMA_NGRAM_MAX] = {75, 66, 66, 66}; // Helper function that tries to draft a token from only the static ngram cache: -static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ngram_static) { - llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); +static llama_token try_draft(common_ngram_cache & nc_static, const common_ngram ngram_static) { + common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); if (part_static_it == nc_static.end()) { - return -1; + return LLAMA_TOKEN_NULL; } - const llama_ngram_cache_part part_static = part_static_it->second; + const common_ngram_cache_part part_static = part_static_it->second; int max_count_static = 0; int sum_count_static = 0; - llama_token max_token = -1; + llama_token max_token = LLAMA_TOKEN_NULL; for (std::pair token_count_static : part_static) { const llama_token token = token_count_static.first; @@ -82,39 +85,39 @@ static llama_token try_draft(llama_ngram_cache & nc_static, const llama_ngram ng } if (sum_count_static < draft_min_sample_size_lax[LLAMA_NGRAM_STATIC-1]) { - return -1; + return LLAMA_TOKEN_NULL; } if (100*max_count_static < draft_min_percent_lax[LLAMA_NGRAM_STATIC-1]*sum_count_static) { - return -1; + return LLAMA_TOKEN_NULL; } return max_token; } // Try to draft a token from primary cache (context/dynamic), validate with static cache: static llama_token try_draft( - llama_ngram_cache & nc_primary, const std::vector & ngrams_primary, llama_ngram_cache_part & part_static, + common_ngram_cache & nc_primary, const std::vector & ngrams_primary, common_ngram_cache_part & part_static, const int * min_sample_size, const int * min_percent) { - llama_token drafted_token = -1; + llama_token drafted_token = LLAMA_TOKEN_NULL; - for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) { - const llama_ngram ngram_primary = ngrams_primary[i]; + for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == LLAMA_TOKEN_NULL; --i) { + const common_ngram ngram_primary = ngrams_primary[i]; - llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary); + common_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary); if (part_primary_it == nc_primary.end()) { continue; } - const llama_ngram_cache_part part_primary = part_primary_it->second; + const common_ngram_cache_part part_primary = part_primary_it->second; int max_count_primary = 0; int max_count_static = 0; int sum_count_primary = 0; - llama_token max_token = -1; + llama_token max_token = LLAMA_TOKEN_NULL; for (std::pair token_count_primary : part_primary) { const llama_token token = token_count_primary.first; - llama_ngram_cache_part::iterator token_count_static_it = part_static.find(token); + common_ngram_cache_part::iterator token_count_static_it = part_static.find(token); const int32_t count_primary = token_count_primary.second; const int32_t count_static = token_count_static_it != part_static.end() ? 100*token_count_static_it->second : 1; @@ -139,9 +142,9 @@ static llama_token try_draft( return drafted_token; } -void llama_ngram_cache_draft( +void common_ngram_cache_draft( std::vector & inp, std::vector & draft, int n_draft, int ngram_min, int ngram_max, - llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static + common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static ) { GGML_ASSERT(draft.size() == 1); const int inp_size = inp.size(); @@ -151,40 +154,40 @@ void llama_ngram_cache_draft( } while ((int) draft.size()-1 < n_draft) { - llama_token drafted_token = -1; + llama_token drafted_token = LLAMA_TOKEN_NULL; const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1; - llama_ngram ngram_static; + common_ngram ngram_static; for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) { ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j); } - llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); - llama_ngram_cache_part part_static; + common_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); + common_ngram_cache_part part_static; if (part_static_it != nc_static.end()) { part_static = part_static_it->second; } // cd = context + dynamic - std::vector ngrams_cd; + std::vector ngrams_cd; for (int ngram_size_cd = ngram_min; ngram_size_cd <= ngram_max; ++ngram_size_cd) { const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1; - llama_ngram ngram_cd; + common_ngram ngram_cd; for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) { ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j); } ngrams_cd.push_back(ngram_cd); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { drafted_token = try_draft(nc_context, ngrams_cd, part_static, draft_min_sample_size_lax, draft_min_percent_lax); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { drafted_token = try_draft(nc_dynamic, ngrams_cd, part_static, draft_min_sample_size_strict, draft_min_percent_strict); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { drafted_token = try_draft(nc_static, ngram_static); } - if (drafted_token == -1) { + if (drafted_token == LLAMA_TOKEN_NULL) { break; } @@ -193,16 +196,16 @@ void llama_ngram_cache_draft( } } -void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename) { +void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) { std::ofstream file_out(filename, std::ios::binary); - for (std::pair item : ngram_cache) { - const llama_ngram ngram = item.first; - llama_ngram_cache_part token_counts = item.second; + for (std::pair item : ngram_cache) { + const common_ngram ngram = item.first; + common_ngram_cache_part token_counts = item.second; GGML_ASSERT(!token_counts.empty()); const int32_t ntokens = token_counts.size(); GGML_ASSERT(ntokens > 0); - file_out.write(reinterpret_cast(&ngram), sizeof(llama_ngram)); + file_out.write(reinterpret_cast(&ngram), sizeof(common_ngram)); file_out.write(reinterpret_cast(&ntokens), sizeof(int32_t)); for (std::pair item2 : token_counts) { const llama_token token = item2.first; @@ -216,14 +219,14 @@ void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filen } -llama_ngram_cache llama_ngram_cache_load(std::string & filename) { +common_ngram_cache common_ngram_cache_load(std::string & filename) { std::ifstream hashmap_file(filename, std::ios::binary); if (!hashmap_file) { throw std::ifstream::failure("Unable to open file " + filename); } - llama_ngram_cache ngram_cache; + common_ngram_cache ngram_cache; - llama_ngram ngram; + common_ngram ngram; int32_t ntokens; llama_token token; int32_t count; @@ -232,11 +235,11 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) { char * ntokensc = reinterpret_cast(&ntokens); char * tokenc = reinterpret_cast(&token); char * countc = reinterpret_cast(&count); - while(hashmap_file.read(ngramc, sizeof(llama_ngram))) { + while(hashmap_file.read(ngramc, sizeof(common_ngram))) { GGML_ASSERT(!hashmap_file.eof()); GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t))); GGML_ASSERT(ntokens > 0); - llama_ngram_cache_part token_counts; + common_ngram_cache_part token_counts; for (int i = 0; i < ntokens; ++i) { GGML_ASSERT(!hashmap_file.eof()); @@ -254,12 +257,12 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) { return ngram_cache; } -void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) { - for (std::pair ngram_part : ngram_cache_add) { - const llama_ngram ngram = ngram_part.first; - llama_ngram_cache_part part = ngram_part.second; +void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add) { + for (std::pair ngram_part : ngram_cache_add) { + const common_ngram ngram = ngram_part.first; + common_ngram_cache_part part = ngram_part.second; - llama_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram); + common_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram); if (part_merged_it == ngram_cache_target.end()) { ngram_cache_target.emplace(ngram, part); continue; @@ -270,7 +273,7 @@ void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram const int32_t count = token_count.second; GGML_ASSERT(count > 0); - llama_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token); + common_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token); if (token_count_merged_it == part_merged_it->second.end()) { part_merged_it->second.emplace(token, count); continue; diff --git a/common/ngram-cache.h b/common/ngram-cache.h index ab4c9b376..dfe012abe 100644 --- a/common/ngram-cache.h +++ b/common/ngram-cache.h @@ -12,22 +12,22 @@ // Data structures to map n-grams to empirical token probabilities: -struct llama_ngram { +struct common_ngram { llama_token tokens[LLAMA_NGRAM_MAX]; - llama_ngram() { + common_ngram() { for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { - tokens[i] = -1; + tokens[i] = LLAMA_TOKEN_NULL; } } - llama_ngram(const llama_token * input, const int ngram_size) { + common_ngram(const llama_token * input, const int ngram_size) { for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { - tokens[i] = i < ngram_size ? input[i] : -1; + tokens[i] = i < ngram_size ? input[i] : LLAMA_TOKEN_NULL; } } - bool operator==(const llama_ngram & other) const { + bool operator==(const common_ngram & other) const { for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { if (tokens[i] != other.tokens[i]) { return false; @@ -37,28 +37,28 @@ struct llama_ngram { } }; -struct llama_token_hash_function { +struct common_token_hash_function { size_t operator()(const llama_token token) const { // see https://probablydance.com/2018/06/16/fibonacci-hashing-the-optimization-that-the-world-forgot-or-a-better-alternative-to-integer-modulo/ return token * 11400714819323198485llu; } }; -struct llama_ngram_hash_function { - size_t operator()(const llama_ngram & ngram) const { - size_t hash = llama_token_hash_function{}(ngram.tokens[0]); +struct common_ngram_hash_function { + size_t operator()(const common_ngram & ngram) const { + size_t hash = common_token_hash_function{}(ngram.tokens[0]); for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) { - hash ^= llama_token_hash_function{}(ngram.tokens[i]); + hash ^= common_token_hash_function{}(ngram.tokens[i]); } return hash; } }; // token -> number of times token has been seen -typedef std::unordered_map llama_ngram_cache_part; +typedef std::unordered_map common_ngram_cache_part; // n-gram -> empirical distribution of following tokens -typedef std::unordered_map llama_ngram_cache; +typedef std::unordered_map common_ngram_cache; // Update an ngram cache with tokens. @@ -70,8 +70,8 @@ typedef std::unordered_map & inp_data, int nnew, bool print_progress); +void common_ngram_cache_update( + common_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector & inp_data, int nnew, bool print_progress); // Try to draft tokens from ngram caches. // inp: the tokens generated so far. @@ -81,21 +81,21 @@ void llama_ngram_cache_update( // nc_context: ngram cache based on current context. // nc_dynamic: ngram cache based on previous user generations. // nc_static: ngram cache generated from a large text corpus, used for validation. -void llama_ngram_cache_draft( +void common_ngram_cache_draft( std::vector & inp, std::vector & draft, int n_draft, int ngram_min, int ngram_max, - llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static); + common_ngram_cache & nc_context, common_ngram_cache & nc_dynamic, common_ngram_cache & nc_static); // Save an ngram cache to a file. // ngram_cache: the ngram cache to save. // filename: the path under which to save the ngram cache. -void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename); +void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename); -// Load an ngram cache saved with llama_ngram_cache_save. +// Load an ngram cache saved with common_ngram_cache_save. // filename: the path from which to load the ngram cache. // returns: an ngram cache containing the information saved to filename. -llama_ngram_cache llama_ngram_cache_load(std::string & filename); +common_ngram_cache common_ngram_cache_load(std::string & filename); // Merge two ngram caches. // ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add. // ngram_cache_add: the ngram cache to add to ngram_cache_target. -void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add); +void common_ngram_cache_merge(common_ngram_cache & ngram_cache_target, common_ngram_cache & ngram_cache_add); diff --git a/common/sampling.cpp b/common/sampling.cpp index 7806b77e0..bc7e49fdb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,9 @@ #include "common.h" +#include +#include + // the ring buffer works similarly to std::deque, but with a fixed capacity // TODO: deduplicate with llama-impl.h template @@ -95,8 +98,8 @@ struct ring_buffer { std::vector data; }; -struct gpt_sampler { - gpt_sampler_params params; +struct common_sampler { + common_params_sampling params; struct llama_sampler * grmr; struct llama_sampler * chain; @@ -110,7 +113,10 @@ struct gpt_sampler { void set_logits(struct llama_context * ctx, int idx) { const auto * logits = llama_get_logits_ith(ctx, idx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); cur.resize(n_vocab); @@ -122,28 +128,41 @@ struct gpt_sampler { } }; -std::string gpt_sampler_params::print() const { +std::string common_params_sampling::print() const { char result[1024]; snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" - "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, - top_k, tfs_z, top_p, min_p, typ_p, temp, + dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, + top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); } -struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); - lparams.no_perf = false; // TODO: control via params + lparams.no_perf = params.no_perf; - auto * result = new gpt_sampler { + std::vector trigger_words; + trigger_words.reserve(params.grammar_trigger_words.size()); + for (const auto & str : params.grammar_trigger_words) { + trigger_words.push_back(str.word.c_str()); + } + auto * result = new common_sampler { /* .params = */ params, - /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), + /* .grmr = */ params.grammar_lazy + ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", + trigger_words.data(), trigger_words.size(), + params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) + : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"), /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, @@ -152,68 +171,67 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_chain_add(result->chain, llama_sampler_init_logit_bias( - llama_n_vocab(model), + llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); - llama_sampler_chain_add(result->chain, - llama_sampler_init_penalties( - llama_n_vocab (model), - llama_token_eos(model), - llama_token_nl (model), - params.penalty_last_n, - params.penalty_repeat, - params.penalty_freq, - params.penalty_present, - params.penalize_nl, - params.ignore_eos)); + if (params.mirostat == 0) { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto & str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } - if (params.temp > 0.0f) { - if (params.mirostat == 0) { - for (const auto & cnstr : params.samplers) { - switch (cnstr) { - case GPT_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); - break; - case GPT_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); - break; - case GPT_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); - break; - case GPT_SAMPLER_TYPE_TFS_Z: - llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep)); - break; - case GPT_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); - break; - case GPT_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - default: - GGML_ASSERT(false && "unknown sampler type"); - } + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); } - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); - llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); - } else if (params.mirostat == 1) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); - } else if (params.mirostat == 2) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); - } else { - GGML_ASSERT(false && "unknown mirostat version"); } + llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + } else if (params.mirostat == 1) { + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + } else if (params.mirostat == 2) { + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); } else { - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); - llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); + GGML_ASSERT(false && "unknown mirostat version"); } return result; } -void gpt_sampler_free(struct gpt_sampler * gsmpl) { +void common_sampler_free(struct common_sampler * gsmpl) { if (gsmpl) { llama_sampler_free(gsmpl->grmr); @@ -223,7 +241,7 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { } } -void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) { +void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { if (accept_grammar) { llama_sampler_accept(gsmpl->grmr, token); } @@ -233,14 +251,14 @@ void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool acce gsmpl->prev.push_back(token); } -void gpt_sampler_reset(struct gpt_sampler * gsmpl) { +void common_sampler_reset(struct common_sampler * gsmpl) { llama_sampler_reset(gsmpl->grmr); llama_sampler_reset(gsmpl->chain); } -struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { - return new gpt_sampler { +struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { + return new common_sampler { /* .params = */ gsmpl->params, /* .grmr = */ llama_sampler_clone(gsmpl->grmr), /* .chain = */ llama_sampler_clone(gsmpl->chain), @@ -250,18 +268,18 @@ struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { }; } -void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) { +void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) { // TODO: measure grammar performance if (gsmpl) { - llama_perf_print(gsmpl->chain, LLAMA_PERF_TYPE_SAMPLER_CHAIN); + llama_perf_sampler_print(gsmpl->chain); } if (ctx) { - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + llama_perf_context_print(ctx); } } -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { gsmpl->set_logits(ctx, idx); auto & grmr = gsmpl->grmr; @@ -307,18 +325,61 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context return cur_p.data[cur_p.selected].id; } +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { + GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + + std::vector result; + result.reserve(idxs.size()); + + size_t i = 0; + for (; i < draft.size(); i++) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + common_sampler_accept(gsmpl, id, true); + + result.push_back(id); + + if (draft[i] != id) { + break; + } + } + + if (i == draft.size()) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + common_sampler_accept(gsmpl, id, true); + + result.push_back(id); + } + + return result; +} + +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { + std::vector idxs(draft.size() + 1); + for (size_t i = 0; i < idxs.size(); ++i) { + idxs[i] = i; + } + + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); +} + +uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { + return llama_sampler_get_seed(gsmpl->chain); +} + // helpers -llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { +llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) { return &gsmpl->cur_p; } -llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { +llama_token common_sampler_last(const struct common_sampler * gsmpl) { return gsmpl->prev.rat(0); } -std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { - std::string result = "\tlogits "; +std::string common_sampler_print(const struct common_sampler * gsmpl) { + std::string result = "logits "; for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); @@ -328,7 +389,7 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { return result; } -std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { +std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx_main, int n) { n = std::min(n, (int) gsmpl->prev.size()); if (n <= 0) { @@ -343,63 +404,70 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); - result += llama_token_to_piece(ctx_main, id); + result += common_token_to_piece(ctx_main, id); } return result; } -char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) { +char common_sampler_type_to_chr(enum common_sampler_type cnstr) { switch (cnstr) { - case GPT_SAMPLER_TYPE_TOP_K: return 'k'; - case GPT_SAMPLER_TYPE_TFS_Z: return 'f'; - case GPT_SAMPLER_TYPE_TYPICAL_P: return 'y'; - case GPT_SAMPLER_TYPE_TOP_P: return 'p'; - case GPT_SAMPLER_TYPE_MIN_P: return 'm'; - case GPT_SAMPLER_TYPE_TEMPERATURE: return 't'; + case COMMON_SAMPLER_TYPE_DRY: return 'd'; + case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; + case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; + case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; + case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; + case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; + case COMMON_SAMPLER_TYPE_XTC: return 'x'; + case COMMON_SAMPLER_TYPE_INFILL: return 'i'; + case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; default : return '?'; } } -std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) { +std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { switch (cnstr) { - case GPT_SAMPLER_TYPE_TOP_K: return "top_k"; - case GPT_SAMPLER_TYPE_TFS_Z: return "tfs_z"; - case GPT_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; - case GPT_SAMPLER_TYPE_TOP_P: return "top_p"; - case GPT_SAMPLER_TYPE_MIN_P: return "min_p"; - case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature"; + case COMMON_SAMPLER_TYPE_DRY: return "dry"; + case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; + case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; + case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; + case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; + case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; + case COMMON_SAMPLER_TYPE_XTC: return "xtc"; + case COMMON_SAMPLER_TYPE_INFILL: return "infill"; + case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; default : return ""; } } -std::vector gpt_sampler_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map sampler_canonical_name_map { - { "top_k", GPT_SAMPLER_TYPE_TOP_K }, - { "top_p", GPT_SAMPLER_TYPE_TOP_P }, - { "typ_p", GPT_SAMPLER_TYPE_TYPICAL_P }, - { "min_p", GPT_SAMPLER_TYPE_MIN_P }, - { "tfs_z", GPT_SAMPLER_TYPE_TFS_Z }, - { "temperature", GPT_SAMPLER_TYPE_TEMPERATURE }, +std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map sampler_canonical_name_map { + { "dry", COMMON_SAMPLER_TYPE_DRY }, + { "top_k", COMMON_SAMPLER_TYPE_TOP_K }, + { "top_p", COMMON_SAMPLER_TYPE_TOP_P }, + { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "min_p", COMMON_SAMPLER_TYPE_MIN_P }, + { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, + { "xtc", COMMON_SAMPLER_TYPE_XTC }, + { "infill", COMMON_SAMPLER_TYPE_INFILL }, + { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, }; // since samplers names are written multiple ways // make it ready for both system names and input names - std::unordered_map sampler_alt_name_map { - { "top-k", GPT_SAMPLER_TYPE_TOP_K }, - { "top-p", GPT_SAMPLER_TYPE_TOP_P }, - { "nucleus", GPT_SAMPLER_TYPE_TOP_P }, - { "typical-p", GPT_SAMPLER_TYPE_TYPICAL_P }, - { "typical", GPT_SAMPLER_TYPE_TYPICAL_P }, - { "typ-p", GPT_SAMPLER_TYPE_TYPICAL_P }, - { "typ", GPT_SAMPLER_TYPE_TYPICAL_P }, - { "min-p", GPT_SAMPLER_TYPE_MIN_P }, - { "tfs-z", GPT_SAMPLER_TYPE_TFS_Z }, - { "tfs", GPT_SAMPLER_TYPE_TFS_Z }, - { "temp", GPT_SAMPLER_TYPE_TEMPERATURE }, + std::unordered_map sampler_alt_name_map { + { "top-k", COMMON_SAMPLER_TYPE_TOP_K }, + { "top-p", COMMON_SAMPLER_TYPE_TOP_P }, + { "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, + { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "typ-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, + { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, + { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, }; - std::vector samplers; + std::vector samplers; samplers.reserve(names.size()); for (const auto & name : names) { @@ -419,17 +487,20 @@ std::vector gpt_sampler_types_from_names(const std::vector gpt_sampler_types_from_chars(const std::string & chars) { - std::unordered_map sampler_name_map { - { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_K), GPT_SAMPLER_TYPE_TOP_K }, - { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TFS_Z), GPT_SAMPLER_TYPE_TFS_Z }, - { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P }, - { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P }, - { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P }, - { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE } +std::vector common_sampler_types_from_chars(const std::string & chars) { + std::unordered_map sampler_name_map = { + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_DRY), COMMON_SAMPLER_TYPE_DRY }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, }; - std::vector samplers; + std::vector samplers; samplers.reserve(chars.size()); for (const auto & c : chars) { diff --git a/common/sampling.h b/common/sampling.h index 654e0c513..348911b18 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -2,62 +2,12 @@ #include "llama.h" +#include "common.h" + #include #include -enum gpt_sampler_type { - GPT_SAMPLER_TYPE_NONE = 0, - GPT_SAMPLER_TYPE_TOP_K = 1, - GPT_SAMPLER_TYPE_TOP_P = 2, - GPT_SAMPLER_TYPE_MIN_P = 3, - GPT_SAMPLER_TYPE_TFS_Z = 4, - GPT_SAMPLER_TYPE_TYPICAL_P = 5, - GPT_SAMPLER_TYPE_TEMPERATURE = 6, -}; - -// sampling parameters -struct gpt_sampler_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler - - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typ_p = 1.00f; // typical_p, 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token - bool ignore_eos = false; - - std::vector samplers = { - GPT_SAMPLER_TYPE_TOP_K, - GPT_SAMPLER_TYPE_TFS_Z, - GPT_SAMPLER_TYPE_TYPICAL_P, - GPT_SAMPLER_TYPE_TOP_P, - GPT_SAMPLER_TYPE_MIN_P, - GPT_SAMPLER_TYPE_TEMPERATURE - }; - - std::string grammar; // optional BNF-like grammar to constrain sampling - - std::vector logit_bias; // logit biases to apply - - // print the parameters into a string - std::string print() const; -}; - -// gpt_sampler extends llama_sampler with additional functionality: +// common_sampler extends llama_sampler with additional functionality: // // - grammar support // - custom sampler logic based on the parameters @@ -73,30 +23,30 @@ struct gpt_sampler_params { // token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the // grammar constraints are applied to the full vocabulary and the token is resampled. // -// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can +// The common_sampler also maintains a container with the last accepted tokens. In the future, this can // be moved into the core llama library. // -// For convenience, the gpt_sampler also maintains a container with the current candidate tokens. +// For convenience, the common_sampler also maintains a container with the current candidate tokens. // This can be used to access the probabilities of the rest of the non-sampled tokens. // // TODO: measure grammar performance // -struct gpt_sampler; +struct common_sampler; // llama_sampler API overloads -struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params); +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); -void gpt_sampler_free(struct gpt_sampler * gsmpl); +void common_sampler_free(struct common_sampler * gsmpl); // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar -void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar); -void gpt_sampler_reset (struct gpt_sampler * gsmpl); -struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl); +void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar); +void common_sampler_reset (struct common_sampler * gsmpl); +struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); // arguments can be nullptr to skip printing -void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl); +void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); // extended sampling implementation: // @@ -108,24 +58,47 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * // if grammar_first is true, the grammar is applied before the samplers (slower) // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar // -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); + +// generalized version of common_sampler_sample +// +// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match +// if the sampler disagrees at some point, we stop and return the accepted tokens up to now +// +// common_sampler_sample_n(gsmpl, ctx, { idx }, {}); +// +// is equivalent to +// +// common_sampler_sample(gsmpl, ctx, idx); +// common_sampler_accept(gsmpl, token, true); +// +// requires: idxs.size() == draft.size() + 1 +// +// returns at least 1 token, up to idxs.size() +// +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); + +// assume idxs == [ 0, 1, 2, ..., draft.size() ] +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); + +uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); // helpers // access the internal list of current candidate tokens -llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); +llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl); // get the last accepted token -llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); +llama_token common_sampler_last(const struct common_sampler * gsmpl); // print the sampler chain into a string -std::string gpt_sampler_print(const struct gpt_sampler * gsmpl); +std::string common_sampler_print(const struct common_sampler * gsmpl); // get a string representation of the last accepted tokens -std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n); +std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n); -char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr); -std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr); +char common_sampler_type_to_chr(enum common_sampler_type cnstr); +std::string common_sampler_type_to_str(enum common_sampler_type cnstr); -std::vector gpt_sampler_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector gpt_sampler_types_from_chars(const std::string & chars); +std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector common_sampler_types_from_chars(const std::string & chars); diff --git a/common/speculative.cpp b/common/speculative.cpp new file mode 100644 index 000000000..318e96ea3 --- /dev/null +++ b/common/speculative.cpp @@ -0,0 +1,277 @@ +#include "speculative.h" + +#include "log.h" +#include "common.h" +#include "sampling.h" + +#include + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + +struct common_speculative { + struct llama_context * ctx; + struct common_sampler * smpl; + + llama_batch batch; + llama_tokens prompt; +}; + +struct common_speculative * common_speculative_init( + struct llama_context * ctx_dft) { + auto * result = new common_speculative { + /* .ctx = */ ctx_dft, + /* .smpl = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .prompt = */ {}, + }; + + // TODO: optimize or pass from outside? +#if 0 + { + common_params_sampling params; + params.no_perf = false; + + params.top_k = 40; + params.top_p = 0.9; + + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + COMMON_SAMPLER_TYPE_TOP_P, + COMMON_SAMPLER_TYPE_INFILL, + }; + + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } +#else + { + common_params_sampling params; + params.no_perf = false; + + params.top_k = 10; + + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + } +#endif + + return result; +} + +void common_speculative_free(struct common_speculative * spec) { + if (spec == nullptr) { + return; + } + + common_sampler_free(spec->smpl); + + llama_batch_free(spec->batch); + + delete spec; +} + +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft) { + const struct llama_model * model_tgt = llama_get_model(ctx_tgt); + const struct llama_model * model_dft = llama_get_model(ctx_dft); + + const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + + const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); + LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(vocab_dft); + LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LOG_ERR("%s: draft model vocab type must match target model to use speculation but " + "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); + return false; + } + + if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || + llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || + llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || + llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) { + LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); + LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt)); + LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft)); + return false; + } + + { + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); + + const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " + "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " + "token %d content differs - target '%s', draft '%s'\n", __func__, i, + common_token_to_piece(ctx_tgt, i).c_str(), + common_token_to_piece(ctx_dft, i).c_str()); + return false; + } + } + } + + return true; +} + +llama_tokens common_speculative_gen_draft( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt_tgt, + llama_token id_last) { + auto & batch = spec->batch; + auto & ctx = spec->ctx; + auto & smpl = spec->smpl; + auto & prompt = spec->prompt; + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx) - params.n_draft; + + const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_tgt.size() && + i + cur < (int) prompt.size() && + prompt_tgt[i_start + cur] == prompt[i + cur]) { + cur++; + } + + if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt.size()); + + llama_tokens result; + result.reserve(params.n_draft); + + if (reuse_n == 0) { + llama_kv_cache_clear(ctx); + + prompt.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { + result.push_back(prompt[i]); + + if (params.n_draft <= (int) result.size()) { + break; + } + } + + return result; + } + + if (reuse_i > 0) { + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); + + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + } + + if (reuse_n < (int) prompt.size()) { + llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1); + + prompt.erase(prompt.begin() + reuse_n, prompt.end()); + } + } + + // prepare a batch to evaluate any new tokens in the prompt + common_batch_clear(batch); + + for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + + prompt.push_back(prompt_tgt[i]); + } + + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); + + llama_decode(ctx, batch); + } + + const llama_pos n_past = prompt.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt.push_back(id_last); + + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); + + llama_decode(ctx, batch); + + common_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_draft; ++i) { + common_batch_clear(batch); + + common_sampler_sample(smpl, ctx, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + + common_sampler_accept(smpl, id, true); + + result.push_back(id); + + if (params.n_draft <= (int) result.size()) { + break; + } + + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx, batch); + + prompt.push_back(id); + } + + return result; +} diff --git a/common/speculative.h b/common/speculative.h new file mode 100644 index 000000000..50ec03446 --- /dev/null +++ b/common/speculative.h @@ -0,0 +1,28 @@ +#pragma once + +#include "llama.h" +#include "common.h" + +struct common_speculative; + +struct common_speculative_params { + int n_draft = 16; // max drafted tokens + int n_reuse = 256; + + float p_min = 0.9f; // min probabiliy required to accept a token in the draft +}; + +struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); + +void common_speculative_free(struct common_speculative * spec); + +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft); + +// sample up to n_draft tokens and add them to the batch using the draft model +llama_tokens common_speculative_gen_draft( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt, + llama_token id_last); diff --git a/common/train.cpp b/common/train.cpp deleted file mode 100644 index fef1e57c9..000000000 --- a/common/train.cpp +++ /dev/null @@ -1,1513 +0,0 @@ -#include "train.h" -#include "common.h" - -#include -#include -#include - -struct random_normal_distribution { - std::mt19937 gen; - std::normal_distribution rd; - float min; - float max; -}; - -struct random_uniform_distribution { - std::mt19937 gen; - std::uniform_real_distribution rd; -}; - -struct train_state * init_train_state() { - struct train_state * state = new struct train_state; - state->train_its = 0; - state->train_samples = 0; - state->train_tokens = 0; - state->train_epochs = 0; - state->shuffle_samples_hash = 0; - state->shuffle_sample_count = 0; - state->shuffle_next_sample = 0; - state->shuffle_rng_state_current = ""; - state->shuffle_rng_state_next = ""; - - state->opt = new struct ggml_opt_context; - state->opt->ctx = NULL; - state->opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM); - state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES; - state->opt->loss_after = 0.0f; - - return state; -} - -void free_train_state(struct train_state * state) { - delete state->opt; - delete state; -} - -struct random_normal_distribution * init_random_normal_distribution( - int seed, float mean, float std, float min, float max -) { - struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution)); - rnd->gen = std::mt19937(seed); - rnd->rd = std::normal_distribution{mean, std}; - rnd->min = min; - rnd->max = max; - return rnd; -} - -struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max) { - struct random_uniform_distribution * rnd = (struct random_uniform_distribution *) malloc(sizeof(struct random_uniform_distribution)); - rnd->gen = std::mt19937(seed); - rnd->rd = std::uniform_real_distribution{min, max}; - return rnd; -} - -void free_random_normal_distribution (struct random_normal_distribution * rnd) { - free(rnd); -} - -void free_random_uniform_distribution(struct random_uniform_distribution * rnd) { - free(rnd); -} - -struct ggml_tensor * randomize_tensor_normal(struct ggml_tensor * tensor, struct random_normal_distribution * rnd) { - float scale = 1.0f; // xavier - switch (ggml_n_dims(tensor)) { - case 1: - scale /= sqrtf((float) tensor->ne[0]); - for (int i0 = 0; i0 < tensor->ne[0]; i0++) { - float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]); - *dst = scale * frand_normal(rnd); - } - break; - case 2: - scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]); - for (int i1 = 0; i1 < tensor->ne[1]; i1++) { - for (int i0 = 0; i0 < tensor->ne[0]; i0++) { - float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); - *dst = scale * frand_normal(rnd); - } - } - break; - case 3: - scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]); - for (int i2 = 0; i2 < tensor->ne[2]; i2++) { - for (int i1 = 0; i1 < tensor->ne[1]; i1++) { - for (int i0 = 0; i0 < tensor->ne[0]; i0++) { - float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); - *dst = scale * frand_normal(rnd); - } - } - } - break; - case 4: - scale /= sqrtf((float) tensor->ne[0]+tensor->ne[1]); - for (int i3 = 0; i3 < tensor->ne[3]; i3++) { - for (int i2 = 0; i2 < tensor->ne[2]; i2++) { - for (int i1 = 0; i1 < tensor->ne[1]; i1++) { - for (int i0 = 0; i0 < tensor->ne[0]; i0++) { - float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]); - *dst = scale * frand_normal(rnd); - } - } - } - } - break; - default: - die("Unsupported tensor->n_dims"); - }; - return tensor; -} - -struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd) { - switch (ggml_n_dims(tensor)) { - case 1: - for (int i0 = 0; i0 < tensor->ne[0]; i0++) { - float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0]); - *dst = frand_uniform(rnd); - } - break; - case 2: - for (int i1 = 0; i1 < tensor->ne[1]; i1++) { - for (int i0 = 0; i0 < tensor->ne[0]; i0++) { - float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); - *dst = frand_uniform(rnd); - } - } - break; - case 3: - for (int i2 = 0; i2 < tensor->ne[2]; i2++) { - for (int i1 = 0; i1 < tensor->ne[1]; i1++) { - for (int i0 = 0; i0 < tensor->ne[0]; i0++) { - float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2]); - *dst = frand_uniform(rnd); - } - } - } - break; - case 4: - for (int i3 = 0; i3 < tensor->ne[3]; i3++) { - for (int i2 = 0; i2 < tensor->ne[2]; i2++) { - for (int i1 = 0; i1 < tensor->ne[1]; i1++) { - for (int i0 = 0; i0 < tensor->ne[0]; i0++) { - float * dst = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]); - *dst = frand_uniform(rnd); - } - } - } - } - break; - default: - die("Unsupported tensor->n_dims"); - }; - return tensor; -} - -float frand() { - return (float)rand()/((float)(RAND_MAX) + 1.0f); -} - -float frand_normal(struct random_normal_distribution * rnd) { - return fclamp(rnd->rd(rnd->gen), rnd->min, rnd->max); -} - -float frand_uniform(struct random_uniform_distribution * rnd) { - return rnd->rd(rnd->gen); -} - -int clamp(const int v, const int min, const int max) { - return ((v < min) ? (min) : (v > max) ? (max) : v); -} - -float fclamp(const float v, const float min, const float max) { - return ((v < min) ? (min) : (v > max) ? (max) : v); -} - -void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) { - GGML_ASSERT(tensor->ne[0] == ne0); - GGML_ASSERT(tensor->ne[1] == 1); - GGML_ASSERT(tensor->ne[2] == 1); - GGML_ASSERT(tensor->ne[3] == 1); -} - -void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) { - GGML_ASSERT(tensor->ne[0] == ne0); - GGML_ASSERT(tensor->ne[1] == ne1); - GGML_ASSERT(tensor->ne[2] == 1); - GGML_ASSERT(tensor->ne[3] == 1); -} - -void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) { - GGML_ASSERT(tensor->ne[0] == ne0); - GGML_ASSERT(tensor->ne[1] == ne1); - GGML_ASSERT(tensor->ne[2] == ne2); - GGML_ASSERT(tensor->ne[3] == 1); -} - -void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { - GGML_ASSERT(tensor->ne[0] == ne0); - GGML_ASSERT(tensor->ne[1] == ne1); - GGML_ASSERT(tensor->ne[2] == ne2); - GGML_ASSERT(tensor->ne[3] == ne3); -} - -int64_t get_example_targets_batch( - struct llama_context * lctx, - struct ggml_tensor * tokens_input, - struct ggml_tensor * target_probs, - int64_t example_id, - const size_t * samples_offs, - const size_t * samples_begin, - const size_t * samples_size, - size_t samples_count, - const llama_token * train_data, - size_t n_train_data, - bool separate_with_eos, - bool separate_with_bos, - bool fill_with_next_samples, - bool sample_random_offsets -) { - GGML_ASSERT(samples_count > 0); - GGML_ASSERT(ggml_is_matrix(tokens_input)); - GGML_ASSERT(ggml_is_3d(target_probs)); - int64_t n_vocab = target_probs->ne[0]; - int64_t n_tokens = tokens_input->ne[0]; - int64_t n_batch = tokens_input->ne[1]; - GGML_ASSERT(n_vocab == target_probs->ne[0]); - GGML_ASSERT(n_tokens == target_probs->ne[1]); - GGML_ASSERT(n_batch == target_probs->ne[2]); - - int64_t used_samples = 0; - - ggml_set_f32(target_probs, 0.0f); - llama_token bos = llama_token_bos(llama_get_model(lctx)); - llama_token eos = llama_token_eos(llama_get_model(lctx)); - // printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples); - for (int k=0; k= sample_size && fill_with_next_samples) { - if (!sample_separation_eos) { - // insert eos token to separate samples - sample_separation_eos = true; - } else if (!sample_separation_bos) { - // insert bos token to separate samples - sample_separation_bos = true; - token = bos; - } else { - // sample separation is done, continue with next sample - sample_separation_eos = !separate_with_eos; - sample_separation_bos = !separate_with_bos; - sample_offs = 0; - sample_idx = (example_id + used_samples) % samples_count; - sample_begin = samples_begin[sample_idx]; - sample_size = samples_size[sample_idx]; - ++used_samples; - } - } - // note: no else-if here - if (sample_offs < sample_size) { - token = clamp(train_data[sample_begin+sample_offs], 0, (llama_token) (n_vocab - 1)); - ++sample_offs; - } - ggml_set_f32_nd(target_probs, token, (int) i, (int) k, 0, +1.0f); - if (i+1> rng; -} - -std::string mt19937_get_state(const std::mt19937& rng) { - std::stringstream s_rng_state; - s_rng_state.imbue(std::locale::classic()); - s_rng_state << rng; - return s_rng_state.str(); -} - -std::string mt19937_seed_to_state(unsigned seed) { - std::mt19937 rng(seed); - return mt19937_get_state(rng); -} - -std::string shuffle_samples( - const std::string & rng_state, - size_t * shuffled_offs, - size_t * shuffled_begins, - size_t * shuffled_sizes, - const size_t * begins, - const size_t * sizes, - size_t count) { - if (count == 0) return rng_state; - - std::mt19937 rng; - mt19937_set_state(rng, rng_state); - - // sort indices by random value for each index - std::vector idcs; - { - std::vector rnd; - idcs.resize(count); - rnd.resize(count); - for (unsigned i=0; i h_string; - std::hash h_ull; - size_t h = h_string(std::string(fn)); - h = hash_combine(h, h_ull((unsigned long long) sample_count)); - for (size_t i=0; i< sample_count; ++i) { - h = hash_combine(h, h_ull((unsigned long long) samples_begin[i])); - h = hash_combine(h, h_ull((unsigned long long) samples_size[i])); - } - return h; -} - -std::string replace_str(const char * s, const char * needle, const char * replacement) { - std::string str = s; - size_t pos = str.find(needle); - if (pos != std::string::npos) { - str.replace(pos, strlen(needle), replacement); - } - return str; -} - -void print_duration(double fmillis) { - if (fmillis < 1000.0f) { - printf("%.1fms", (float) fmillis); - return; - } - const int64_t one_sec = 1000; - const int64_t one_min = one_sec * 60; - const int64_t one_hour = one_min * 60; - const int64_t one_day = one_hour * 24; - - int64_t millis = (int64_t) fmillis; - int64_t days = millis/one_day; - int64_t hours = (millis - days*one_day)/one_hour; - int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min; - int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec; - - // to print int64_t either cast to (long long int) or use macro PRId64 from - if (days > 0) { - printf("%lldd ", (long long int) days); - } - printf("%02lld:%02lld:%02lld", (long long int) hours, (long long int) minutes, (long long int) seconds); -} - -float cosine_decay(int64_t step, int64_t decay_steps, float minimum) { - if (step > decay_steps) { - step = decay_steps; - } - const float cosine_decay = 0.50f*(1.0f + cosf(3.14159265359f*step/decay_steps)); - const float decay = (1 - minimum)*cosine_decay + minimum; - return decay; -} - -float cosine_decay_restart(int64_t step, int64_t decay_steps, float minimum, float restart_step_mult) { - while (step > decay_steps) { - step -= decay_steps; - decay_steps = (int64_t) (restart_step_mult * decay_steps); - } - return cosine_decay(step, decay_steps, minimum); -} - -float learning_schedule( - int64_t step, - int64_t warmup_steps, - int64_t cos_decay_steps, - float learning_rate, - float overall_minimum, - float cos_decay_minimum, - float cos_decay_restart_step_mult, - bool enable_restart) { - - float result = - (step < warmup_steps) - ? (float) step / (float) warmup_steps - : enable_restart - ? cosine_decay_restart( - step - warmup_steps, - cos_decay_steps, - cos_decay_minimum, - cos_decay_restart_step_mult) - : cosine_decay( - step, - cos_decay_steps, - cos_decay_minimum); - - float min = overall_minimum / learning_rate; - result = min + result * (1.0f - min); - return result; -} - -static bool are_same_layout(struct ggml_tensor * a, struct ggml_tensor * b) { - GGML_ASSERT(a != NULL); - GGML_ASSERT(b != NULL); - GGML_ASSERT(a->type == b->type); - GGML_ASSERT(ggml_are_same_shape(a, b)); - GGML_ASSERT(ggml_is_contiguous(a) && ggml_is_contiguous(b)); - - return true; -} - -void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name) { - if (dst == NULL) { - return; - } - struct ggml_tensor * t = ggml_get_tensor(ctx, name); - GGML_ASSERT(are_same_layout(dst, t)); - memcpy(dst->data, t->data, ggml_nbytes(t)); - - if (strlen(ggml_get_name(dst)) == 0) { - ggml_set_name(dst, name); - } -} - -// gguf constants -static const char * LLM_KV_OPTIMIZER_TYPE = "optimizer.type"; -static const char * LLM_KV_OPTIMIZER_TYPE_ADAM = "adam"; -static const char * LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs"; -static const char * LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version"; -static const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count"; -static const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count"; -static const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count"; -static const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized"; -static const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss"; -static const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss"; -static const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count"; -static const char * LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count"; -static const char * LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss"; -static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step"; -static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j"; -static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k"; -static const char * LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end"; -static const char * LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count"; - -static const char * LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments"; -static const char * LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments"; -static const char * LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values"; - -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters"; -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters"; -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients"; -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients"; -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction"; -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values"; -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha"; -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys"; -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"; -static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"; - -static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; -static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; -static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; -static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; -static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count"; -static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash"; -static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state"; -static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count"; -static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample"; - -#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ -{ \ - const std::string skey(key); \ - const int kid = gguf_find_key(ctx, skey.c_str()); \ - if (kid >= 0) { \ - enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ - if (ktype != (type)) { \ - die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \ - } \ - (dst) = func(ctx, kid); \ - } else if (req) { \ - die_fmt("key not found in model: %s", skey.c_str()); \ - } \ -} - -void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) { - // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read - - uint32_t file_version; - GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_FILE_VERSION); - GGML_ASSERT(file_version == 0); - - GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT); - GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT); - GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED); - - uint64_t nx; - GGUF_GET_KEY(fctx, nx, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_OPTIMIZER_PARAMETER_COUNT); - opt->nx = (size_t) nx; - - // don't call ggml_opt_init until optimizer type and optimizer specific parameters are know - - std::string opt_type; - GGUF_GET_KEY(fctx, opt_type, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_OPTIMIZER_TYPE); - if (opt_type == LLM_KV_OPTIMIZER_TYPE_ADAM) { - opt->params.type = GGML_OPT_TYPE_ADAM; - - GGUF_GET_KEY(fctx, opt->adam.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS); - GGUF_GET_KEY(fctx, opt->adam.fx_prev, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS); - GGUF_GET_KEY(fctx, opt->adam.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT); - - ggml_opt_init(opt->ctx, opt, opt->params, opt->nx); - - copy_tensor_by_name(opt->adam.m, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS); - copy_tensor_by_name(opt->adam.v, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS); - copy_tensor_by_name(opt->adam.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES); - } else if (opt_type == LLM_KV_OPTIMIZER_TYPE_LBFGS) { - opt->params.type = GGML_OPT_TYPE_LBFGS; - - GGUF_GET_KEY(fctx, opt->params.lbfgs.m, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT); - GGUF_GET_KEY(fctx, opt->lbfgs.fx_best, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS); - GGUF_GET_KEY(fctx, opt->lbfgs.step, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP); - GGUF_GET_KEY(fctx, opt->lbfgs.j, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J); - GGUF_GET_KEY(fctx, opt->lbfgs.k, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K); - GGUF_GET_KEY(fctx, opt->lbfgs.end, gguf_get_val_i32, GGUF_TYPE_INT32, true, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END); - GGUF_GET_KEY(fctx, opt->lbfgs.n_no_improvement, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT); - - ggml_opt_init(opt->ctx, opt, opt->params, opt->nx); - - copy_tensor_by_name(opt->lbfgs.x, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS); - copy_tensor_by_name(opt->lbfgs.xp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS); - copy_tensor_by_name(opt->lbfgs.g, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS); - copy_tensor_by_name(opt->lbfgs.gp, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS); - copy_tensor_by_name(opt->lbfgs.d, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION); - copy_tensor_by_name(opt->lbfgs.pf, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES); - copy_tensor_by_name(opt->lbfgs.lmal, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA); - copy_tensor_by_name(opt->lbfgs.lmys, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS); - copy_tensor_by_name(opt->lbfgs.lms, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S); - copy_tensor_by_name(opt->lbfgs.lmy, f_ggml_ctx, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y); - } else { - die("unknown optimizer type\n"); - } -} - -void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt) { - gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_FILE_VERSION, 0); - gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT, opt->params.past); - gguf_set_val_u64(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT, (uint64_t) opt->nx); - gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter); - gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized); - - switch (opt->params.type) { - case GGML_OPT_TYPE_ADAM: - { - gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_ADAM); - gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_BEST_LOSS, opt->adam.fx_best); - gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS, opt->adam.fx_prev); - gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT, opt->adam.n_no_improvement); - - ggml_set_name(opt->adam.m, LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS); - ggml_set_name(opt->adam.v, LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS); - if (opt->adam.pf) { - ggml_set_name(opt->adam.pf, LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES); - } - - gguf_add_tensor(fctx, opt->adam.m); - gguf_add_tensor(fctx, opt->adam.v); - if (opt->adam.pf) { - gguf_add_tensor(fctx, opt->adam.pf); - } - } break; - case GGML_OPT_TYPE_LBFGS: - { - gguf_set_val_str(fctx, LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS); - gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, opt->params.lbfgs.m); - gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, opt->lbfgs.fx_best); - gguf_set_val_f32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, opt->lbfgs.step); - gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, opt->lbfgs.j); - gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, opt->lbfgs.k); - gguf_set_val_i32(fctx, LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, opt->lbfgs.end); - gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, opt->lbfgs.n_no_improvement); - - ggml_set_name(opt->lbfgs.x, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS); - ggml_set_name(opt->lbfgs.xp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS); - ggml_set_name(opt->lbfgs.g, LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS); - ggml_set_name(opt->lbfgs.gp, LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS); - ggml_set_name(opt->lbfgs.d, LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION); - if (opt->lbfgs.pf) { - ggml_set_name(opt->lbfgs.pf, LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES); - } - ggml_set_name(opt->lbfgs.lmal, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA); - ggml_set_name(opt->lbfgs.lmys, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS); - ggml_set_name(opt->lbfgs.lms, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S); - ggml_set_name(opt->lbfgs.lmy, LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y); - - gguf_add_tensor(fctx, opt->lbfgs.x); - gguf_add_tensor(fctx, opt->lbfgs.xp); - gguf_add_tensor(fctx, opt->lbfgs.g); - gguf_add_tensor(fctx, opt->lbfgs.gp); - gguf_add_tensor(fctx, opt->lbfgs.d); - if (opt->lbfgs.pf) { - gguf_add_tensor(fctx, opt->lbfgs.pf); - } - gguf_add_tensor(fctx, opt->lbfgs.lmal); - gguf_add_tensor(fctx, opt->lbfgs.lmys); - gguf_add_tensor(fctx, opt->lbfgs.lms); - gguf_add_tensor(fctx, opt->lbfgs.lmy); - } break; - } -} - -bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) { - if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) < 0) { - return false; - } - - uint32_t file_version; - GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); - GGML_ASSERT(file_version <= 1); - - if (file_version == 0) { - - GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); - GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); - - } else if (file_version == 1) { - - GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT); - GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT); - GGUF_GET_KEY(fctx, train->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT); - - GGUF_GET_KEY(fctx, train->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH); - GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE); - GGUF_GET_KEY(fctx, train->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, train->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE); - } - - load_opt_context_gguf(fctx, f_ggml_ctx, train->opt); - return true; -} - -void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) { - gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, train->train_epochs); - - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash); - gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, train->shuffle_rng_state_current.c_str()); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) train->shuffle_next_sample); - - save_opt_context_gguf(fctx, train->opt); -} - - -struct llama_file { - // use FILE * so we don't have to re-open the file to mmap - FILE * fp; - size_t size; - - llama_file(const char * fname, const char * mode) { - fp = std::fopen(fname, mode); - if (fp == NULL) { - size = 0; - } else { - seek(0, SEEK_END); - size = tell(); - seek(0, SEEK_SET); - } - } - - size_t tell() const { -#ifdef _WIN32 - __int64 ret = _ftelli64(fp); -#else - long ret = std::ftell(fp); -#endif - GGML_ASSERT(ret != -1); // this really shouldn't fail - return (size_t) ret; - } - - void seek(size_t offset, int whence) { -#ifdef _WIN32 - int ret = _fseeki64(fp, (__int64) offset, whence); -#else - int ret = std::fseek(fp, (long) offset, whence); -#endif - GGML_ASSERT(ret == 0); // same - } - - void read_raw(void * ptr, size_t size) { - if (size == 0) { - return; - } - errno = 0; - std::size_t ret = std::fread(ptr, size, 1, fp); - if (ferror(fp)) { - die_fmt("read error: %s", strerror(errno)); - } - if (ret != 1) { - die("unexpectedly reached end of file"); - } - } - - std::uint32_t read_u32() { - std::uint32_t ret; - read_raw(&ret, sizeof(ret)); - return ret; - } - - std::string read_string(std::uint32_t len) { - std::vector chars(len); - read_raw(chars.data(), len); - return std::string(chars.data(), len); - } - - void write_raw(const void * ptr, size_t size) { - if (size == 0) { - return; - } - errno = 0; - size_t ret = std::fwrite(ptr, size, 1, fp); - if (ret != 1) { - die_fmt("write error: %s", strerror(errno)); - } - } - - void write_u32(std::uint32_t val) { - write_raw(&val, sizeof(val)); - } - - ~llama_file() { - if (fp) { - std::fclose(fp); - } - } -}; - -static size_t utf8_len(char src) { - const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t highbits = static_cast(src) >> 4; - return lookup[highbits]; -} - -// mark each byte with its utf8 unit number. -// returns the number of utf8 characters. -// e.g. when bytes == '\x61\xD0\xB0\x62', -// then utf8_units will become [0,0,1,0] -// utf8_nunits will become [1,2,2,1] and 3 is returned. -// bytes where utf8_units is zero, are the begin of an utf8 character. -static size_t mark_utf8_units(const char* bytes, int * utf8_units, int * utf8_nunits, size_t count) { - size_t offs = 0; - size_t count_utf8 = 0; - while(offs < count) { - int len = (int) utf8_len(bytes[offs]); - for (int i=0; i & out_tokens, - std::vector & out_samples_begin, - std::vector & out_samples_size) { - struct llama_file f(filename, "rb"); - - if (f.size == 0) { - out_tokens.clear(); - out_samples_begin.clear(); - out_samples_size.clear(); - printf("%s: warning: empty or not existing training data file '%s'\n", - __func__, filename); - return out_tokens.size(); - } - - // account for possible leading whitespace that will be added by tokenizer - // e.g. '\t' will be tokenized by llama spm tokenizer to [29871, 12] - const int n_max_tokens_overhead = 1; - - std::vector buf; - buf.resize(f.size); - - f.read_raw(buf.data(), f.size); - - std::vector utf8_units; - std::vector utf8_nunits; - utf8_units.resize(buf.size()); - utf8_nunits.resize(buf.size()); - mark_utf8_units(buf.data(), utf8_units.data(), utf8_nunits.data(), buf.size()); - - if (sample_start.size() == 0) { - // tokenize all data at once - out_tokens.resize(buf.size() + n_max_tokens_overhead); - - int n_tokens = llama_tokenize( - llama_get_model(lctx), - buf.data(), - (int) buf.size(), - out_tokens.data(), - (int) out_tokens.size(), - false, false); - if (n_tokens < 0) { - out_tokens.resize(-n_tokens); - n_tokens = llama_tokenize( - llama_get_model(lctx), - buf.data(), - (int) buf.size(), - out_tokens.data(), - (int) out_tokens.size(), - false, false); - } - if (n_tokens >= 0) { - out_tokens.resize(n_tokens); - } - - // generate sample starts at all token positions - out_samples_begin.clear(); - out_samples_begin.push_back(0); - out_samples_size.push_back(std::min((size_t) context_length, out_tokens.size())); - size_t end = (out_tokens.size() >= context_length) ? (out_tokens.size() - context_length) : 0; - for (size_t sample_begin = 1; sample_begin < end; ++sample_begin) { - out_samples_begin.push_back(sample_begin); - out_samples_size.push_back(context_length); - } - } else { - // split data into samples and tokenize each sample - std::string data_str(buf.data(), buf.size()); - out_samples_begin.clear(); - out_samples_size.clear(); - out_tokens.clear(); - - // find all positions of pattern sample_start - size_t sample_begin = data_str.find(sample_start, 0); - while (sample_begin != std::string::npos) { - out_samples_begin.push_back(sample_begin); - const size_t search_start = sample_begin + sample_start.size(); - sample_begin = data_str.find(sample_start, search_start); - } - if (out_samples_begin.size() == 0) { - printf("%s: warning: sample start pattern '%s' not found. inserting single sample at data begin\n", - __func__, sample_start.c_str()); - out_samples_begin.push_back(0); - } - - out_samples_size.resize(out_samples_begin.size(), 0); - - std::vector buf_sample; - std::vector tok_sample; - - const size_t sample_begin_offset = (include_sample_start ? 0 : sample_start.size()); - size_t found_too_big_sample = 0; - size_t found_too_small_sample = 0; - size_t found_empty_sample = 0; - size_t found_min_sample_size = SIZE_MAX; - size_t found_max_sample_size = 0; - - size_t max_token_text_size = 0; - int n_vocab = llama_n_vocab(llama_get_model(lctx)); - for (llama_token token=0; token < n_vocab; ++token) { - max_token_text_size = std::max( - max_token_text_size, - strlen(llama_token_get_text(llama_get_model(lctx), token))); - } - - // upper bound of context byte length. - // strings with this byte length should always tokenize to at least context_length tokens. - size_t context_byte_len = max_token_text_size*context_length; - - for (unsigned i=0; i 0) { - // sample end is in the middle of an utf8 character. - // advance sample_end to the begin of the next utf8 character. - sample_end += utf8_nunits[sample_end] - utf8_units[sample_end]; - } - size_t sample_size = sample_end - sample_begin; - if (sample_size == 0) { - ++found_empty_sample; - } - - if (sample_size > 0) { - // llama_tokenize expects zero terminated string, - // copy sample into buffer and zero terminate it. - buf_sample.resize(sample_size); - memcpy(buf_sample.data(), data_str.data() + sample_begin, sample_size); - - // printf("sample: '%s'\n", buf_sample.data()); - - // tokenize the sample - tok_sample.resize(buf_sample.size() + n_max_tokens_overhead); - int n_tokens = llama_tokenize(llama_get_model(lctx), - buf_sample.data(), - (int) buf_sample.size(), - tok_sample.data(), - (int) tok_sample.size(), - false, false); - if (n_tokens < 0) { - tok_sample.resize(-n_tokens); - n_tokens = llama_tokenize(llama_get_model(lctx), - buf_sample.data(), - (int) buf_sample.size(), - tok_sample.data(), - (int) tok_sample.size(), - false, false); - GGML_ASSERT(n_tokens >= 0); - } - GGML_ASSERT(n_tokens <= (int) tok_sample.size()); - - if ((size_t) n_tokens > context_length) { - ++found_too_big_sample; - } else if ((size_t) n_tokens < context_length) { - ++found_too_small_sample; - } - found_max_sample_size = std::max(found_max_sample_size, (size_t) n_tokens); - found_min_sample_size = std::min(found_min_sample_size, (size_t) n_tokens); - - // write out tokens, start and size of sample - // overwrite the string start position with the token start position - out_samples_begin[i] = out_tokens.size(); - out_samples_size[i] = (size_t) n_tokens; - out_tokens.insert(out_tokens.end(), tok_sample.begin(), tok_sample.begin() + n_tokens); - } else { - out_samples_begin[i] = out_tokens.size(); - out_samples_size[i] = 0; - } - - } - if (found_too_big_sample > 0) { - printf("%s: warning: found %zu samples (max length %zu) that exceed context length of %u. samples will be cut off.\n", - __func__, found_too_big_sample, found_max_sample_size, context_length); - } - - if (found_too_small_sample > 0) { - printf("%s: warning: found %zu samples (min length %zu) that are shorter than context length of %u.\n", - __func__, found_too_small_sample, found_min_sample_size, context_length); - } - - if (found_empty_sample) { - printf("%s: warning: found %zu empty samples.\n", - __func__, found_empty_sample); - } - } - printf("%s: total number of samples: %zu\n", - __func__, out_samples_begin.size()); - - GGML_ASSERT(out_samples_begin.size() == out_samples_size.size()); - - return out_tokens.size(); -} - -std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) { - std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); - return replace_str(filename, pattern_it, sit.c_str()); -} - -struct train_params_common get_default_train_params_common() { - struct train_params_common params; - params.fn_train_data = "shakespeare.txt"; - params.fn_checkpoint_in = "checkpoint.gguf"; - params.fn_checkpoint_out = "checkpoint-ITERATION.gguf"; - params.pattern_fn_it = "ITERATION"; - params.fn_latest = "LATEST"; - - params.print_usage = false; - - params.save_every = 10; - - params.seed = -1; - - params.n_ctx = 128; - params.n_threads = 6; - params.n_batch = 8; - params.n_gradient_accumulation = 1; - params.n_epochs = -1; - params.n_gpu_layers = 0; - - params.custom_n_ctx = false; - - params.use_flash = false; - params.use_checkpointing = true; - - params.sample_start = ""; - params.include_sample_start = false; - params.escape = false; - params.overlapping_samples = false; - params.fill_with_next_samples = false; - params.separate_with_eos = false; - params.separate_with_bos = true; - params.sample_random_offsets = false; - params.force_reshuffle = false; - - params.opt_past = 0; - params.opt_delta = 1e-5f; - params.opt_max_no_improvement = 0; - - params.warmup = 100; - params.cos_decay_steps = 1000; - params.cos_decay_restart = 1.1f; - params.cos_decay_min = 0.1f; - params.enable_restart = false; - - params.adam_n_iter = 256; - params.adam_alpha = 1e-3f; - params.adam_min_alpha = 0; - params.adam_decay = 1e-1f; - params.adam_decay_min_ndim = 2; - params.adam_beta1 = 0.9f; - params.adam_beta2 = 0.999f; - params.adam_gclip = 1.0f; - params.adam_eps_f = 0.0f; - - return params; -} - -void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train_params_common * params) { - // fprintf(stderr, "usage: %s [options]\n", argv[0]); - // fprintf(stderr, "\n"); - // fprintf(stderr, "options:\n"); - // fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data); - fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in); - fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out); - fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it); - fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest); - fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every); - fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n"); - fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); - fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); - fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); - fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation); - fprintf(stderr, " --sample-start STR Sets the starting point for samples after the specified pattern. If empty use every token position as sample start. (default '%s')\n", params->sample_start.c_str()); - fprintf(stderr, " --include-sample-start Include the sample start in the samples. (default off)\n"); - fprintf(stderr, " --escape process sample start escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); - fprintf(stderr, " --overlapping-samples Samples may overlap, will include sample-start of second and following samples. When off, samples will end at begin of next sample. (default off)\n"); - fprintf(stderr, " --fill-with-next-samples Samples shorter than context length will be followed by the next (shuffled) samples. (default off)\n"); - fprintf(stderr, " --separate-with-eos When fill-with-next-samples, insert end-of-sequence token between samples.%s\n", params->separate_with_eos ? " (default)" : ""); - fprintf(stderr, " --separate-with-bos When fill-with-next-samples, insert begin-of-sequence token between samples.%s\n", params->separate_with_bos ? " (default)" : ""); - fprintf(stderr, " --no-separate-with-eos When fill-with-next-samples, don't insert end-of-sequence token between samples.%s\n", !params->separate_with_eos ? " (default)" : ""); - fprintf(stderr, " --no-separate-with-bos When fill-with-next-samples, don't insert begin-of-sequence token between samples.%s\n", !params->separate_with_bos ? " (default)" : ""); - fprintf(stderr, " --sample-random-offsets Use samples beginning at random offsets. Together with fill-with-next-samples this may help for training endless text generation.%s\n", params->sample_random_offsets ? " (default)" : ""); - fprintf(stderr, " --force-reshuffle Force a reshuffling of data at program start, otherwise the shuffling of loaded checkpoint is resumed.\n"); - fprintf(stderr, " --no-flash Don't use flash attention \n"); - fprintf(stderr, " --use-flash Use flash attention (default)\n"); - fprintf(stderr, " --no-checkpointing Don't use gradient checkpointing\n"); - fprintf(stderr, " --use-checkpointing Use gradient checkpointing (default)\n"); - fprintf(stderr, " --warmup N Only for Adam optimizer. Number of warmup steps (default %d)\n", params->warmup); - fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps); - fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart); - fprintf(stderr, " --cos-decay-min N Only for Adam optimizer. Cosine decay minimum (default %f)\n", params->cos_decay_min); - fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : ""); - fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : ""); - fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past); - fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta); - fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement); - fprintf(stderr, " --epochs N Maximum number epochs to process. (default %d)\n", params->n_epochs); - fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); - fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); - fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha); - fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); - fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim); - fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); - fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); - fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); - fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); - fprintf(stderr, " -ngl N, --n-gpu-layers N Number of model layers to offload to GPU (default %d)", params->n_gpu_layers); - fprintf(stderr, "\n"); -} - -bool consume_common_train_arg( - int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param -) { - int& i = *idx; - std::string arg = argv[i]; - const std::string arg_prefix = "--"; - if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { - std::replace(arg.begin(), arg.end(), '_', '-'); - } - if (arg == "--train-data") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->fn_train_data = argv[i]; - } else if (arg == "--checkpoint-in") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->fn_checkpoint_in = argv[i]; - } else if (arg == "--checkpoint-out") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->fn_checkpoint_out = argv[i]; - } else if (arg == "--pattern-fn-it") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->pattern_fn_it = argv[i]; - } else if (arg == "--fn-latest") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->fn_latest = argv[i]; - } else if (arg == "--save-every") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->save_every = std::stoi(argv[i]); - } else if (arg == "-s" || arg == "--seed") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->seed = std::stoi(argv[i]); - } else if (arg == "-c" || arg == "--ctx") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->n_ctx = std::stoi(argv[i]); - params->custom_n_ctx = true; - } else if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->n_threads = std::stoi(argv[i]); - } else if (arg == "-b" || arg == "--batch") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->n_batch = std::stoi(argv[i]); - } else if (arg == "--grad-acc") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->n_gradient_accumulation = std::max(1, std::stoi(argv[i])); - } else if (arg == "--sample-start") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->sample_start = std::string(argv[i]); - } else if (arg == "--escape") { - params->escape = true; - } else if (arg == "--include-sample-start") { - params->include_sample_start = true; - } else if (arg == "--overlapping-samples") { - params->overlapping_samples = true; - } else if (arg == "--fill-with-next-samples") { - params->fill_with_next_samples = true; - } else if (arg == "--separate-with-eos") { - params->separate_with_eos = true; - } else if (arg == "--separate-with-bos") { - params->separate_with_bos = true; - } else if (arg == "--no-separate-with-eos") { - params->separate_with_eos = false; - } else if (arg == "--no-separate-with-bos") { - params->separate_with_bos = false; - } else if (arg == "--sample-random-offsets") { - params->sample_random_offsets = true; - } else if (arg == "--force-reshuffle") { - params->force_reshuffle = true; - } else if (arg == "--no-flash") { - params->use_flash = false; - } else if (arg == "--use-flash") { - params->use_flash = true; - } else if (arg == "--no-checkpointing") { - params->use_checkpointing = false; - } else if (arg == "--use-checkpointing") { - params->use_checkpointing = true; - } else if (arg == "--warmup") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->warmup = std::stoi(argv[i]); - } else if (arg == "--cos-decay-steps") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->cos_decay_steps = std::stoi(argv[i]); - } else if (arg == "--cos-decay-restart") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->cos_decay_restart = std::stof(argv[i]); - } else if (arg == "--cos-decay-min") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->cos_decay_min = std::stof(argv[i]); - } else if (arg == "--enable-restart") { - params->enable_restart = true; - } else if (arg == "--disable-restart") { - params->enable_restart = false; - } else if (arg == "--opt-past") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->opt_past = std::stoi(argv[i]); - } else if (arg == "--opt-delta") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->opt_delta = std::stof(argv[i]); - } else if (arg == "--opt-max-no-improvement") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->opt_max_no_improvement = std::stoi(argv[i]); - } else if (arg == "--adam-epsf") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->adam_eps_f = std::stof(argv[i]); - } else if (arg == "--epochs") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->n_epochs = std::stoi(argv[i]); - } else if (arg == "--adam-iter") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->adam_n_iter = std::stoi(argv[i]); - } else if (arg == "--adam-alpha") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->adam_alpha = std::stof(argv[i]); - } else if (arg == "--adam-min-alpha") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->adam_min_alpha = std::stof(argv[i]); - } else if (arg == "--adam-decay") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->adam_decay = std::stof(argv[i]); - } else if (arg == "--adam-decay-min-ndim") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->adam_decay_min_ndim = std::stoi(argv[i]); - } else if (arg == "--adam-beta1") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->adam_beta1 = std::stof(argv[i]); - } else if (arg == "--adam-beta2") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->adam_beta2 = std::stof(argv[i]); - } else if (arg == "--adam-gclip") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - params->adam_gclip = std::stof(argv[i]); - } else if (arg == "-ngl" || arg == "--n-gpu-layers") { - if (++i >= argc) { - *invalid_param = true; - return true; - } - if (llama_supports_gpu_offload()) { - params->n_gpu_layers = std::stoi(argv[i]); - } else { - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); - } - } else if (arg == "-h" || arg == "--help") { - params->print_usage = true; - return true; - } else { - return false; - } - return true; -} - -void finish_processing_train_args(struct train_params_common * params) { - if (params->escape) { - string_process_escapes(params->sample_start); - } -} - -void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) { - struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata; - struct train_params_common * params = data->params; - struct train_state * train = data->train; - struct ggml_opt_context * opt = train->opt; - int n_batch = params->n_batch; - int n_ctx = params->n_ctx; - - if (accum_step == 0) { - // time measurement - int64_t now = ggml_time_ms(); - if (now > data->last_time && opt->iter > data->first_iter) { - double dt = (double) (now - data->last_time); - if (data->millis_per_iter == 0.0) { - data->millis_per_iter = dt; - } else { - const double gain = 0.7; - data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain; - } - } - - double remaining_millis = 0.0; - if (data->millis_per_iter > 0.0) { - const int n_iter = params->adam_n_iter; - const int done_iter = opt->iter - data->first_iter; - const int remaining_iter = n_iter - done_iter; - remaining_millis = remaining_iter * data->millis_per_iter; - } - - // file saving - const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); - if (save_now) { - int new_iters = opt->iter - data->last_save_iter; - train->train_its += new_iters; - train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; - - if (data->save_cb) { - data->save_cb(data->save_data, train); - } - - data->last_save_iter = opt->iter; - } - - // exclude file saving from time measurement, by measuring last_time after saving - data->last_time = ggml_time_ms(); - - *sched = learning_schedule( - opt->iter, - params->warmup, - params->cos_decay_steps, - params->adam_alpha, - params->adam_min_alpha, - params->cos_decay_min, - params->cos_decay_restart, - params->enable_restart); - - int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); - if (impr_plot > 0) impr_plot = 0; - if (std::isnan(opt->loss_before) || std::isnan(opt->loss_after)) impr_plot = 0; - printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f", - __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count, - *sched, opt->loss_after); - - - if (data->millis_per_iter > 0) { - printf(" dt="); - print_duration(data->millis_per_iter); - printf(" eta="); - print_duration(remaining_millis); - } - - float improvement = opt->loss_before - opt->loss_after; - const float plot_scale = 10.0f; - int bar_len = (int)(1 + improvement*plot_scale + 0.5); - printf(" |"); - for (int i=0; i"); - printf("\n"); - } - - int64_t used_samples = get_example_targets_batch( - data->lctx, - data->tokens_input, - data->target_probs, - train->shuffle_next_sample, - data->shuffled_samples_offs, - data->shuffled_samples_begin, - data->shuffled_samples_size, - data->samples_count, - data->tokens_data, - data->tokens_size, - params->separate_with_eos, - params->separate_with_bos, - params->fill_with_next_samples, - params->sample_random_offsets); - - train->train_samples += used_samples; - train->shuffle_next_sample += used_samples; - - if (train->shuffle_next_sample >= train->shuffle_sample_count) { - ++train->train_epochs; - printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs); - // note: we may have used some samples from the current shuffling more than once - train->shuffle_rng_state_current = train->shuffle_rng_state_next; - train->shuffle_rng_state_next = shuffle_samples( - train->shuffle_rng_state_current, - data->shuffled_samples_offs, - data->shuffled_samples_begin, - data->shuffled_samples_size, - data->samples_begin, - data->samples_size, - data->samples_count); - train->shuffle_next_sample = 0; - } - - const bool last_epoch_reached = (params->n_epochs > 0 && (int64_t) train->train_epochs - data->first_epoch >= params->n_epochs); - if (last_epoch_reached) { - // allow optimization iteration at last epoch to be completed before canceling - if (data->iter_at_last_epoch < 0) { - data->iter_at_last_epoch = opt->iter; - } else if (opt->iter > data->iter_at_last_epoch) { - *cancel = true; - } - } -} diff --git a/common/train.h b/common/train.h deleted file mode 100644 index 263d940c0..000000000 --- a/common/train.h +++ /dev/null @@ -1,233 +0,0 @@ -// Various helper functions and utilities for training - -#pragma once - -#include -#include -#include - -#include "ggml.h" -#include "llama.h" - -#define LLAMA_TRAIN_MAX_NODES 16384 - -typedef std::string mt19937_state; - -struct train_state { - struct ggml_opt_context * opt; - - uint64_t train_its; - uint64_t train_samples; - uint64_t train_tokens; - uint64_t train_epochs; - - size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes) - mt19937_state shuffle_rng_state_current; - mt19937_state shuffle_rng_state_next; - size_t shuffle_sample_count; - size_t shuffle_next_sample; -}; - -struct train_params_common { - const char * fn_train_data; - const char * fn_checkpoint_in; - const char * fn_checkpoint_out; - const char * pattern_fn_it; - const char * fn_latest; - - bool print_usage; - - int save_every; - - uint32_t seed; - - int n_ctx; - int n_threads; - int n_batch; - int n_gradient_accumulation; - int n_epochs; - int n_gpu_layers; - - bool custom_n_ctx; - - bool use_flash; - bool use_checkpointing; - - std::string sample_start; - bool include_sample_start; - bool escape; - bool overlapping_samples; - bool fill_with_next_samples; - bool separate_with_eos; - bool separate_with_bos; - bool sample_random_offsets; - - bool force_reshuffle; - - int warmup; - int cos_decay_steps; - float cos_decay_restart; - float cos_decay_min; - bool enable_restart; - - int opt_past; - float opt_delta; - int opt_max_no_improvement; - - int adam_n_iter; - float adam_alpha; - float adam_min_alpha; - float adam_decay; - int adam_decay_min_ndim; - float adam_beta1; - float adam_beta2; - float adam_gclip; - float adam_eps_f; -}; - -typedef void (*save_train_files_callback)(void * data, struct train_state * train); - -struct train_opt_callback_data { - struct train_params_common * params; - struct train_state * train; - save_train_files_callback save_cb; - void * save_data; - struct llama_context * lctx; - int last_save_iter; - llama_token * tokens_data; - size_t tokens_size; - size_t * samples_begin; - size_t * samples_size; - size_t * shuffled_samples_offs; - size_t * shuffled_samples_begin; - size_t * shuffled_samples_size; - size_t samples_count; - struct ggml_tensor * tokens_input; - struct ggml_tensor * target_probs; - int first_iter; - int first_epoch; - int iter_at_last_epoch; - int64_t last_time; - double millis_per_iter; -}; - -struct train_state * init_train_state(); -void free_train_state(struct train_state * state); - -struct train_params_common get_default_train_params_common(); -void print_common_train_usage(int /*argc*/, char ** argv, const struct train_params_common * params); - -bool consume_common_train_arg(int argc, char ** argv, int * idx, struct train_params_common * params, bool * invalid_param); -void finish_processing_train_args(struct train_params_common * params); - -struct random_normal_distribution; -struct random_uniform_distribution; - -struct random_normal_distribution * init_random_normal_distribution (int seed, float mean, float std, float min, float max); -struct random_uniform_distribution * init_random_uniform_distribution(int seed, float min, float max); - -void free_random_normal_distribution (struct random_normal_distribution * rnd); -void free_random_uniform_distribution(struct random_uniform_distribution * rnd); - -struct ggml_tensor * randomize_tensor_normal (struct ggml_tensor * tensor, struct random_normal_distribution * rnd); -struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struct random_uniform_distribution * rnd); - -// generate random float in interval [0,1) -float frand(); -float frand_normal (struct random_normal_distribution * rnd); -float frand_uniform(struct random_uniform_distribution * rnd); - -int clamp (const int v, const int min, const int max); -float fclamp(const float v, const float min, const float max); - -void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0); -void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1); -void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2); -void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3); - -size_t tokenize_file( - struct llama_context * lctx, - const char * filename, - const std::string & sample_start, - bool include_sample_start, - bool overlapping_samples, - unsigned context_length, - std::vector & out_tokens, - std::vector & out_samples_begin, - std::vector & out_samples_size); - -int64_t get_example_targets_batch( - struct llama_context * lctx, - struct ggml_tensor * tokens_input, - struct ggml_tensor * target_probs, - int64_t example_id, - const size_t * samples_offs, - const size_t * samples_begin, - const size_t * samples_size, - size_t samples_count, - const llama_token * train_data, - size_t n_train_data, - bool separate_with_eos, - bool separate_with_bos, - bool fill_with_next_samples, - bool sample_random_offsets); - - -void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state); -mt19937_state mt19937_get_state(const std::mt19937& rng); -mt19937_state mt19937_seed_to_state(unsigned seed); - -mt19937_state shuffle_samples( - const mt19937_state & rng_state, - size_t * shuffled_offs, - size_t * shuffled_begins, - size_t * shuffled_sizes, - const size_t * begins, - const size_t * sizes, - size_t count); - -size_t hash_combine(size_t h1, size_t h2); - -size_t compute_samples_hash( - const char* fn, - const size_t* samples_begin, - const size_t* samples_size, - size_t sample_count); - - -std::string replace_str(const char * s, const char * needle, const char * replacement); - -void print_duration(double milliseconds); - -float cosine_decay( - int64_t step, - int64_t decay_steps, - float minimum); - -float cosine_decay_restart( - int64_t step, - int64_t decay_steps, - float minimum, - float restart_step_mult); - -float learning_schedule( - int64_t step, - int64_t warmup_steps, - int64_t decay_steps, - float learning_rate, - float overall_minimum, - float cos_decay_minimum, - float cos_decay_restart_step_mult, - bool enable_restart); - -void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, const char * name); - -void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt); -void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt); - -bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train); -void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train); - -std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration); - -void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0a9bbc829..63b54a9cf 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -15,6 +15,7 @@ from enum import IntEnum from pathlib import Path from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast +from itertools import chain import math import numpy as np @@ -64,7 +65,6 @@ class Model: model_name: str | None metadata_override: Path | None dir_model_card: Path - is_lora: bool # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -72,7 +72,8 @@ class Model: def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, is_lora: bool = False): + split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, + small_first_shard: bool = False, hparams: dict[str, Any] | None = None): if type(self) is Model: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") @@ -87,14 +88,13 @@ class Model: self.is_safetensors = len(self.part_names) > 0 if not self.is_safetensors: self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") - self.hparams = Model.load_hparams(self.dir_model) + self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_names = None self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py - self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type if self.ftype == gguf.LlamaFileType.GUESSED: @@ -132,12 +132,14 @@ class Model: def get_tensors(self) -> Iterator[tuple[str, Tensor]]: tensor_names_from_parts: set[str] = set() - if len(self.part_names) > 1: + index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin" + index_name += ".index.json" + index_file = self.dir_model / index_name + + if index_file.is_file(): self.tensor_names = set() - index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin" - index_name += ".index.json" logger.info(f"gguf: loading model weight map from '{index_name}'") - with open(self.dir_model / index_name, "r", encoding="utf-8") as f: + with open(index_file, "r", encoding="utf-8") as f: index: dict[str, Any] = json.load(f) weight_map = index.get("weight_map") if weight_map is None or not isinstance(weight_map, dict): @@ -145,6 +147,7 @@ class Model: self.tensor_names.update(weight_map.keys()) else: self.tensor_names = tensor_names_from_parts + weight_map = {} for part_name in self.part_names: logger.info(f"gguf: loading model part '{part_name}'") @@ -171,9 +174,17 @@ class Model: data = LazyTorchTensor.from_eager(data) yield name, data - # only verify tensor name presence; it doesn't matter if they are not in the right files - if len(sym_diff := tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: - raise ValueError(f"Mismatch between weight map and model parts for tensor names: {sym_diff}") + # verify tensor name presence and identify potentially missing files + if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: + missing = sorted(self.tensor_names.difference(tensor_names_from_parts)) + extra = sorted(tensor_names_from_parts.difference(self.tensor_names)) + missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) + if len(extra) == 0 and len(missing_files) > 0: + raise ValueError(f"Missing or incomplete model files: {missing_files}") + else: + raise ValueError("Mismatch between weight map and model parts for tensor names:\n" + f"Missing tensors: {missing}\n" + f"Extra tensors: {extra}") def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: if key not in gguf.MODEL_TENSORS[self.model_arch]: @@ -210,17 +221,17 @@ class Model: self.gguf_writer.add_context_length(n_ctx) logger.info(f"gguf: context length = {n_ctx}") - n_embd = self.find_hparam(["hidden_size", "n_embd"]) - self.gguf_writer.add_embedding_length(n_embd) - logger.info(f"gguf: embedding length = {n_embd}") + if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None: + self.gguf_writer.add_embedding_length(n_embd) + logger.info(f"gguf: embedding length = {n_embd}") if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None: self.gguf_writer.add_feed_forward_length(n_ff) logger.info(f"gguf: feed forward length = {n_ff}") - n_head = self.find_hparam(["num_attention_heads", "n_head"]) - self.gguf_writer.add_head_count(n_head) - logger.info(f"gguf: head count = {n_head}") + if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None: + self.gguf_writer.add_head_count(n_head) + logger.info(f"gguf: head count = {n_head}") if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: self.gguf_writer.add_head_count_kv(n_head_kv) @@ -259,10 +270,14 @@ class Model: return False + # some models need extra generated tensors (like rope_freqs) + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + return () + def prepare_tensors(self): max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") - for name, data_torch in self.get_tensors(): + for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): # we don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): continue @@ -280,8 +295,15 @@ class Model: bid = int(part) break - for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): - data: np.ndarray # type hint + for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + # TODO: why do we squeeze here? + # data = data_torch.squeeze().numpy() + data = data_torch.numpy() + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data.shape) == 0: + data = data_torch.numpy() + n_dims = len(data.shape) data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) @@ -302,6 +324,11 @@ class Model: gguf.MODEL_TENSOR.TIME_MIX_FIRST, gguf.MODEL_TENSOR.TIME_MIX_W1, gguf.MODEL_TENSOR.TIME_MIX_W2, + gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1, + gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2, + gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED, + gguf.MODEL_TENSOR.POSNET_NORM1, + gguf.MODEL_TENSOR.POSNET_NORM2, ) ) or not new_name.endswith(".weight") @@ -451,6 +478,11 @@ class Model: return modelcls return func + @classmethod + def print_registered_models(cls): + for name in sorted(cls._model_classes.keys()): + logger.error(f"- {name}") + @classmethod def from_model_architecture(cls, arch: str) -> type[Model]: try: @@ -503,9 +535,19 @@ class Model: else: token: str = reverse_vocab[i] if token in added_vocab: + # The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized. + # To avoid unexpected issues - we make sure to normalize non-normalized tokens + if not tokenizer.added_tokens_decoder[i].normalized: + previous_token = token + token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) + if previous_token != token: + logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") + if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token): toktypes.append(gguf.TokenType.CONTROL) else: + # NOTE: this was added for Gemma. + # Encoding and decoding the tokens above isn't sufficient for this case. token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces toktypes.append(gguf.TokenType.USER_DEFINED) else: @@ -549,9 +591,15 @@ class Model: if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed": # ref: https://huggingface.co/tiiuae/falcon-7b res = "falcon" + if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e": + # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base + res = "falcon3" if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 res = "bert-bge" + if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7": + # ref: https://huggingface.co/BAAI/bge-large-zh-v1.5 + res = "bert-bge-large" if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166": # ref: https://huggingface.co/mosaicml/mpt-7b res = "mpt" @@ -579,6 +627,9 @@ class Model: if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e": # ref: https://huggingface.co/databricks/dbrx-base res = "dbrx" + if chkhsh == "c7699093ba4255a91e702aa38a596aa81669f3525dae06c2953267dde580f448": + # ref: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + res = "jina-v1-en" if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en res = "jina-v2-en" @@ -624,6 +675,30 @@ class Model: if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae": # ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct res = "exaone" + if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085": + # ref: https://huggingface.co/microsoft/phi-2 + res = "phi-2" + if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450": + # ref: https://huggingface.co/facebook/chameleon-7b + res = "chameleon" + if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": + # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 + res = "minerva-7b" + if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65": + # ref: https://huggingface.co/sentence-transformers/stsb-roberta-base + res = "roberta-bpe" + if chkhsh == "ad851be1dba641f2e3711822f816db2c265f788b37c63b4e1aeacb9ee92de8eb": + # ref: https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct + res = "gigachat" + if chkhsh == "d4c8f286ea6b520b3d495c4455483cfa2302c0cfcd4be05d781b6a8a0a7cdaf1": + # ref: https://huggingface.co/Infinigence/Megrez-3B-Instruct + res = "megrez" + if chkhsh == "877081d19cf6996e2c4ff0e1236341e9b7bde288f5311a56a937f0afbbb3aeb5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-V3 + res = "deepseek-v3" + if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5": + # ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B + res = "deepseek-r1-qwen" if res is None: logger.warning("\n") @@ -646,6 +721,9 @@ class Model: return res # Marker: End get_vocab_base_pre + def _set_vocab_none(self) -> None: + self.gguf_writer.add_tokenizer_model("none") + def _set_vocab_gpt2(self) -> None: tokens, toktypes, tokpre = self.get_vocab_base() self.gguf_writer.add_tokenizer_model("gpt2") @@ -1482,7 +1560,7 @@ class StableLMModel(Model): raise ValueError(f"Unprocessed norms: {norms}") -@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") +@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") class LlamaModel(Model): model_arch = gguf.MODEL_ARCH.LLAMA @@ -1508,6 +1586,17 @@ class LlamaModel(Model): special_vocab._set_special_token("eot", 32010) special_vocab.add_to_gguf(self.gguf_writer) + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + # Apply to granite small models only + if self.hparams.get("vocab_size", 32000) == 49152: + self.gguf_writer.add_add_bos_token(False) + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams @@ -1524,17 +1613,6 @@ class LlamaModel(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' - if tokenizer_config_file.is_file(): - with open(tokenizer_config_file, "r", encoding="utf-8") as f: - tokenizer_config_json = json.load(f) - if "add_prefix_space" in tokenizer_config_json: - self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) - - # Apply to granite small models only - if self.hparams.get("vocab_size", 32000) == 49152: - self.gguf_writer.add_add_bos_token(False) - @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: @@ -1590,7 +1668,7 @@ class LlamaModel(Model): return [(self.map_tensor_name(name), data_torch)] - def prepare_tensors(self): + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) @@ -1617,9 +1695,9 @@ class LlamaModel(Model): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - if not self.is_lora: - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + def prepare_tensors(self): super().prepare_tensors() if self._experts is not None: @@ -1629,6 +1707,178 @@ class LlamaModel(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("DeciLMForCausalLM") +class DeciModel(Model): + model_arch = gguf.MODEL_ARCH.DECI + + @staticmethod + def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int: + # DeciLM-specific code + intermediate_size = int(2 * ffn_mult * n_embd / 3) + return DeciModel._find_multiple(intermediate_size, 256) + + @staticmethod + def _find_multiple(n: int, k: int) -> int: + # DeciLM-specific code + if n % k == 0: + return n + return n + k - (n % k) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + _block_configs: list[dict[str,Any]] = self.hparams["block_configs"] + assert self.block_count == len(_block_configs) + self._num_kv_heads = list() + self._num_heads = list() + _ffn_multipliers = list() + # ***linear attention layer*** + # if n_heads_in_group is None and replace_with_linear is True + # then _num_kv_heads[il] is 0 and _num_heads[il] is num_attention_heads + # ***attention-free layer*** + # if n_heads_in_group is None and replace_with_linear is False + # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 + # ***normal attention-layer*** + # if n_heads_in_group is not None, then + # _num_kv_heads[il] is num_attention_head // n_heads_in_group and + # _num_heads[il] is num_attention_head + for il in range(len(_block_configs)): + if _block_configs[il]["attention"]["n_heads_in_group"] is None: + if _block_configs[il]["attention"]["replace_with_linear"] is True: + self._num_kv_heads.append(0) + self._num_heads.append(self.hparams["num_attention_heads"]) + else: + self._num_kv_heads.append(0) + self._num_heads.append(0) + else: + self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"]) + self._num_heads.append(self.hparams["num_attention_heads"]) + _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"]) + assert self.block_count == len(self._num_kv_heads) + assert self.block_count == len(self._num_heads) + assert self.block_count == len(_ffn_multipliers) + assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int) + assert isinstance(self._num_heads, list) and isinstance(self._num_heads[0], int) + assert isinstance(_ffn_multipliers, list) and isinstance(_ffn_multipliers[0], float) + self._ffn_dims: list[int] = [ + DeciModel._ffn_mult_to_intermediate_size(multiplier, self.hparams["hidden_size"]) + for multiplier in _ffn_multipliers + ] + + def set_vocab(self): + # Please change tokenizer_config.json of Llama-3_1-Nemotron-51B's + # eos_token from '|eot_id|' to '|end_of_text|' + if self.hparams.get("vocab_size", 128256) == 128256: + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab.add_to_gguf(self.gguf_writer) + else: + # DeciLM-7B + self._set_vocab_llama_hf() + + def set_gguf_parameters(self): + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + assert self.block_count == len(self._num_kv_heads) + assert self.block_count == len(self._num_heads) + assert self.block_count == len(self._ffn_dims) + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + self.gguf_writer.add_head_count_kv(self._num_kv_heads) + self.gguf_writer.add_head_count(self._num_heads) + self.gguf_writer.add_feed_forward_length(self._ffn_dims) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_key_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_value_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_file_type(self.ftype) + else: # DeciLM-7B + super().set_gguf_parameters() + if "num_key_value_heads_per_layer" in self.hparams: # DeciLM-7B + self._num_kv_heads: list[int] = self.hparams["num_key_value_heads_per_layer"] + assert self.block_count == len(self._num_kv_heads) + self.gguf_writer.add_head_count_kv(self._num_kv_heads) + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + if bid is not None: + if "num_key_value_heads_per_layer" in self.hparams: + n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid] + elif "block_configs" in self.hparams: + n_kv_head = self._num_kv_heads[bid] + n_head = self._num_heads[bid] + else: + n_kv_head = self.hparams.get("num_key_value_heads") + else: + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = DeciModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = DeciModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): + if rope_scaling.get("rope_type", '').lower() == "llama3": + base = self.hparams.get("rope_theta", 10000.0) + dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = rope_scaling.get("factor", 8.0) + low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + assert low_freq_wavelen != high_freq_wavelen + + rope_factors = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + rope_factors.append(1) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1 / ((1 - smooth) / factor + smooth)) + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + + def prepare_tensors(self): + super().prepare_tensors() + + @Model.register("BitnetForCausalLM") class BitnetModel(Model): model_arch = gguf.MODEL_ARCH.BITNET @@ -1797,19 +2047,97 @@ class MiniCPMModel(Model): model_arch = gguf.MODEL_ARCH.MINICPM def set_gguf_parameters(self): - block_count = self.hparams["num_hidden_layers"] - self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) - self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) - self.gguf_writer.add_block_count(block_count) - self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) - self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) - self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) - self.gguf_writer.add_file_type(self.ftype) + super().set_gguf_parameters() + embedding_scale = float(self.hparams["scale_emb"]) + self.gguf_writer.add_embedding_scale(embedding_scale) + logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}") + residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5 + self.gguf_writer.add_residual_scale(residual_scale) + logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}") + logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"] + self.gguf_writer.add_logit_scale(logit_scale) + logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}") + if self.hparams.get("rope_scaling") is not None: + if self.hparams["rope_scaling"].get("type") == "longrope": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE) + logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}") + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + + rope_scaling = self.find_hparam(['rope_scaling'], True) + if rope_scaling is not None: + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) + + if long_factors is None or short_factors is None: + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) def set_vocab(self): - self._set_vocab_llama_hf() + self._set_vocab_sentencepiece() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + # HF models permute some of the tensors, so we need to undo that + if name.endswith(("q_proj.weight")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + + return [(self.map_tensor_name(name), data_torch)] + + +@Model.register("MiniCPM3ForCausalLM") +class MiniCPM3Model(Model): + model_arch = gguf.MODEL_ARCH.MINICPM3 + + def set_gguf_parameters(self): + hparams = self.hparams + + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_context_length(hparams["max_position_embeddings"]) + self.gguf_writer.add_embedding_length(hparams["hidden_size"]) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) + self.gguf_writer.add_head_count(hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: + self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) + self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + rope_scaling = self.find_hparam(['rope_scaling'], True) + if rope_scaling is not None: + rope_dims = self.hparams["qk_rope_head_dim"] + + long_factors = rope_scaling.get('long_factor', None) + short_factors = rope_scaling.get('short_factor', None) + + if long_factors is None or short_factors is None: + raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + + if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: + raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + + def set_vocab(self): + self._set_vocab_sentencepiece() def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: @@ -1821,20 +2149,6 @@ class MiniCPMModel(Model): .reshape(weights.shape) ) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - - n_head = self.hparams["num_attention_heads"] - n_kv_head = self.hparams.get("num_key_value_heads") - - # HF models permute some of the tensors, so we need to undo that - if name.endswith(("q_proj.weight")): - data_torch = self._reverse_hf_permute(data_torch, n_head, n_head) - if name.endswith(("k_proj.weight")): - data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head) - - return [(self.map_tensor_name(name), data_torch)] - @Model.register("QWenLMHeadModel") class QwenModel(Model): @@ -1888,6 +2202,75 @@ class Qwen2Model(Model): except FileNotFoundError: self._set_vocab_gpt2() + def set_gguf_parameters(self): + super().set_gguf_parameters() + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "yarn": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + + +@Model.register("Qwen2VLForConditionalGeneration") +class Qwen2VLModel(Model): + model_arch = gguf.MODEL_ARCH.QWEN2VL + + def set_gguf_parameters(self): + super().set_gguf_parameters() + mrope_section = self.hparams["rope_scaling"]["mrope_section"] + mrope_section += [0] * max(0, 4 - len(mrope_section)) + self.gguf_writer.add_rope_dimension_sections(mrope_section) + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for name, data in super().get_tensors(): + if name.startswith("visual."): + continue + yield name, data + + +@Model.register("WavTokenizerDec") +class WavTokenizerDecModel(Model): + model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if \ + name.endswith("codebook.cluster_size") or \ + name.endswith("codebook.embed_avg") or \ + name.endswith("codebook.inited"): + logger.debug(f"Skipping {name!r}") + return [] + + logger.info(f"{self.map_tensor_name(name)} -> {data_torch.shape}") + + return [(self.map_tensor_name(name), data_torch)] + + def set_vocab(self): + self._set_vocab_none() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_vocab_size (self.hparams["vocab_size"]) + self.gguf_writer.add_features_length (self.hparams["n_embd_features"]) + self.gguf_writer.add_feed_forward_length(self.hparams["n_ff"]) + self.gguf_writer.add_group_norm_eps (self.hparams["group_norm_epsilon"]) + self.gguf_writer.add_group_norm_groups (self.hparams["group_norm_groups"]) + + self.gguf_writer.add_posnet_embedding_length(self.hparams["posnet"]["n_embd"]) + self.gguf_writer.add_posnet_block_count (self.hparams["posnet"]["n_layer"]) + + self.gguf_writer.add_convnext_embedding_length(self.hparams["convnext"]["n_embd"]) + self.gguf_writer.add_convnext_block_count (self.hparams["convnext"]["n_layer"]) + + self.gguf_writer.add_causal_attention(False) + @Model.register("Qwen2MoeForCausalLM") class Qwen2MoeModel(Model): @@ -2017,6 +2400,15 @@ class Phi3MiniModel(Model): model_arch = gguf.MODEL_ARCH.PHI3 def set_vocab(self): + # Phi-4 model uses GPT2Tokenizer + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + tokenizer_class = tokenizer_config_json['tokenizer_class'] + if tokenizer_class == 'GPT2Tokenizer': + return self._set_vocab_gpt2() + from sentencepiece import SentencePieceProcessor tokenizer_path = self.dir_model / 'tokenizer.model' @@ -2133,7 +2525,18 @@ class Phi3MiniModel(Model): self.gguf_writer.add_rope_dimension_count(rope_dims) self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"])) self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_sliding_window(self.find_hparam(["sliding_window"])) + sliding_window = self.hparams.get("sliding_window") + # use zero value of sliding_window to distinguish Phi-4 from other PHI3 models + if sliding_window is None: + sliding_window = 0 + self.gguf_writer.add_sliding_window(sliding_window) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + n_embd = self.find_hparam(["hidden_size", "n_embd"]) + n_head = self.find_hparam(["num_attention_heads", "n_head"]) + max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) + orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) + rope_dims = n_embd // n_head # write rope scaling for long context (128k) model rope_scaling = self.find_hparam(['rope_scaling'], True) @@ -2164,9 +2567,65 @@ class Phi3MiniModel(Model): if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') - if not self.is_lora: - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32)) - self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + + +@Model.register("PhiMoEForCausalLM") +class PhiMoeModel(Phi3MiniModel): + model_arch = gguf.MODEL_ARCH.PHIMOE + + _experts: list[dict[str, Tensor]] | None = None + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_count(self.hparams["num_local_experts"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams["num_local_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["w1", "w2", "w3"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") @Model.register("PlamoForCausalLM") @@ -2426,7 +2885,67 @@ class InternLM2Model(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("BertModel", "CamembertModel") +@Model.register("InternLM3ForCausalLM") +class InternLM3Model(Model): + model_arch = gguf.MODEL_ARCH.LLAMA + + def set_vocab(self): + tokens, scores, toktypes = self._create_vocab_sentencepiece() + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + if "add_prefix_space" in tokenizer_config_json: + self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + + if "added_tokens_decoder" in tokenizer_config_json: + for token_id, token_data in tokenizer_config_json["added_tokens_decoder"].items(): + if token_data.get("special"): + token_id = int(token_id) + token = token_data["content"] + special_vocab._set_special_token(token, token_id) + # update eos token + if token == '<|im_end|>' and "eos" in special_vocab.special_token_ids: + special_vocab.special_token_ids["eos"] = token_id + + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(rope_dim) + + if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"].get("type") == "linear" or self.hparams["rope_scaling"].get("rope_type") == "linear": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + +@Model.register("BertModel", "BertForMaskedLM", "CamembertModel") class BertModel(Model): model_arch = gguf.MODEL_ARCH.BERT @@ -2467,7 +2986,8 @@ class BertModel(Model): # we need this to validate the size of the token_type embeddings # though currently we are passing all zeros to the token_type embeddings - self.gguf_writer.add_token_type_count(2) # "Sequence A" or "Sequence B" + # "Sequence A" or "Sequence B" + self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) # convert to phantom space vocab def phantom(tok): @@ -2491,13 +3011,73 @@ class BertModel(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused + if name.startswith("bert."): + name = name[5:] + + if name.endswith(".gamma"): + name = name[:-6] + ".weight" + + if name.endswith(".beta"): + name = name[:-5] + ".bias" + # we are only using BERT for embeddings so we don't need the pooling layer if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): return [] # we don't need these + if name.startswith("cls.predictions"): + return [] + + if name.startswith("cls.seq_relationship"): + return [] + return [(self.map_tensor_name(name), data_torch)] +@Model.register("RobertaModel") +class RobertaModel(BertModel): + model_arch = gguf.MODEL_ARCH.BERT + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # we need the pad_token_id to know how to chop down position_embd matrix + if (pad_token_id := self.hparams.get("pad_token_id")) is not None: + self._position_offset = 1 + pad_token_id + if "max_position_embeddings" in self.hparams: + self.hparams["max_position_embeddings"] -= self._position_offset + else: + self._position_offset = None + + def set_vocab(self): + """Support BPE tokenizers for roberta models""" + bpe_tok_path = self.dir_model / "tokenizer.json" + if bpe_tok_path.exists(): + self._set_vocab_gpt2() + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + + # we need this to validate the size of the token_type embeddings + # though currently we are passing all zeros to the token_type embeddings + # "Sequence A" or "Sequence B" + self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) + + else: + return super().set_vocab() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "roberta.", remove the prefix + # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main + if name.startswith("roberta."): + name = name[8:] + + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor + if name == "embeddings.position_embeddings.weight": + if self._position_offset is not None: + data_torch = data_torch[self._position_offset:,:] + + return super().modify_tensors(data_torch, name, bid) + + @Model.register("NomicBertModel") class NomicBertModel(BertModel): model_arch = gguf.MODEL_ARCH.NOMIC_BERT @@ -2528,7 +3108,7 @@ class NomicBertModel(BertModel): self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) -@Model.register("XLMRobertaModel") +@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -2614,7 +3194,7 @@ class XLMRobertaModel(BertModel): self.gguf_writer.add_token_scores(scores) self.gguf_writer.add_token_types(toktypes) self.gguf_writer.add_add_space_prefix(add_prefix) - self.gguf_writer.add_token_type_count(1) + self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces) if precompiled_charsmap: self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap) @@ -2626,6 +3206,11 @@ class XLMRobertaModel(BertModel): self.gguf_writer.add_add_eos_token(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "roberta.", remove the prefix + # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main + if name.startswith("roberta."): + name = name[8:] + # position embeddings start at pad_token_id + 1, so just chop down the weight tensor if name == "embeddings.position_embeddings.weight": if self._position_offset is not None: @@ -2769,6 +3354,11 @@ class Rwkv6Model(Model): self.gguf_writer.add_tokenizer_model("rwkv") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False) + special_vocab.chat_template = "rwkv-world" + # hack: Add '\n\n' as the EOT token to make it chat normally + special_vocab._set_special_token("eot", 261) + special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): block_count = self.hparams["num_hidden_layers"] @@ -2795,6 +3385,8 @@ class Rwkv6Model(Model): # required by llama.cpp, unused self.gguf_writer.add_head_count(0) + lerp_weights: dict[int, dict[str, Tensor]] = {} + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: new_name = self.map_tensor_name(name) @@ -2807,14 +3399,87 @@ class Rwkv6Model(Model): if new_name.endswith("time_mix_w2.weight"): data_torch = data_torch.permute(0, 2, 1) - rescale_every_n_layers = self.hparams["rescale_every"] - if rescale_every_n_layers > 0: - if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"): - data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers)) + if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name: + data_torch = data_torch.squeeze() + + try: + rescale_every_n_layers = self.hparams["rescale_every"] + if rescale_every_n_layers > 0: + if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"): + data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers)) + except KeyError: + pass + + # concat time_mix_lerp weights to reduce some cpu overhead + # also reduces the number of tensors in the model + if bid is not None and "time_mix_lerp" in new_name and "time_mix_lerp_x" not in new_name: + try: + self.lerp_weights[bid][new_name] = data_torch + except KeyError: + self.lerp_weights[bid] = {new_name: data_torch} + if all(f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() for i in ["w", "k", "v", "r", "g"]): + new_name = f"blk.{bid}.time_mix_lerp_fused.weight" + data = torch.stack([self.lerp_weights[bid][f"blk.{bid}.time_mix_lerp_{i}.weight"].unsqueeze(0) for i in ["w", "k", "v", "r", "g"]], dim=0).unsqueeze(1) + yield (new_name, data) + return yield (new_name, data_torch) +@Model.register("RWKV6Qwen2ForCausalLM") +class RWKV6Qwen2Model(Rwkv6Model): + model_arch = gguf.MODEL_ARCH.RWKV6QWEN2 + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + num_attention_heads = self.hparams["num_attention_heads"] + num_key_value_heads = self.hparams["num_key_value_heads"] + hidden_size = self.hparams["hidden_size"] + head_size = hidden_size // num_attention_heads + rms_norm_eps = self.hparams["rms_norm_eps"] + intermediate_size = self.hparams["intermediate_size"] + time_mix_extra_dim = 64 if hidden_size >= 4096 else 32 + time_decay_extra_dim = 128 if hidden_size >= 4096 else 64 + + # RWKV isn't context limited + self.gguf_writer.add_context_length(1048576) + self.gguf_writer.add_embedding_length(hidden_size) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_wkv_head_size(head_size) + self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim) + self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim) + self.gguf_writer.add_feed_forward_length(intermediate_size) + self.gguf_writer.add_file_type(self.ftype) + + # special parameters for time_mixing in RWKV6QWEN2 + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_token_shift_count(1) + # RWKV6QWEN2 use grouped key/value like GQA + self.gguf_writer.add_head_count_kv(num_key_value_heads) + + # required by llama.cpp, unused + self.gguf_writer.add_head_count(0) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + for new_name, data in super().modify_tensors(data_torch, name, bid): + if "time_mix_w1" in new_name or "time_mix_w2" in new_name: + data = data.view(5, -1, data.shape[-1]) + # rwkv6qwen2 has a different order of rkvwg instead of the original wkvrg + # permute them here to avoid code changes + data = torch.stack([data[3], data[1], data[2], data[0], data[4]], dim=0).view(-1, data.shape[-1]) + if "w2" in new_name: + data = data.view(5, -1, data.shape[-1]) + yield (new_name, data) + continue + yield (new_name, data) + + @Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") class MambaModel(Model): model_arch = gguf.MODEL_ARCH.MAMBA @@ -2909,6 +3574,24 @@ class CommandR2Model(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) +@Model.register("Cohere2ForCausalLM") +class Cohere2Model(Model): + model_arch = gguf.MODEL_ARCH.COHERE2 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_logit_scale(self.hparams["logit_scale"]) + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + + rotary_pct = self.hparams["rotary_pct"] + hidden_size = self.hparams["hidden_size"] + num_attention_heads = self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads))) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + @Model.register("OlmoForCausalLM") @Model.register("OLMoForCausalLM") class OlmoModel(Model): @@ -2937,6 +3620,71 @@ class OlmoModel(Model): return [(self.map_tensor_name(name), data_torch)] +@Model.register("Olmo2ForCausalLM") +class Olmo2Model(Model): + model_arch = gguf.MODEL_ARCH.OLMO2 + + +@Model.register("OlmoeForCausalLM") +class OlmoeModel(Model): + model_arch = gguf.MODEL_ARCH.OLMOE + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_layer_norm_rms_eps(1e-5) + if (n_experts := self.hparams.get("num_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + + _experts: list[dict[str, Tensor]] | None = None + + # Copied from: Qwen2MoeModel + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # process the experts separately + if name.find("experts") != -1: + n_experts = self.hparams["num_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + # Copied from: Qwen2MoeModel + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @Model.register("JinaBertModel", "JinaBertForMaskedLM") class JinaBertV2Model(BertModel): model_arch = gguf.MODEL_ARCH.JINA_BERT_V2 @@ -2975,6 +3723,14 @@ class JinaBertV2Model(BertModel): self.gguf_writer.add_add_bos_token(True) self.gguf_writer.add_add_eos_token(True) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # if name starts with "bert.", remove the prefix + # e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + if name.startswith("bert."): + name = name[5:] + + return super().modify_tensors(data_torch, name, bid) + @Model.register("OpenELMForCausalLM") class OpenELMModel(Model): @@ -3202,7 +3958,99 @@ class ArcticModel(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("DeepseekForCausalLM") +class DeepseekModel(Model): + model_arch = gguf.MODEL_ARCH.DEEPSEEK + + def set_vocab(self): + try: + self._set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + if "head_dim" in hparams: + rope_dim = hparams["head_dim"] + else: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + + self.gguf_writer.add_rope_dimension_count(rope_dim) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_weights_scale(1.0) + self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) + + _experts: list[dict[str, Tensor]] | None = None + + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = DeepseekModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = DeepseekModel.permute(data_torch, n_head, n_kv_head) + + # process the experts separately + if name.find("mlp.experts") != -1: + n_experts = self.hparams["n_routed_experts"] + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @Model.register("DeepseekV2ForCausalLM") +@Model.register("DeepseekV3ForCausalLM") class DeepseekV2Model(Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -3224,6 +4072,15 @@ class DeepseekV2Model(Model): self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["scoring_func"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["scoring_func"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: @@ -3236,6 +4093,16 @@ class DeepseekV2Model(Model): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # rename e_score_correction_bias tensors + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + + # skip Multi-Token Prediction (MTP) layers + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return [] + # process the experts separately if name.find("mlp.experts") != -1: n_experts = self.hparams["n_routed_experts"] @@ -3577,10 +4444,7 @@ class JaisModel(Model): # Embeddings scale self.embeddings_scale = 1.0 - # note: For some JAIS flavors, output is tied to (same as) wte in original model - self.output_is_wte = False if 'mup_embeddings_scale' in self.hparams: - self.output_is_wte = True # Hack (?) self.embeddings_scale = self.hparams['mup_embeddings_scale'] elif 'embeddings_scale' in self.hparams: self.embeddings_scale = self.hparams['embeddings_scale'] @@ -3637,10 +4501,7 @@ class JaisModel(Model): if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD): tensors.append((new_name, data_torch * self.embeddings_scale)) - if self.output_is_wte: - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale)) elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT): - assert not self.output_is_wte tensors.append((new_name, data_torch * self.width_scale)) else: tensors.append((new_name, data_torch)) @@ -3915,7 +4776,7 @@ class ExaoneModel(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) - def prepare_tensors(self): + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) @@ -3942,14 +4803,112 @@ class ExaoneModel(Model): smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - if not self.is_lora: - self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32)) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) - super().prepare_tensors() + +@Model.register("GraniteForCausalLM") +class GraniteModel(LlamaModel): + """Conversion for IBM's GraniteForCausalLM""" + model_arch = gguf.MODEL_ARCH.GRANITE + + def set_gguf_parameters(self): + """Granite uses standard llama parameters with the following differences: + + - No head_dim support + - New multiplier params: + - attention_scale + - embedding_scale + - residual_scale + - logits_scaling + """ + if head_dim := self.hparams.pop("head_dim", None): + logger.warning("Ignoring head_dim (%s) from config for Granite", head_dim) + super().set_gguf_parameters() + # NOTE: Convert _multiplier params to _scale params for naming + # consistency + if attention_scale := self.hparams.get("attention_multiplier"): + self.gguf_writer.add_attention_scale(attention_scale) + logger.info("gguf: (granite) attention_scale = %s", attention_scale) + if embedding_scale := self.hparams.get("embedding_multiplier"): + self.gguf_writer.add_embedding_scale(embedding_scale) + logger.info("gguf: (granite) embedding_scale = %s", embedding_scale) + if residual_scale := self.hparams.get("residual_multiplier"): + self.gguf_writer.add_residual_scale(residual_scale) + logger.info("gguf: (granite) residual_scale = %s", residual_scale) + if logits_scale := self.hparams.get("logits_scaling"): + self.gguf_writer.add_logit_scale(logits_scale) + logger.info("gguf: (granite) logits_scale = %s", logits_scale) + + +@Model.register("GraniteMoeForCausalLM") +class GraniteMoeModel(GraniteModel): + """Conversion for IBM's GraniteMoeForCausalLM""" + model_arch = gguf.MODEL_ARCH.GRANITE_MOE + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + """In modeling_granitemoe, the JetMoe implementation of parallel experts + is used. This essentially merges w1 and w3 into a single tensor with 2x + the hidden size that is then split during forward. To keep compatibility + with existing mixtral support, we pull them apart here. + """ + + if name.endswith("block_sparse_moe.input_linear.weight"): + ffn_dim = self.hparams["intermediate_size"] + assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" + gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :] + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), + ] + + return super().modify_tensors(data_torch, name, bid) + + +@Model.register("ChameleonForConditionalGeneration") +@Model.register("ChameleonForCausalLM") # obsolete +class ChameleonModel(Model): + model_arch = gguf.MODEL_ARCH.CHAMELEON + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False)) + + def set_vocab(self): + self._set_vocab_gpt2() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # ignore image tokenizer for now + # TODO: remove this once image support is implemented for Chameleon + if name.startswith("model.vqmodel"): + return [] + + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + hidden_dim = self.hparams.get("hidden_size") + + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + if name.endswith(("q_norm.weight", "q_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim) + if name.endswith(("k_norm.weight", "k_norm.bias")): + data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim) + + return [(self.map_tensor_name(name), data_torch)] + + # see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203 + @staticmethod + def _reverse_hf_permute(data_torch, n_heads, hidden_dim): + head_dim = hidden_dim // n_heads + data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1) + data_torch = data_torch.repeat_interleave(n_heads, 0) + return data_torch ###### CONVERSION LOGIC ###### + # tree of lazy tensors class LazyTorchTensor(gguf.LazyBase): _tensor_type = torch.Tensor @@ -4038,6 +4997,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "model", type=Path, help="directory containing model file", + nargs="?", ) parser.add_argument( "--use-temp-file", action="store_true", @@ -4075,8 +5035,15 @@ def parse_args() -> argparse.Namespace: "--metadata", type=Path, help="Specify the path for an authorship metadata override file" ) + parser.add_argument( + "--print-supported-models", action="store_true", + help="Print the supported models" + ) - return parser.parse_args() + args = parser.parse_args() + if not args.print_supported_models and args.model is None: + parser.error("the following arguments are required: model") + return args def split_str_to_n_bytes(split_str: str) -> int: @@ -4100,6 +5067,11 @@ def split_str_to_n_bytes(split_str: str) -> int: def main() -> None: args = parse_args() + if args.print_supported_models: + logger.error("Supported models:") + Model.print_registered_models() + sys.exit(0) + if args.verbose: logging.basicConfig(level=logging.DEBUG) else: diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index ff4955f9c..cea34413f 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -17,7 +17,7 @@ # # python3 convert_hf_to_gguf_update.py # -# - Copy-paste the generated get_vocab_base_pre() function into convert_hf_to_gguf.py +# - The convert_hf_to_gguf.py script will have had its get_vocab_base_pre() function updated # - Update llama.cpp with the new pre-tokenizer if necessary # # TODO: generate tokenizer tests for llama.cpp @@ -31,6 +31,7 @@ import re import requests import sys import json +import shutil from hashlib import sha256 from enum import IntEnum, auto @@ -64,39 +65,50 @@ else: # TODO: add models here, base models preferred models = [ - {"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", }, - {"name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", }, - {"name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", }, - {"name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", }, - {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", }, - {"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", }, - {"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", }, - {"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", }, - {"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", }, - {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, - {"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", }, - {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, - {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, - {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, - {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, - {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, - {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! - {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, - {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, - {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, - {"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", }, - {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, - {"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B - {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, - {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, - {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, - {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, - {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, - {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, - {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, - {'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", }, - {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, - {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, + {"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", }, + {"name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", }, + {"name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", }, + {"name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", }, + {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", }, + {"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", }, + {"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", }, + {"name": "falcon3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", }, + {"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", }, + {"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", }, + {"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", }, + {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, + {"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", }, + {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, + {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, + {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, + {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, + {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, + {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", }, + {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! + {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, + {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, + {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, + {"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", }, + {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, + {"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B + {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, + {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, + {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, + {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, + {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, + {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, + {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, + {'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", }, + {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, + {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, + {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, + {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, + {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", }, + {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, + {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, + {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, + {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, + {"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"}, ] @@ -125,12 +137,27 @@ def download_model(model): if tokt == TOKENIZER_TYPE.UGM: files.append("spiece.model") - for file in files: - save_path = f"models/tokenizers/{name}/{file}" - if os.path.isfile(save_path): - logger.info(f"{name}: File {save_path} already exists - skipping") - continue - download_file_with_auth(f"{repo}/resolve/main/{file}", token, save_path) + if os.path.isdir(repo): + # If repo is a path on the file system, copy the directory + for file in files: + src_path = os.path.join(repo, file) + dst_path = f"models/tokenizers/{name}/{file}" + if os.path.isfile(dst_path): + logger.info(f"{name}: File {dst_path} already exists - skipping") + continue + if os.path.isfile(src_path): + shutil.copy2(src_path, dst_path) + logger.info(f"{name}: Copied {src_path} to {dst_path}") + else: + logger.warning(f"{name}: Source file {src_path} does not exist") + else: + # If repo is a URL, download the files + for file in files: + save_path = f"models/tokenizers/{name}/{file}" + if os.path.isfile(save_path): + logger.info(f"{name}: File {save_path} already exists - skipping") + continue + download_file_with_auth(f"{repo}/resolve/main/{file}", token, save_path) for model in models: diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index ddd347a2a..6dea14a23 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -12,6 +12,7 @@ import json from math import prod from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast +from transformers import AutoConfig import torch @@ -225,12 +226,15 @@ def get_base_tensor_name(lora_tensor_name: str) -> str: base_name = lora_tensor_name.replace("base_model.model.", "") base_name = base_name.replace(".lora_A.weight", ".weight") base_name = base_name.replace(".lora_B.weight", ".weight") + # models produced by mergekit-extract-lora have token embeddings in the adapter + base_name = base_name.replace(".lora_embedding_A", ".weight") + base_name = base_name.replace(".lora_embedding_B", ".weight") return base_name def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file") + description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file") parser.add_argument( "--outfile", type=Path, help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", @@ -256,17 +260,27 @@ def parse_args() -> argparse.Namespace: help="only print out what will be done, without writing any new files", ) parser.add_argument( - "--base", type=Path, required=True, - help="directory containing base model file", + "--base", type=Path, + help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config", + ) + parser.add_argument( + "--base-model-id", type=str, + help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')", ) parser.add_argument( "lora_path", type=Path, - help="directory containing LoRA adapter file", + help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)", ) return parser.parse_args() +def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]: + # normally, adapter does not come with base model config, we need to load it from AutoConfig + config = AutoConfig.from_pretrained(hf_model_id) + return config.to_dict() + + if __name__ == '__main__': args = parse_args() logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) @@ -281,8 +295,9 @@ if __name__ == '__main__': ftype = ftype_map[args.outtype] - dir_base_model: Path = args.base + dir_base_model: Path | None = args.base dir_lora: Path = args.lora_path + base_model_id: str | None = args.base_model_id lora_config = dir_lora / "adapter_config.json" input_model = dir_lora / "adapter_model.safetensors" @@ -301,9 +316,32 @@ if __name__ == '__main__': input_model = os.path.join(dir_lora, "adapter_model.bin") lora_model = torch.load(input_model, map_location="cpu", weights_only=True) + # load LoRA config + with open(lora_config, "r") as f: + lparams: dict[str, Any] = json.load(f) + # load base model - logger.info(f"Loading base model: {dir_base_model.name}") - hparams = Model.load_hparams(dir_base_model) + if base_model_id is not None: + logger.info(f"Loading base model from Hugging Face: {base_model_id}") + hparams = load_hparams_from_hf(base_model_id) + elif dir_base_model is None: + if "base_model_name_or_path" in lparams: + model_id = lparams["base_model_name_or_path"] + logger.info(f"Loading base model from Hugging Face: {model_id}") + try: + hparams = load_hparams_from_hf(model_id) + except OSError as e: + logger.error(f"Failed to load base model config: {e}") + logger.error("Please try downloading the base model and add its path to --base") + sys.exit(1) + else: + logger.error("'base_model_name_or_path' is not found in adapter_config.json") + logger.error("Base model config is required. Please download the base model and add its path to --base") + sys.exit(1) + else: + logger.info(f"Loading base model: {dir_base_model.name}") + hparams = Model.load_hparams(dir_base_model) + with torch.inference_mode(): try: model_class = Model.from_model_architecture(hparams["architectures"][0]) @@ -323,13 +361,19 @@ if __name__ == '__main__': self.dir_model_card = dir_lora_model self.lora_alpha = float(lora_alpha) + def set_vocab(self): + pass + def set_type(self): self.gguf_writer.add_type(gguf.GGUFType.ADAPTER) self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") def set_gguf_parameters(self): self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) - super().set_gguf_parameters() + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # Never add extra tensors (e.g. rope_freqs) for LoRA adapters + return () def get_tensors(self) -> Iterator[tuple[str, Tensor]]: tensor_map: dict[str, PartialLoraTensor] = {} @@ -338,12 +382,20 @@ if __name__ == '__main__': if self.lazy: tensor = LazyTorchTensor.from_eager(tensor) base_name = get_base_tensor_name(name) - is_lora_a = ".lora_A.weight" in name - is_lora_b = ".lora_B.weight" in name + # note: mergekit-extract-lora also adds token embeddings to the adapter + is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name + is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name if not is_lora_a and not is_lora_b: if ".base_layer.weight" in name: continue + # mergekit-extract-lora add these layernorm to the adapter, we need to keep them + if "_layernorm" in name or ".norm" in name: + yield (base_name, tensor) + continue logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") + if ".embed_tokens.weight" in name or ".lm_head.weight" in name: + logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning") + logger.error("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948") sys.exit(1) if base_name in tensor_map: @@ -363,17 +415,32 @@ if __name__ == '__main__': yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B))) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - dest = super().modify_tensors(data_torch, name, bid) + dest = list(super().modify_tensors(data_torch, name, bid)) + # some archs may have the same tensor for lm_head and output (tie word embeddings) + # in this case, adapters targeting lm_head will fail when using llama-export-lora + # therefore, we ignore them for now + # see: https://github.com/ggerganov/llama.cpp/issues/9065 + if name == "lm_head.weight" and len(dest) == 0: + raise ValueError("lm_head is present in adapter, but is ignored in base model") for dest_name, dest_data in dest: + # mergekit-extract-lora add these layernorm to the adapter + if "_norm" in dest_name: + assert dest_data.dim() == 1 + yield (dest_name, dest_data) + continue + + # otherwise, we must get the lora_A and lora_B tensors assert isinstance(dest_data, LoraTorchTensor) lora_a, lora_b = dest_data.get_lora_A_B() + # note: mergekit-extract-lora flip and transpose A and B + # here we only need to transpose token_embd.lora_a, see llm_build_inp_embd() + if "token_embd.weight" in dest_name: + lora_a = lora_a.T + yield (dest_name + ".lora_a", lora_a) yield (dest_name + ".lora_b", lora_b) - with open(lora_config, "r") as f: - lparams: dict[str, Any] = json.load(f) - alpha: float = lparams["lora_alpha"] model_instance = LoraModel( @@ -386,7 +453,7 @@ if __name__ == '__main__': dry_run=args.dry_run, dir_lora_model=dir_lora, lora_alpha=alpha, - is_lora=True, + hparams=hparams, ) logger.info("Exporting model...") diff --git a/docs/android.md b/docs/android.md index cec4358d9..47530c6c1 100644 --- a/docs/android.md +++ b/docs/android.md @@ -2,55 +2,82 @@ # Android ## Build on Android using Termux -[Termux](https://github.com/termux/termux-app#installation) is a method to execute `llama.cpp` on an Android device (no root required). -``` -apt update && apt upgrade -y -apt install git make cmake -``` -It's recommended to move your model inside the `~/` directory for best performance: -``` -cd storage/downloads -mv model.gguf ~/ -``` +[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid. -[Get the code](https://github.com/ggerganov/llama.cpp#get-the-code) & [follow the Linux build instructions](https://github.com/ggerganov/llama.cpp#build) to build `llama.cpp`. - -## Building the Project using Android NDK -Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake. - -Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux: -``` -$ mkdir build-android -$ cd build-android -$ export NDK= -$ cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod .. -$ make -``` - -Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice). - -Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission: - -(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`) -``` -$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/ -$cd /data/data/com.termux/files/home/bin -$chmod +x ./* -``` - -Download model [llama-2-7b-chat.Q4_K_M.gguf](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/blob/main/llama-2-7b-chat.Q4_K_M.gguf), and push it to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/` +With Termux, you can install and run `llama.cpp` as if the environment were Linux. Once in the Termux shell: ``` -$mv /sdcard/llama.cpp/llama-2-7b-chat.Q4_K_M.gguf /data/data/com.termux/files/home/model/ +$ apt update && apt upgrade -y +$ apt install git cmake ``` -Now, you can start chatting: +Then, follow the [build instructions](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md), specifically for CMake. + +Once the binaries are built, download your model of choice (e.g., from Hugging Face). It's recommended to place it in the `~/` directory for best performance: + ``` -$cd /data/data/com.termux/files/home/bin -$./llama-cli -m ../model/llama-2-7b-chat.Q4_K_M.gguf -n 128 -cml +$ curl -L {model-url} -o ~/{model}.gguf ``` -Here's a demo of an interactive session running on Pixel 5 phone: +Then, if you are not already in the repo directory, `cd` into `llama.cpp` and: + +``` +$ ./build/bin/llama-cli -m ~/{model}.gguf -c {context-size} -p "{your-prompt}" +``` + +Here, we show `llama-cli`, but any of the executables under `examples` should work, in theory. Be sure to set `context-size` to a reasonable number (say, 4096) to start with; otherwise, memory could spike and kill your terminal. + +To see what it might look like visually, here's an old demo of an interactive session running on a Pixel 5 phone: https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4 + +## Cross-compile using Android NDK +It's possible to build `llama.cpp` for Android on your host system via CMake and the Android NDK. If you are interested in this path, ensure you already have an environment prepared to cross-compile programs for Android (i.e., install the Android SDK). Note that, unlike desktop environments, the Android environment ships with a limited set of native libraries, and so only those libraries are available to CMake when building with the Android NDK (see: https://developer.android.com/ndk/guides/stable_apis.) + +Once you're ready and have cloned `llama.cpp`, invoke the following in the project directory: + +``` +$ cmake \ + -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DCMAKE_C_FLAGS="-march=armv8.7a" \ + -DCMAKE_CXX_FLAGS="-march=armv8.7a" \ + -DGGML_OPENMP=OFF \ + -DGGML_LLAMAFILE=OFF \ + -B build-android +``` + +Notes: + - While later versions of Android NDK ship with OpenMP, it must still be installed by CMake as a dependency, which is not supported at this time + - `llamafile` does not appear to support Android devices (see: https://github.com/Mozilla-Ocho/llamafile/issues/325) + +The above command should configure `llama.cpp` with the most performant options for modern devices. Even if your device is not running `armv8.7a`, `llama.cpp` includes runtime checks for available CPU features it can use. + +Feel free to adjust the Android ABI for your target. Once the project is configured: + +``` +$ cmake --build build-android --config Release -j{n} +$ cmake --install build-android --prefix {install-dir} --config Release +``` + +After installing, go ahead and download the model of your choice to your host system. Then: + +``` +$ adb shell "mkdir /data/local/tmp/llama.cpp" +$ adb push {install-dir} /data/local/tmp/llama.cpp/ +$ adb push {model}.gguf /data/local/tmp/llama.cpp/ +$ adb shell +``` + +In the `adb shell`: + +``` +$ cd /data/local/tmp/llama.cpp +$ LD_LIBRARY_PATH=lib ./bin/llama-simple -m {model}.gguf -c {context-size} -p "{your-prompt}" +``` + +That's it! + +Be aware that Android will not find the library path `lib` on its own, so we must specify `LD_LIBRARY_PATH` in order to run the installed executables. Android does support `RPATH` in later API levels, so this could change in the future. Refer to the previous section for information about `context-size` (very important!) and running other `examples`. diff --git a/docs/backend/BLIS.md b/docs/backend/BLIS.md index 35d06bd0f..904548577 100644 --- a/docs/backend/BLIS.md +++ b/docs/backend/BLIS.md @@ -27,13 +27,6 @@ We recommend using openmp since it's easier to modify the cores being used. ### llama.cpp compilation -Makefile: - -```bash -make GGML_BLIS=1 -j -# make GGML_BLIS=1 llama-benchmark-matmult -``` - CMake: ```bash diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index 6bdd9d2da..23f10175a 100644 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -23,6 +23,8 @@ The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the abi ## News +- 2024.11 + - Support F16 and F32 data type model for Ascend 310P NPU. - 2024.8 - Support `Q4_0` and `Q8_0` data type for Ascend NPU. - 2024.7 @@ -40,9 +42,11 @@ The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the abi ### Ascend NPU **Verified devices** + | Ascend NPU | Status | |:-----------------------------:|:-------:| | Atlas 300T A2 | Support | +| Atlas 300I Duo | Support | *Notes:* diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index e3b9572cc..89ddbd669 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -26,7 +26,7 @@ ### Llama.cpp + SYCL -The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*). +The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it also supports other vendor GPUs: Nvidia and AMD. ## Recommended Release @@ -34,13 +34,16 @@ The SYCL backend would be broken by some PRs due to no online CI. The following release is verified with good quality: -|Commit ID|Tag|Release|Verified Platform| -|-|-|-|-| -|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggerganov/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| +|Commit ID|Tag|Release|Verified Platform| Update date| +|-|-|-|-|-| +|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggerganov/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19| +|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggerganov/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1|| ## News +- 2024.11 + - Use syclcompat to improve the performance on some platforms. This requires to use oneAPI 2025.0 or newer. - 2024.8 - Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs. @@ -111,10 +114,18 @@ SYCL backend supports Intel GPU Family: **Verified devices** -| Nvidia GPU | Status | Verified Model | -|--------------------------|---------|----------------| -| Ampere Series | Support | A100, A4000 | -| Ampere Series *(Mobile)* | Support | RTX 40 Series | +| Nvidia GPU | Status | Verified Model | +|--------------------------|-----------|----------------| +| Ampere Series | Supported | A100, A4000 | +| Ampere Series *(Mobile)* | Supported | RTX 40 Series | + +| AMD GPU | Status | Verified Model | +|--------------------------|--------------|----------------| +| Radeon Pro | Experimental | W6800 | +| Radeon RX | Experimental | 6700 XT | + +Note: AMD GPU support is highly experimental and is incompatible with F16. +Additionally, it only supports GPUs with a sub_group_size (warp size) of 32. ## Docker The docker build option is currently limited to *intel GPU* targets. @@ -122,7 +133,7 @@ The docker build option is currently limited to *intel GPU* targets. ### Build image ```sh # Using FP16 -docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" -f .devops/llama-cli-intel.Dockerfile . +docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile . ``` *Notes*: @@ -186,6 +197,10 @@ Platform #0: Intel(R) OpenCL HD Graphics In order to target Nvidia GPUs through SYCL, please make sure the CUDA/CUBLAS native requirements *-found [here](README.md#cuda)-* are installed. +- **AMD GPU** + +To target AMD GPUs with SYCL, the ROCm stack must be installed first. + 2. **Install Intel® oneAPI Base toolkit** - **For Intel GPU** @@ -212,6 +227,19 @@ cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENAB cmake --build buildWithCublas --config Release ``` +- **Adding support to AMD GPUs** + +**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit. + +**oneMKL for rocBlas**: The current oneMKL releases *(shipped with the oneAPI base-toolkit)* doesn't contain the rocBLAS backend. A build from source of the upstream [oneMKL](https://github.com/oneapi-src/oneMKL) with the *rocBLAS* backend enabled is thus required to run it on AMD GPUs. + +```sh +git clone https://github.com/oneapi-src/oneMKL +cd oneMKL +# Find your HIPTARGET with rocminfo, under the key 'Name:' +cmake -B buildWithrocBLAS -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENABLE_MKLGPU_BACKEND=OFF -DENABLE_MKLCPU_BACKEND=OFF -DENABLE_ROCBLAS_BACKEND=ON -DHIPTARGETS=${HIPTARGET} -DTARGET_DOMAINS=blas +cmake --build buildWithrocBLAS --config Release +``` 3. **Verify installation and environment** @@ -223,22 +251,32 @@ sycl-ls - **Intel GPU** -When targeting an intel GPU, the user should expect one or more level-zero devices among the available SYCL devices. Please make sure that at least one GPU is present, for instance [`ext_oneapi_level_zero:gpu:0`] in the sample output below: +When targeting an intel GPU, the user should expect one or more level-zero devices among the available SYCL devices. Please make sure that at least one GPU is present, for instance [`level_zero:gpu`] in the sample output below: ``` -[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000] -[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000] -[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50] -[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918] +[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000] +[opencl:cpu][opencl:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000] +[opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50] +[level_zero:gpu][level_zero:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918] ``` - **Nvidia GPU** -Similarly, user targeting Nvidia GPUs should expect at least one SYCL-CUDA device [`ext_oneapi_cuda:gpu`] as bellow: +Similarly, user targeting Nvidia GPUs should expect at least one SYCL-CUDA device [`cuda:gpu`] as below: + ``` -[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.12.0.12_195853.xmain-hotfix] -[opencl:cpu:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix] -[ext_oneapi_cuda:gpu:0] NVIDIA CUDA BACKEND, NVIDIA A100-PCIE-40GB 8.0 [CUDA 12.2] +[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.12.0.12_195853.xmain-hotfix] +[opencl:cpu][opencl:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix] +[cuda:gpu][cuda:0] NVIDIA CUDA BACKEND, NVIDIA A100-PCIE-40GB 8.0 [CUDA 12.5] +``` + +- **AMD GPU** + +For AMD GPUs we should expect at least one SYCL-HIP device [`hip:gpu`]: + +``` +[opencl:cpu][opencl:0] Intel(R) OpenCL, 12th Gen Intel(R) Core(TM) i9-12900K OpenCL 3.0 (Build 0) [2024.18.6.0.02_160000] +[hip:gpu][hip:0] AMD HIP BACKEND, AMD Radeon PRO W6800 gfx1030 [HIP 60140.9] ``` ### II. Build llama.cpp @@ -266,6 +304,7 @@ cmake --build build --config Release -j -v ``` #### Nvidia GPU + ```sh # Export relevant ENV variables export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithCublas/lib:$LD_LIBRARY_PATH @@ -274,16 +313,37 @@ export CPLUS_INCLUDE_DIR=/path/to/oneMKL/buildWithCublas/include:$CPLUS_INCLUDE_ export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR # Build LLAMA with Nvidia BLAS acceleration through SYCL +# Setting GGML_SYCL_DEVICE_ARCH is optional but can improve performance +GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture # Option 1: Use FP32 (recommended for better performance in most cases) -cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx +cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx # Option 2: Use FP16 -cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON +cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON # build all binary cmake --build build --config Release -j -v +``` +#### AMD GPU + +```sh +# Export relevant ENV variables +export LD_LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LD_LIBRARY_PATH +export LIBRARY_PATH=/path/to/oneMKL/buildWithrocBLAS/lib:$LIBRARY_PATH +export CPLUS_INCLUDE_DIR=/path/to/oneMKL/buildWithrocBLAS/include:$CPLUS_INCLUDE_DIR + +# Build LLAMA with rocBLAS acceleration through SYCL + +## AMD +# Use FP32, FP16 is not supported +# Find your GGML_SYCL_DEVICE_ARCH with rocminfo, under the key 'Name:' +GGML_SYCL_DEVICE_ARCH=gfx90a # Example architecture +cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=AMD -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx + +# build all binary +cmake --build build --config Release -j -v ``` ### III. Run the inference @@ -323,7 +383,7 @@ found 2 SYCL devices: |Chosen Device ID|Setting| |-|-| -|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"` or no action| +|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:0"` or no action| |1|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"`| |0 & 1|`export ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"`| @@ -586,11 +646,12 @@ use 1 SYCL GPUs: [0] with Max compute units:512 #### Build -| Name | Value | Function | -|--------------------|-----------------------------------|---------------------------------------------| -| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.
FP32 path - recommended for better perforemance than FP16 on quantized model| -| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA | Set the SYCL target device type. | -| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. | +| Name | Value | Function | +|--------------------|---------------------------------------|---------------------------------------------| +| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.
FP32 path - recommended for better perforemance than FP16 on quantized model| +| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. | +| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. | +| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. | | CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. | | CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. | @@ -636,6 +697,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512 It's same for other projects including llama.cpp SYCL backend. +- Meet issue: `Native API failed. Native API returns: -6 (PI_ERROR_OUT_OF_HOST_MEMORY) -6 (PI_ERROR_OUT_OF_HOST_MEMORY) -999 (UNKNOWN PI error)` or `failed to allocate SYCL0 buffer` + + Device Memory is not enough. + + |Reason|Solution| + |-|-| + |Default Context is too big. It leads to more memory usage.|Set `-c 8192` or smaller value.| + |Model is big and require more memory than device's.|Choose smaller quantized model, like Q5 -> Q4;
Use more than one devices to load model.| ### **GitHub contribution**: Please add the **[SYCL]** prefix/tag in issues/PRs titles to help the SYCL-team check/address them without delay. diff --git a/docs/build.md b/docs/build.md index 152d46d6f..dd6495028 100644 --- a/docs/build.md +++ b/docs/build.md @@ -7,124 +7,75 @@ git clone https://github.com/ggerganov/llama.cpp cd llama.cpp ``` -In order to build llama.cpp you have four different options. +The following sections describe how to build with different backends and options. -- Using `make`: - - On Linux or MacOS: +## CPU Build - ```bash - make - ``` +Build llama.cpp using `CMake`: - - On Windows (x86/x64 only, arm64 requires cmake): +```bash +cmake -B build +cmake --build build --config Release +``` - 1. Download the latest fortran version of [w64devkit](https://github.com/skeeto/w64devkit/releases). - 2. Extract `w64devkit` on your pc. - 3. Run `w64devkit.exe`. - 4. Use the `cd` command to reach the `llama.cpp` folder. - 5. From here you can run: - ```bash - make - ``` +**Notes**: - - Notes: - - For `Q4_0_4_4` quantization type build, add the `GGML_NO_LLAMAFILE=1` flag. For example, use `make GGML_NO_LLAMAFILE=1`. - - For faster compilation, add the `-j` argument to run multiple jobs in parallel. For example, `make -j 8` will run 8 jobs in parallel. - - For faster repeated compilation, install [ccache](https://ccache.dev/). - - For debug builds, run `make LLAMA_DEBUG=1` +- For faster compilation, add the `-j` argument to run multiple jobs in parallel, or use a generator that does this automatically such as Ninja. For example, `cmake --build build --config Release -j 8` will run 8 jobs in parallel. +- For faster repeated compilation, install [ccache](https://ccache.dev/) +- For debug builds, there are two cases: -- Using `CMake`: + 1. Single-config generators (e.g. default = `Unix Makefiles`; note that they just ignore the `--config` flag): - ```bash - cmake -B build + ```bash + cmake -B build -DCMAKE_BUILD_TYPE=Debug + cmake --build build + ``` + + 2. Multi-config generators (`-G` param set to Visual Studio, XCode...): + + ```bash + cmake -B build -G "Xcode" + cmake --build build --config Debug + ``` + + For more details and a list of supported generators, see the [CMake documentation](https://cmake.org/cmake/help/latest/manual/cmake-generators.7.html). +- For static builds, add `-DBUILD_SHARED_LIBS=OFF`: + ``` + cmake -B build -DBUILD_SHARED_LIBS=OFF cmake --build build --config Release ``` - **Notes**: - - - For `Q4_0_4_4` quantization type build, add the `-DGGML_LLAMAFILE=OFF` cmake option. For example, use `cmake -B build -DGGML_LLAMAFILE=OFF`. - - For faster compilation, add the `-j` argument to run multiple jobs in parallel. For example, `cmake --build build --config Release -j 8` will run 8 jobs in parallel. - - For faster repeated compilation, install [ccache](https://ccache.dev/). - - For debug builds, there are two cases: - - 1. Single-config generators (e.g. default = `Unix Makefiles`; note that they just ignore the `--config` flag): +- Building for Windows (x86, x64 and arm64) with MSVC or clang as compilers: + - Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/de/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...): + - Tab Workload: Desktop-development with C++ + - Tab Components (select quickly via search): C++-_CMake_ Tools for Windows, _Git_ for Windows, C++-_Clang_ Compiler for Windows, MS-Build Support for LLVM-Toolset (clang) + - Please remember to always use a Developer Command Prompt / PowerShell for VS2022 for git, build, test + - For Windows on ARM (arm64, WoA) build with: + ```bash + cmake --preset arm64-windows-llvm-release -D GGML_OPENMP=OFF + cmake --build build-arm64-windows-llvm-release + ``` + Building for arm64 can also be done with the MSVC compiler with the build-arm64-windows-MSVC preset, or the standard CMake build instructions. However, note that the MSVC compiler does not support inline ARM assembly code, used e.g. for the accelerated Q4_0_N_M CPU kernels. + For building with ninja generator and clang compiler as default: + -set path:set LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\x64;C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.41.34120\lib\x64\uwp;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\x64 ```bash - cmake -B build -DCMAKE_BUILD_TYPE=Debug - cmake --build build + cmake --preset x64-windows-llvm-release + cmake --build build-x64-windows-llvm-release ``` - 2. Multi-config generators (`-G` param set to Visual Studio, XCode...): - - ```bash - cmake -B build -G "Xcode" - cmake --build build --config Debug - ``` - - Building for Windows (x86, x64 and arm64) with MSVC or clang as compilers: - - Install Visual Studio 2022, e.g. via the [Community Edition](https://visualstudio.microsoft.com/de/vs/community/). In the installer, select at least the following options (this also automatically installs the required additional tools like CMake,...): - - Tab Workload: Desktop-development with C++ - - Tab Components (select quickly via search): C++-_CMake_ Tools for Windows, _Git_ for Windows, C++-_Clang_ Compiler for Windows, MS-Build Support for LLVM-Toolset (clang) - - Please remember to always use a Developer Command Prompt / PowerShell for VS2022 for git, build, test - - For Windows on ARM (arm64, WoA) build with: - ```bash - cmake --preset arm64-windows-llvm-release -D GGML_OPENMP=OFF - cmake --build build-arm64-windows-llvm-release - ``` - Note: Building for arm64 could also be done just with MSVC (with the build-arm64-windows-MSVC preset, or the standard CMake build instructions). But MSVC does not support inline ARM assembly-code, used e.g. for the accelerated Q4_0_4_8 CPU kernels. - -- Using `gmake` (FreeBSD): - - 1. Install and activate [DRM in FreeBSD](https://wiki.freebsd.org/Graphics) - 2. Add your user to **video** group - 3. Install compilation dependencies. - - ```bash - sudo pkg install gmake automake autoconf pkgconf llvm15 openblas - - gmake CC=/usr/local/bin/clang15 CXX=/usr/local/bin/clang++15 -j4 - ``` - -## Metal Build - -On MacOS, Metal is enabled by default. Using Metal makes the computation run on the GPU. -To disable the Metal build at compile time use the `GGML_NO_METAL=1` flag or the `GGML_METAL=OFF` cmake option. - -When built with Metal support, you can explicitly disable GPU inference with the `--n-gpu-layers|-ngl 0` command-line -argument. - ## BLAS Build -Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). Support with CPU-only BLAS implementations doesn't affect the normal generation performance. We may see generation performance improvements with GPU-involved BLAS implementations, e.g. cuBLAS, hipBLAS. There are currently several different BLAS implementations available for build and use: +Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). Using BLAS doesn't affect the generation performance. There are currently several different BLAS implementations available for build and use: -### Accelerate Framework: +### Accelerate Framework This is only available on Mac PCs and it's enabled by default. You can just build using the normal instructions. -### OpenBLAS: +### OpenBLAS This provides BLAS acceleration using only the CPU. Make sure to have OpenBLAS installed on your machine. -- Using `make`: - - On Linux: - ```bash - make GGML_OPENBLAS=1 - ``` - - - On Windows: - - 1. Download the latest fortran version of [w64devkit](https://github.com/skeeto/w64devkit/releases). - 2. Download the latest version of [OpenBLAS for Windows](https://github.com/xianyi/OpenBLAS/releases). - 3. Extract `w64devkit` on your pc. - 4. From the OpenBLAS zip that you just downloaded copy `libopenblas.a`, located inside the `lib` folder, inside `w64devkit\x86_64-w64-mingw32\lib`. - 5. From the same OpenBLAS zip copy the content of the `include` folder inside `w64devkit\x86_64-w64-mingw32\include`. - 6. Run `w64devkit.exe`. - 7. Use the `cd` command to reach the `llama.cpp` folder. - 8. From here you can run: - - ```bash - make GGML_OPENBLAS=1 - ``` - - Using `CMake` on Linux: ```bash @@ -136,14 +87,6 @@ This provides BLAS acceleration using only the CPU. Make sure to have OpenBLAS i Check [BLIS.md](./backend/BLIS.md) for more information. -### SYCL - -SYCL is a higher-level programming model to improve programming productivity on various hardware accelerators. - -llama.cpp based on SYCL is used to **support Intel GPU** (Data Center Max series, Flex series, Arc series, Built-in GPU and iGPU). - -For detailed info, please refer to [llama.cpp for SYCL](./backend/SYCL.md). - ### Intel oneMKL Building through oneAPI compilers will make avx_vnni instruction set available for intel processors that do not support avx512 and avx512_vnni. Please note that this build config **does not support Intel GPU**. For Intel GPU support, please refer to [llama.cpp for SYCL](./backend/SYCL.md). @@ -161,16 +104,31 @@ Building through oneAPI compilers will make avx_vnni instruction set available f Check [Optimizing and Running LLaMA2 on Intel® CPU](https://www.intel.com/content/www/us/en/content-details/791610/optimizing-and-running-llama2-on-intel-cpu.html) for more information. -### CUDA +### Other BLAS libraries -This provides GPU acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). +Any other BLAS library can be used by setting the `GGML_BLAS_VENDOR` option. See the [CMake documentation](https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors) for a list of supported vendors. -For Jetson user, if you have Jetson Orin, you can try this: [Offical Support](https://www.jetson-ai-lab.com/tutorial_text-generation.html). If you are using an old model(nano/TX2), need some additional operations before compiling. +## Metal Build + +On MacOS, Metal is enabled by default. Using Metal makes the computation run on the GPU. +To disable the Metal build at compile time use the `-DGGML_METAL=OFF` cmake option. + +When built with Metal support, you can explicitly disable GPU inference with the `--n-gpu-layers 0` command-line argument. + +## SYCL + +SYCL is a higher-level programming model to improve programming productivity on various hardware accelerators. + +llama.cpp based on SYCL is used to **support Intel GPU** (Data Center Max series, Flex series, Arc series, Built-in GPU and iGPU). + +For detailed info, please refer to [llama.cpp for SYCL](./backend/SYCL.md). + +## CUDA + +This provides GPU acceleration using an NVIDIA GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from the [NVIDIA developer site](https://developer.nvidia.com/cuda-downloads). + +If you are using Fedora (using Fedora Workstation, or an 'Atomic' variant such as Silverblue), or would like to set up CUDA in a toolbox, please consider our [Fedora CUDA guide](./cuda-fedora.md). Unfortunately, the process is not as simple as one might expect. -- Using `make`: - ```bash - make GGML_CUDA=1 - ``` - Using `CMake`: ```bash @@ -186,22 +144,16 @@ The following compilation options are also available to tweak performance: | Option | Legal values | Default | Description | |-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| GGML_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | -| GGML_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | -| GGML_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | | GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | | GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | -| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | | GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. | -### MUSA +## MUSA + +This provides GPU acceleration using the MUSA cores of your Moore Threads MTT GPU. Make sure to have the MUSA SDK installed. You can download it from here: [MUSA SDK](https://developer.mthreads.com/sdk/download/musa). -- Using `make`: - ```bash - make GGML_MUSA=1 - ``` - Using `CMake`: ```bash @@ -209,20 +161,22 @@ The following compilation options are also available to tweak performance: cmake --build build --config Release ``` -### hipBLAS +The environment variable [`MUSA_VISIBLE_DEVICES`](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) can be used to specify which GPU(s) will be used. -This provides BLAS acceleration on HIP-supported AMD GPUs. +The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. + +Most of the compilation options available for CUDA should also be available for MUSA, though they haven't been thoroughly tested yet. + +## HIP + +This provides GPU acceleration on HIP-supported AMD GPUs. Make sure to have ROCm installed. You can download it from your Linux distro's package manager or from here: [ROCm Quick Start (Linux)](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/tutorial/quick-start.html#rocm-install-quick). -- Using `make`: - ```bash - make GGML_HIPBLAS=1 - ``` - Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU): ```bash HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ - cmake -S . -B build -DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build --config Release -- -j 16 ``` On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DGGML_HIP_UMA=ON`. @@ -239,19 +193,14 @@ You can download it from your Linux distro's package manager or from here: [ROCm ```bash HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \ HIP_DEVICE_LIB_PATH= \ - cmake -S . -B build -DGGML_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build -- -j 16 ``` -- Using `make` (example for target gfx1030, build with 16 CPU threads): - ```bash - make -j16 GGML_HIPBLAS=1 GGML_HIP_UMA=1 AMDGPU_TARGETS=gfx1030 - ``` - - Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU): ```bash set PATH=%HIP_PATH%\bin;%PATH% - cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIPBLAS=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release + cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release cmake --build build ``` Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors) @@ -260,23 +209,16 @@ You can download it from your Linux distro's package manager or from here: [ROCm The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used. If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. -The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above): -| Option | Legal values | Default | Description | -|------------------------|------------------------|---------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| GGML_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the HIP dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | -| GGML_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the HIP mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | -| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per HIP thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | - -### Vulkan +## Vulkan **Windows** -#### w64devkit +### w64devkit -Download and extract [w64devkit](https://github.com/skeeto/w64devkit/releases). +Download and extract [`w64devkit`](https://github.com/skeeto/w64devkit/releases). -Download and install the [Vulkan SDK](https://vulkan.lunarg.com/sdk/home#windows). When selecting components, only the Vulkan SDK Core is required. +Download and install the [`Vulkan SDK`](https://vulkan.lunarg.com/sdk/home#windows) with the default settings. Launch `w64devkit.exe` and run the following commands to copy Vulkan dependencies: ```sh @@ -292,18 +234,47 @@ Libs: -lvulkan-1 EOF ``` -Switch into the `llama.cpp` directory and run `make GGML_VULKAN=1`. -#### MSYS2 +Switch into the `llama.cpp` directory and build using CMake. +```sh +cmake -B build -DGGML_VULKAN=ON +cmake --build build --config Release +``` + +### Git Bash MINGW64 + +Download and install [`Git-SCM`](https://git-scm.com/downloads/win) with the default settings + +Download and install [`Visual Studio Community Edition`](https://visualstudio.microsoft.com/) and make sure you select `C++` + +Download and install [`CMake`](https://cmake.org/download/) with the default settings + +Download and install the [`Vulkan SDK`](https://vulkan.lunarg.com/sdk/home#windows) with the default settings. + +Go into your `llama.cpp` directory and right click, select `Open Git Bash Here` and then run the following commands + +``` +cmake -B build -DGGML_VULKAN=ON +cmake --build build --config Release +``` + +Now you can load the model in conversation mode using `Vulkan` + +```sh +build/bin/Release/llama-cli -m "[PATH TO MODEL]" -ngl 100 -c 16384 -t 10 -n -2 -cnv +``` + +### MSYS2 Install [MSYS2](https://www.msys2.org/) and then run the following commands in a UCRT terminal to install dependencies. - ```sh - pacman -S git \ - mingw-w64-ucrt-x86_64-gcc \ - mingw-w64-ucrt-x86_64-cmake \ - mingw-w64-ucrt-x86_64-vulkan-devel \ - mingw-w64-ucrt-x86_64-shaderc - ``` -Switch into `llama.cpp` directory and build using CMake. +```sh +pacman -S git \ + mingw-w64-ucrt-x86_64-gcc \ + mingw-w64-ucrt-x86_64-cmake \ + mingw-w64-ucrt-x86_64-vulkan-devel \ + mingw-w64-ucrt-x86_64-shaderc +``` + +Switch into the `llama.cpp` directory and build using CMake. ```sh cmake -B build -DGGML_VULKAN=ON cmake --build build --config Release @@ -315,7 +286,7 @@ You don't need to install Vulkan SDK. It will be installed inside the container. ```sh # Build the image -docker build -t llama-cpp-vulkan -f .devops/llama-cli-vulkan.Dockerfile . +docker build -t llama-cpp-vulkan --target light -f .devops/vulkan.Dockerfile . # Then, use it: docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-vulkan -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 @@ -352,7 +323,7 @@ cmake --build build --config Release # ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32 ``` -### CANN +## CANN This provides NPU acceleration using the AI cores of your Ascend NPU. And [CANN](https://www.hiascend.com/en/software/cann) is a hierarchical APIs to help you to quickly build AI applications and service based on Ascend NPU. For more information about Ascend NPU in [Ascend Community](https://www.hiascend.com/en/). @@ -367,16 +338,26 @@ cmake --build build --config release You can test with: -`./build/llama-cli -m PATH_TO_MODEL -p "Building a website can be done in 10 steps:" -ngl 32` - -If the fllowing info is output on screen, you are using `llama.cpp by CANN backend`: ```bash -llm_load_tensors: CANN buffer size = 13313.00 MiB +./build/bin/llama-cli -m PATH_TO_MODEL -p "Building a website can be done in 10 steps:" -ngl 32 +``` + +If the following info is output on screen, you are using `llama.cpp` with the CANN backend: +```bash +llm_load_tensors: CANN model buffer size = 13313.00 MiB llama_new_context_with_model: CANN compute buffer size = 1260.81 MiB ``` For detailed info, such as model/device supports, CANN install, please refer to [llama.cpp for CANN](./backend/CANN.md). -### Android +## Android To read documentation for how to build on Android, [click here](./android.md) + +## Notes about GPU-accelerated backends + +The GPU may still be used to accelerate some parts of the computation even when using the `-ngl 0` option. You can fully disable GPU acceleration by using `--device none`. + +In most cases, it is possible to build and use multiple backends at the same time. For example, you can build llama.cpp with both CUDA and Vulkan support by using the `-DGGML_CUDA=ON -DGGML_VULKAN=ON` options with CMake. At runtime, you can specify which backend devices to use with the `--device` option. To see a list of available devices, use the `--list-devices` option. + +Backends can be built as dynamic libraries that can be loaded dynamically at runtime. This allows you to use the same llama.cpp binary on different machines with different GPUs. To enable this feature, use the `GGML_BACKEND_DL` option when building. diff --git a/docs/cuda-fedora.md b/docs/cuda-fedora.md new file mode 100644 index 000000000..b993386c8 --- /dev/null +++ b/docs/cuda-fedora.md @@ -0,0 +1,317 @@ +# Setting Up CUDA on Fedora + +In this guide we setup [Nvidia CUDA](https://docs.nvidia.com/cuda/) in a toolbox container. This guide is applicable for: +- [Fedora Workstation](https://fedoraproject.org/workstation/) +- [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/) +- [Fedora Spins](https://fedoraproject.org/spins) +- [Other Distributions](https://containertoolbx.org/distros/), including `Red Hat Enterprise Linux >= 8.`, `Arch Linux`, and `Ubuntu`. + + +## Table of Contents + +- [Prerequisites](#prerequisites) +- [Monitoring NVIDIA CUDA Repositories](#monitoring-nvidia-cuda-repositories) +- [Using the Fedora 39 CUDA Repository](#using-the-fedora-39-cuda-repository) +- [Creating a Fedora Toolbox Environment](#creating-a-fedora-toolbox-environment) +- [Installing Essential Development Tools](#installing-essential-development-tools) +- [Adding the CUDA Repository](#adding-the-cuda-repository) +- [Installing `nvidia-driver-libs`](#installing-nvidia-driver-libs) +- [Manually Resolving Package Conflicts](#manually-resolving-package-conflicts) +- [Finalizing the Installation of `nvidia-driver-libs`](#finalizing-the-installation-of-nvidia-driver-libs) +- [Installing the CUDA Meta-Package](#installing-the-cuda-meta-package) +- [Configuring the Environment](#configuring-the-environment) +- [Verifying the Installation](#verifying-the-installation) +- [Conclusion](#conclusion) +- [Troubleshooting](#troubleshooting) +- [Additional Notes](#additional-notes) +- [References](#references) + +## Prerequisites + +- **Toolbox Installed on the Host System** `Fedora Silverblue` and `Fedora Workstation` both have toolbox by default, other distributions may need to install the [toolbox package](https://containertoolbx.org/install/). +- **NVIDIA Drivers and Graphics Card installed on Host System (optional)** To run CUDA program, such as `llama.cpp`, the host should be setup to access your NVIDIA hardware. Fedora Hosts can use the [RPM Fusion Repository](https://rpmfusion.org/Howto/NVIDIA). +- **Internet connectivity** to download packages. + +### Monitoring NVIDIA CUDA Repositories + +Before proceeding, it is advisable to check if NVIDIA has updated their CUDA repositories for your Fedora version. NVIDIA's repositories can be found at: + +- [Fedora 40 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora40/x86_64/) +- [Fedora 41 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora41/x86_64/) + +As of the latest update, these repositories do not contain the `cuda` meta-package or are missing essential components. + +### Using the Fedora 39 CUDA Repository + +Since the newer repositories are incomplete, we'll use the Fedora 39 repository: + +- [Fedora 39 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora39/x86_64/) + +**Note:** Fedora 39 is no longer maintained, so we recommend using a toolbox environment to prevent system conflicts. + +## Creating a Fedora Toolbox Environment + +This guide focuses on Fedora hosts, but with small adjustments, it can work for other hosts. Using a Fedora 39 toolbox allows us to install the necessary packages without affecting the host system. + +**Note:** Toolbox is available for other systems, and even without Toolbox, it is possible to use Podman or Docker. + +We do not recommend installing on the host system, as Fedora 39 is out-of-maintenance, and instead you should upgrade to a maintained version of Fedora for your host. + +1. **Create a Fedora 39 Toolbox:** + + ```bash + toolbox create --image registry.fedoraproject.org/fedora-toolbox:39 --container fedora-toolbox-39-cuda + ``` + +2. **Enter the Toolbox:** + + ```bash + toolbox enter --container fedora-toolbox-39-cuda + ``` + + Inside the toolbox, you have root privileges and can install packages without affecting the host system. + +## Installing Essential Development Tools + +1. **Synchronize the DNF Package Manager:** + + ```bash + sudo dnf distro-sync + ``` + +2. **Install the Default Text Editor (Optional):** + + ```bash + sudo dnf install vim-default-editor --allowerasing + ``` + + The `--allowerasing` flag resolves any package conflicts. + +3. **Install Development Tools and Libraries:** + + ```bash + sudo dnf install @c-development @development-tools cmake + ``` + + This installs essential packages for compiling software, including `gcc`, `make`, and other development headers. + +## Adding the CUDA Repository + +Add the NVIDIA CUDA repository to your DNF configuration: + +```bash +sudo dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/fedora39/x86_64/cuda-fedora39.repo +``` + +After adding the repository, synchronize the package manager again: + +```bash +sudo dnf distro-sync +``` + +## Installing `nvidia-driver-libs` + +Attempt to install `nvidia-driver-libs`: + +```bash +sudo dnf install nvidia-driver-libs +``` + +**Explanation:** + +- `nvidia-driver-libs` contains necessary NVIDIA driver libraries required by CUDA. +- This step might fail due to conflicts with existing NVIDIA drivers on the host system. + +## Manually Resolving Package Conflicts + +If the installation fails due to conflicts, we'll manually download and install the required packages, excluding conflicting files. + +### 1. Download the `nvidia-driver-libs` RPM + +```bash +sudo dnf download --arch x86_64 nvidia-driver-libs +``` + +You should see a file similar to: + +``` +nvidia-driver-libs-560.35.05-1.fc39.x86_64.rpm +``` + +### 2. Attempt to Install the RPM + +```bash +sudo dnf install nvidia-driver-libs-560.35.05-1.fc39.x86_64.rpm +``` + +**Expected Error:** + +Installation may fail with errors pointing to conflicts with `egl-gbm` and `egl-wayland`. + +**Note: It is important to carefully read the error messages to identify the exact paths that need to be excluded.** + +### 3. Download Dependencies + +```bash +sudo dnf download --arch x86_64 egl-gbm egl-wayland +``` + +### 4. Install `egl-gbm` with Excluded Paths + +Exclude conflicting files during installation: + +```bash +sudo rpm --install --verbose --hash \ + --excludepath=/usr/lib64/libnvidia-egl-gbm.so.1.1.2 \ + --excludepath=/usr/share/egl/egl_external_platform.d/15_nvidia_gbm.json \ + egl-gbm-1.1.2^20240919gitb24587d-3.fc39.x86_64.rpm +``` + +**Explanation:** + +- The `--excludepath` option skips installing files that conflict with existing files. +- Adjust the paths based on the error messages you receive. + +### 5. Install `egl-wayland` with Excluded Paths + +```bash +sudo rpm --install --verbose --hash \ + --excludepath=/usr/share/egl/egl_external_platform.d/10_nvidia_wayland.json \ + egl-wayland-1.1.17^20241118giteeb29e1-5.fc39.x86_64.rpm +``` + +### 6. Install `nvidia-driver-libs` with Excluded Paths + +```bash +sudo rpm --install --verbose --hash \ + --excludepath=/usr/share/glvnd/egl_vendor.d/10_nvidia.json \ + --excludepath=/usr/share/nvidia/nvoptix.bin \ + nvidia-driver-libs-560.35.05-1.fc39.x86_64.rpm +``` + +**Note:** + +- Replace the paths with the ones causing conflicts in your installation if they differ. +- The `--verbose` and `--hash` options provide detailed output during installation. + +## Finalizing the Installation of `nvidia-driver-libs` + +After manually installing the dependencies, run: + +```bash +sudo dnf install nvidia-driver-libs +``` + +You should receive a message indicating the package is already installed: + +``` +Package nvidia-driver-libs-3:560.35.05-1.fc39.x86_64 is already installed. +Dependencies resolved. +Nothing to do. +Complete! +``` + +## Installing the CUDA Meta-Package + +Now that the driver libraries are installed, proceed to install CUDA: + +```bash +sudo dnf install cuda +``` + +This installs the CUDA toolkit and associated packages. + +## Configuring the Environment + +To use CUDA, add its binary directory to your system's `PATH`. + +1. **Create a Profile Script:** + + ```bash + sudo sh -c 'echo "export PATH=\$PATH:/usr/local/cuda/bin" >> /etc/profile.d/cuda.sh' + ``` + + **Explanation:** + + - We add to `/etc/profile.d/` as the `/etc/` folder is unique to this particular container, and is not shared with other containers or the host system. + - The backslash `\` before `$PATH` ensures the variable is correctly written into the script. + +2. **Make the Script Executable:** + + ```bash + sudo chmod +x /etc/profile.d/cuda.sh + ``` + +3. **Source the Script to Update Your Environment:** + + ```bash + source /etc/profile.d/cuda.sh + ``` + + **Note:** This command updates your current shell session with the new `PATH`. The `/etc/profile.d/cuda.sh` script ensures that the CUDA binaries are available in your `PATH` for all future sessions. + +## Verifying the Installation + +To confirm that CUDA is correctly installed and configured, check the version of the NVIDIA CUDA Compiler (`nvcc`): + +```bash +nvcc --version +``` + +You should see output similar to: + +``` +nvcc: NVIDIA (R) Cuda compiler driver +Copyright (c) 2005-2024 NVIDIA Corporation +Built on Tue_Oct_29_23:50:19_PDT_2024 +Cuda compilation tools, release 12.6, V12.6.85 +Build cuda_12.6.r12.6/compiler.35059454_0 +``` + +This output confirms that the CUDA compiler is accessible and indicates the installed version. + +## Conclusion + +You have successfully set up CUDA on Fedora within a toolbox environment using the Fedora 39 CUDA repository. By manually resolving package conflicts and configuring the environment, you can develop CUDA applications without affecting your host system. + +## Troubleshooting + +- **Installation Failures:** + - If you encounter errors during installation, carefully read the error messages. They often indicate conflicting files or missing dependencies. + - Use the `--excludepath` option with `rpm` to exclude conflicting files during manual installations. + +- **Driver Conflicts:** + - Since the host system may already have NVIDIA drivers installed, conflicts can arise. Using the toolbox environment helps isolate these issues. + +- **Environment Variables Not Set:** + - If `nvcc` is not found after installation, ensure that `/usr/local/cuda/bin` is in your `PATH`. + - Run `echo $PATH` to check if the path is included. + - Re-source the profile script or open a new terminal session. + +## Additional Notes + +- **Updating CUDA in the Future:** + - Keep an eye on the official NVIDIA repositories for updates to your Fedora version. + - When an updated repository becomes available, adjust your `dnf` configuration accordingly. + +- **Building `llama.cpp`:** + - With CUDA installed, you can follow these [build instructions for `llama.cpp`](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) to compile it with CUDA support. + - Ensure that any CUDA-specific build flags or paths are correctly set in your build configuration. + +- **Using the Toolbox Environment:** + - The toolbox environment is isolated from your host system, which helps prevent conflicts. + - Remember that system files and configurations inside the toolbox are separate from the host. By default the home directory of the user is shared between the host and the toolbox. + +--- + +**Disclaimer:** Manually installing and modifying system packages can lead to instability of the container. The above steps are provided as a guideline and may need adjustments based on your specific system configuration. Always back up important data before making significant system changes, especially as your home folder is writable and shared with he toolbox. + +**Acknowledgments:** Special thanks to the Fedora community and NVIDIA documentation for providing resources that assisted in creating this guide. + +## References + +- [Fedora Toolbox Documentation](https://docs.fedoraproject.org/en-US/fedora-silverblue/toolbox/) +- [NVIDIA CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) +- [Podman Documentation](https://podman.io/get-started) + +--- diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 04c5ccbbe..8fcd70811 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -28,7 +28,7 @@ The required steps to implement for an HF model are: ```python @Model.register("MyModelForCausalLM") class MyModel(Model): - model_arch = gguf.MODEL_ARCH.GROK + model_arch = gguf.MODEL_ARCH.MYMODEL ``` 2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py) @@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi - `Model#set_vocab` - `Model#write_tensors` -NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights. +NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights. ### 2. Define the model architecture in `llama.cpp` The model params and tensors layout must be defined in `llama.cpp`: 1. Define a new `llm_arch` 2. Define the tensors layout in `LLM_TENSOR_NAMES` -3. Add any non standard metadata in `llm_load_hparams` +3. Add any non-standard metadata in `llm_load_hparams` 4. Create the tensors for inference in `llm_load_tensors` 5. If the model has a RoPE operation, add the rope type in `llama_rope_type` @@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`. -Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`. +Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`. -When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR. +Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR. Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/). diff --git a/docs/docker.md b/docs/docker.md index e8a084173..dac9a9ec1 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -19,8 +19,11 @@ Additionally, there the following images, similar to the above: - `ghcr.io/ggerganov/llama.cpp:full-rocm`: Same as `full` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) - `ghcr.io/ggerganov/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) - `ghcr.io/ggerganov/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`) +- `ghcr.io/ggerganov/llama.cpp:full-musa`: Same as `full` but compiled with MUSA support. (platforms: `linux/amd64`) +- `ghcr.io/ggerganov/llama.cpp:light-musa`: Same as `light` but compiled with MUSA support. (platforms: `linux/amd64`) +- `ghcr.io/ggerganov/llama.cpp:server-musa`: Same as `server` but compiled with MUSA support. (platforms: `linux/amd64`) -The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA or ROCm library, you'll need to build the images locally for now). +The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now). ## Usage @@ -57,9 +60,9 @@ Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia ## Building Docker locally ```bash -docker build -t local/llama.cpp:full-cuda -f .devops/full-cuda.Dockerfile . -docker build -t local/llama.cpp:light-cuda -f .devops/llama-cli-cuda.Dockerfile . -docker build -t local/llama.cpp:server-cuda -f .devops/llama-server-cuda.Dockerfile . +docker build -t local/llama.cpp:full-cuda --target full -f .devops/cuda.Dockerfile . +docker build -t local/llama.cpp:light-cuda --target light -f .devops/cuda.Dockerfile . +docker build -t local/llama.cpp:server-cuda --target server -f .devops/cuda.Dockerfile . ``` You may want to pass in some different `ARGS`, depending on the CUDA environment supported by your container host, as well as the GPU architecture. @@ -84,3 +87,37 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:full-cuda --run docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 docker run --gpus all -v /path/to/models:/models local/llama.cpp:server-cuda -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 --n-gpu-layers 1 ``` + +## Docker With MUSA + +Assuming one has the [mt-container-toolkit](https://developer.mthreads.com/musa/native) properly installed on Linux, `muBLAS` should be accessible inside the container. + +## Building Docker locally + +```bash +docker build -t local/llama.cpp:full-musa --target full -f .devops/musa.Dockerfile . +docker build -t local/llama.cpp:light-musa --target light -f .devops/musa.Dockerfile . +docker build -t local/llama.cpp:server-musa --target server -f .devops/musa.Dockerfile . +``` + +You may want to pass in some different `ARGS`, depending on the MUSA environment supported by your container host, as well as the GPU architecture. + +The defaults are: + +- `MUSA_VERSION` set to `rc3.1.0` + +The resulting images, are essentially the same as the non-MUSA images: + +1. `local/llama.cpp:full-musa`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. +2. `local/llama.cpp:light-musa`: This image only includes the main executable file. +3. `local/llama.cpp:server-musa`: This image only includes the server executable file. + +## Usage + +After building locally, Usage is similar to the non-MUSA examples, but you'll need to set `mthreads` as default Docker runtime. This can be done by executing `(cd /usr/bin/musa && sudo ./docker setup $PWD)` and verifying the changes by executing `docker info | grep mthreads` on the host machine. You will also want to use the `--n-gpu-layers` flag. + +```bash +docker run -v /path/to/models:/models local/llama.cpp:full-musa --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 +docker run -v /path/to/models:/models local/llama.cpp:light-musa -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 +docker run -v /path/to/models:/models local/llama.cpp:server-musa -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 --n-gpu-layers 1 +``` diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 67b3d2774..66cfab2c3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -6,22 +6,26 @@ find_package(Threads REQUIRED) # ... +# flags + +llama_add_compile_flags() + # examples include_directories(${CMAKE_CURRENT_SOURCE_DIR}) if (EMSCRIPTEN) else() - add_subdirectory(cvector-generator) - add_subdirectory(baby-llama) add_subdirectory(batched-bench) add_subdirectory(batched) - add_subdirectory(benchmark) - add_subdirectory(convert-llama2c-to-ggml) add_subdirectory(embedding) add_subdirectory(eval-callback) - add_subdirectory(export-lora) - add_subdirectory(gbnf-validator) + + if (NOT WIN32) + # disabled on Windows because it uses internal functions not exported with LLAMA_API + add_subdirectory(gbnf-validator) + endif() + add_subdirectory(gguf-hash) add_subdirectory(gguf-split) add_subdirectory(gguf) @@ -29,27 +33,41 @@ else() add_subdirectory(imatrix) add_subdirectory(infill) add_subdirectory(llama-bench) - add_subdirectory(llava) add_subdirectory(lookahead) add_subdirectory(lookup) add_subdirectory(main) add_subdirectory(parallel) add_subdirectory(passkey) add_subdirectory(perplexity) - add_subdirectory(quantize-stats) add_subdirectory(quantize) add_subdirectory(retrieval) - if (GGML_RPC) - add_subdirectory(rpc) - endif() if (LLAMA_BUILD_SERVER) - add_subdirectory(server) - endif() - if (GGML_SYCL) - add_subdirectory(sycl) + add_subdirectory(server) endif() add_subdirectory(save-load-state) + add_subdirectory(run) add_subdirectory(simple) + add_subdirectory(simple-chat) add_subdirectory(speculative) + add_subdirectory(speculative-simple) add_subdirectory(tokenize) + add_subdirectory(tts) + add_subdirectory(gen-docs) + if (NOT GGML_BACKEND_DL) + # these examples use the backends directly and cannot be built with dynamic loading + add_subdirectory(convert-llama2c-to-ggml) + add_subdirectory(cvector-generator) + add_subdirectory(export-lora) + if (NOT WIN32) + # disabled on Windows because it uses internal functions not exported with LLAMA_API + add_subdirectory(quantize-stats) + endif() + add_subdirectory(llava) + if (GGML_RPC) + add_subdirectory(rpc) + endif() + if (GGML_SYCL) + add_subdirectory(sycl) + endif() + endif() endif() diff --git a/examples/baby-llama/CMakeLists.txt b/examples/baby-llama/CMakeLists.txt deleted file mode 100644 index 71b82105c..000000000 --- a/examples/baby-llama/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -set(TARGET llama-baby-llama) -add_executable(${TARGET} baby-llama.cpp) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp deleted file mode 100644 index 3ce91070b..000000000 --- a/examples/baby-llama/baby-llama.cpp +++ /dev/null @@ -1,1639 +0,0 @@ -#include "ggml.h" -#include "train.h" - -#include -#include -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -#ifdef LLAMA_DEFAULT_RMS_EPS -constexpr float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; -#else -constexpr float rms_norm_eps = 5e-6f; -#endif - -static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { - struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr); - - if (plan.work_size > 0) { - buf.resize(plan.work_size); - plan.work_data = buf.data(); - } - - ggml_graph_compute(graph, &plan); -} - -static struct ggml_tensor * randomize_tensor( - struct ggml_tensor * tensor, int ndims, const int64_t ne[], float fmin, float fmax -) { - switch (ndims) { - case 1: - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i0] = frand()*(fmax - fmin) + fmin; - } - break; - case 2: - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } - } - break; - case 3: - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } - } - } - break; - case 4: - for (int i3 = 0; i3 < ne[3]; i3++) { - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)tensor->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } - } - } - } - break; - default: - assert(false); - } - - return tensor; -} - -struct llama_hparams { - uint32_t n_vocab = 32000; - uint32_t n_ctx = 512; // this is provided as user input? - uint32_t n_embd = 4096; - uint32_t n_mult = 4; - uint32_t n_head = 32; - uint32_t n_layer = 32; - uint32_t n_rot = 64; - - bool operator!=(const llama_hparams & other) const { - return memcmp(this, &other, sizeof(llama_hparams)); - } -}; - -static uint32_t get_n_ff(const struct llama_hparams* hparams) { - const uint32_t n_ff = ((2*(4*hparams->n_embd)/3 + hparams->n_mult - 1)/hparams->n_mult)*hparams->n_mult; - return n_ff; -} - -struct llama_hparams_lora { - uint32_t n_vocab = 32000; - uint32_t n_ctx = 512; // this is provided as user input? - uint32_t n_embd = 4096; - uint32_t n_mult = 4; - uint32_t n_head = 32; - uint32_t n_layer = 32; - uint32_t n_rot = 64; - uint32_t n_lora = 64; - - bool operator!=(const llama_hparams_lora & other) const { - return memcmp(this, &other, sizeof(llama_hparams_lora)) != 0; - } -}; - -struct llama_layer { - // normalization - struct ggml_tensor * attention_norm; - - // attention - struct ggml_tensor * wq; - struct ggml_tensor * wk; - struct ggml_tensor * wv; - struct ggml_tensor * wo; - - // normalization - struct ggml_tensor * ffn_norm; - - // ff - struct ggml_tensor * w1; - struct ggml_tensor * w2; - struct ggml_tensor * w3; -}; - -struct llama_layer_lora { - // normalization - struct ggml_tensor * attention_norm; - - // attention - struct ggml_tensor * wqa; - struct ggml_tensor * wqb; - struct ggml_tensor * wka; - struct ggml_tensor * wkb; - struct ggml_tensor * wva; - struct ggml_tensor * wvb; - struct ggml_tensor * woa; - struct ggml_tensor * wob; - - // normalization - struct ggml_tensor * ffn_norm; - - // ff - struct ggml_tensor * w1; - struct ggml_tensor * w2; - struct ggml_tensor * w3; -}; - - -struct llama_kv_cache { - struct ggml_context * ctx = NULL; - - struct ggml_tensor * k; - struct ggml_tensor * v; - - // llama_ctx_buffer buf; - - int n; // number of tokens currently in the cache -}; - -struct llama_model { - struct ggml_context * ctx = NULL; - - llama_hparams hparams; - - struct ggml_tensor * tok_embeddings; - - struct ggml_tensor * norm; - struct ggml_tensor * output; - - std::vector layers; -}; - -struct llama_model_lora { - struct ggml_context * ctx = NULL; - - llama_hparams_lora hparams; - - struct ggml_tensor * tok_embeddings; - - struct ggml_tensor * norm; - struct ggml_tensor * outputa; - struct ggml_tensor * outputb; - - std::vector layers; -}; - -static void init_model(struct llama_model * model) { - const auto & hparams = model->hparams; - - const uint32_t n_embd = hparams.n_embd; - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_vocab = hparams.n_vocab; - - const uint32_t n_ff = get_n_ff(&hparams); - - struct ggml_context * ctx = model->ctx; - - model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); // ("tok_embeddings.weight", {n_embd, n_vocab}); - model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // ("norm.weight", {n_embd}); - model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); // ("output.weight", {n_embd, n_vocab}); - - model->layers.resize(n_layer); - for (uint32_t i = 0; i < n_layer; ++i) { - auto & layer = model->layers[i]; - - // std::string layers_i = "layers." + std::to_string(i); - - layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // (layers_i + ".attention_norm.weight", {n_embd}); - - layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); // (layers_i + ".attention.wq.weight", {n_embd, n_embd}); - layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); // (layers_i + ".attention.wk.weight", {n_embd, n_embd}); - layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); // (layers_i + ".attention.wv.weight", {n_embd, n_embd}); - layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); // (layers_i + ".attention.wo.weight", {n_embd, n_embd}); - - layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // (layers_i + ".ffn_norm.weight", {n_embd}); - - layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); // (layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}); - layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); // (layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}); - layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); // (layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}); - } -} - - -static void init_model_lora(struct llama_model_lora * model) { - const auto & hparams = model->hparams; - - const uint32_t n_embd = hparams.n_embd; - const uint32_t n_mult = hparams.n_mult; - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_vocab = hparams.n_vocab; - const uint32_t n_lora = hparams.n_lora; - - const uint32_t n_ff = ((2*(4*n_embd)/3 + n_mult - 1)/n_mult)*n_mult; - - struct ggml_context * ctx = model->ctx; - - model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); // ("tok_embeddings.weight", {n_embd, n_vocab}); - model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // ("norm.weight", {n_embd}); - model->outputa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_vocab); // ("output.weight", {n_embd, n_vocab}); - model->outputb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // ("output.weight", {n_embd, n_vocab}); - - model->layers.resize(n_layer); - for (uint32_t i = 0; i < n_layer; ++i) { - auto & layer = model->layers[i]; - - // std::string layers_i = "layers." + std::to_string(i); - - layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // (layers_i + ".attention_norm.weight", {n_embd}); - - layer.wqa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wq.weight", {n_embd, n_embd}); - layer.wqb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wq.weight", {n_embd, n_embd}); - layer.wka = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wk.weight", {n_embd, n_embd}); - layer.wkb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wk.weight", {n_embd, n_embd}); - layer.wva = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wv.weight", {n_embd, n_embd}); - layer.wvb = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wv.weight", {n_embd, n_embd}); - layer.woa = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_lora, n_embd); // (layers_i + ".attention.wo.weight", {n_embd, n_embd}); - layer.wob = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_lora); // (layers_i + ".attention.wo.weight", {n_embd, n_embd}); - - layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); // (layers_i + ".ffn_norm.weight", {n_embd}); - - layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); // (layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}); - layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); // (layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}); - layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); // (layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}); - } -} - -static void set_param_model(struct llama_model * model) { - const auto& hparams = model->hparams; - - const uint32_t n_layer = hparams.n_layer; - - struct ggml_context* ctx = model->ctx; - - ggml_set_param(ctx, model->tok_embeddings); - ggml_set_param(ctx, model->norm); - ggml_set_param(ctx, model->output); - - for (uint32_t i = 0; i < n_layer; ++i) { - auto & layer = model->layers[i]; - - ggml_set_param(ctx, layer.attention_norm); - ggml_set_param(ctx, layer.wq); - ggml_set_param(ctx, layer.wk); - ggml_set_param(ctx, layer.wv); - ggml_set_param(ctx, layer.wo); - ggml_set_param(ctx, layer.ffn_norm); - ggml_set_param(ctx, layer.w1); - ggml_set_param(ctx, layer.w2); - ggml_set_param(ctx, layer.w3); - } -} - -static void set_param_model_lora(struct llama_model_lora * model) { - const auto& hparams = model->hparams; - - const uint32_t n_layer = hparams.n_layer; - - struct ggml_context* ctx = model->ctx; - - ggml_set_param(ctx, model->tok_embeddings); - ggml_set_param(ctx, model->norm); - ggml_set_param(ctx, model->outputa); - ggml_set_param(ctx, model->outputb); - - for (uint32_t i = 0; i < n_layer; ++i) { - auto & layer = model->layers[i]; - - ggml_set_param(ctx, layer.attention_norm); - ggml_set_param(ctx, layer.wqa); - ggml_set_param(ctx, layer.wqb); - ggml_set_param(ctx, layer.wka); - ggml_set_param(ctx, layer.wkb); - ggml_set_param(ctx, layer.wva); - ggml_set_param(ctx, layer.wvb); - ggml_set_param(ctx, layer.woa); - ggml_set_param(ctx, layer.wob); - ggml_set_param(ctx, layer.ffn_norm); - ggml_set_param(ctx, layer.w1); - ggml_set_param(ctx, layer.w2); - ggml_set_param(ctx, layer.w3); - } -} - -static void randomize_model(struct llama_model * model, int seed, float mean, float std, float min, float max) { - const auto & hparams = model->hparams; - - const uint32_t n_layer = hparams.n_layer; - - struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max); - - randomize_tensor_normal(model->tok_embeddings , rnd); - randomize_tensor_normal(model->norm , rnd); - randomize_tensor_normal(model->output , rnd); - - for (uint32_t i = 0; i < n_layer; ++i) { - auto & layer = model->layers[i]; - randomize_tensor_normal(layer.attention_norm, rnd); - - randomize_tensor_normal(layer.wq, rnd); - randomize_tensor_normal(layer.wk, rnd); - randomize_tensor_normal(layer.wv, rnd); - randomize_tensor_normal(layer.wo, rnd); - - randomize_tensor_normal(layer.ffn_norm, rnd); - - randomize_tensor_normal(layer.w1, rnd); - randomize_tensor_normal(layer.w2, rnd); - randomize_tensor_normal(layer.w3, rnd); - } - - free_random_normal_distribution(rnd); -} - - -static void randomize_model_lora( - struct llama_model_lora * model, int seed, float mean, float std, float min, float max -) { - const auto & hparams = model->hparams; - - const uint32_t n_layer = hparams.n_layer; - - struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max); - - randomize_tensor_normal(model->tok_embeddings, rnd); - randomize_tensor_normal(model->norm , rnd); - randomize_tensor_normal(model->outputa , rnd); - randomize_tensor_normal(model->outputb , rnd); - - for (uint32_t i = 0; i < n_layer; ++i) { - auto & layer = model->layers[i]; - randomize_tensor_normal(layer.attention_norm, rnd); - - randomize_tensor_normal(layer.wqa, rnd); - randomize_tensor_normal(layer.wqb, rnd); - randomize_tensor_normal(layer.wka, rnd); - randomize_tensor_normal(layer.wkb, rnd); - randomize_tensor_normal(layer.wva, rnd); - randomize_tensor_normal(layer.wvb, rnd); - randomize_tensor_normal(layer.woa, rnd); - randomize_tensor_normal(layer.wob, rnd); - - randomize_tensor_normal(layer.ffn_norm, rnd); - - randomize_tensor_normal(layer.w1, rnd); - randomize_tensor_normal(layer.w2, rnd); - randomize_tensor_normal(layer.w3, rnd); - } - - free_random_normal_distribution(rnd); -} - -static void init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model, int n_batch) { - const auto & hparams = model->hparams; - - const uint32_t n_ctx = hparams.n_ctx; - const uint32_t n_embd = hparams.n_embd; - const uint32_t n_layer = hparams.n_layer; - - const int64_t n_mem = n_layer*n_ctx*n_batch; - const int64_t n_elements = n_embd*n_mem; - - // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); - - // struct ggml_init_params params; - // params.mem_size = cache.buf.size; - // params.mem_buffer = cache.buf.addr; - // params.no_alloc = false; - if (!cache->ctx) { - struct ggml_init_params params; - params.mem_size = 2u*n_elements*ggml_type_size(GGML_TYPE_F32) + 2u*1024*1024; - params.mem_buffer = NULL; - params.no_alloc = false; - - cache->ctx = ggml_init(params); - - if (!cache->ctx) { - fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); - exit(1); - } - } - - cache->k = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements); - cache->v = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements); -} - -static bool init_kv_cache_lora(struct llama_kv_cache* cache, struct llama_model_lora * model, int n_batch) { - const auto & hparams = model->hparams; - - const uint32_t n_ctx = hparams.n_ctx; - const uint32_t n_embd = hparams.n_embd; - const uint32_t n_layer = hparams.n_layer; - - const int64_t n_mem = n_layer*n_ctx*n_batch; - const int64_t n_elements = n_embd*n_mem; - - // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); - - // struct ggml_init_params params; - // params.mem_size = cache.buf.size; - // params.mem_buffer = cache.buf.addr; - // params.no_alloc = false; - if (!cache->ctx) { - struct ggml_init_params params; - params.mem_size = 2u*n_elements*ggml_type_size(GGML_TYPE_F32) + 2u*1024*1024; - params.mem_buffer = NULL; - params.no_alloc = false; - - cache->ctx = ggml_init(params); - - if (!cache->ctx) { - fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); - return false; - } - } - - cache->k = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements); - cache->v = ggml_new_tensor_1d(cache->ctx, GGML_TYPE_F32, n_elements); - - return true; -} - -static struct ggml_tensor * forward( - struct llama_model * model, - struct llama_kv_cache * cache, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - struct ggml_tensor * tokens_input, - const int n_tokens, - const int n_past -) { - const int N = n_tokens; - - struct llama_kv_cache& kv_self = *cache; - const auto & hparams = model->hparams; - const int n_ctx = hparams.n_ctx; - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_head = hparams.n_head; - const int n_rot = hparams.n_rot; - - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens)); - - struct ggml_tensor * kc = kv_self.k; - struct ggml_tensor * vc = kv_self.v; - - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < N; ++i) { - data[i] = n_past + i; - } - } - - // inpL shape [n_embd,N,1,1] - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - struct ggml_tensor * cur; - - // lctx.use_buf(ctx0, 0); - - // norm - { - // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - - // cur = attention_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].attention_norm, cur), - cur); - } - - // self-attention - { - // compute Q and K and RoPE them - // wq shape [n_embd, n_embd, 1, 1] - // wk shape [n_embd, n_embd, 1, 1] - // Qcur shape [n_embd/n_head, n_head, N, 1] - // Kcur shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0); - - // store key and value to memory - { - // compute the transposed [N, n_embd] V matrix - // wv shape [n_embd, n_embd, 1, 1] - // Vcur shape [n_embd, N, 1, 1] - struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wv, cur), n_embd, N))); - - // kv_self.k shape [n_embd * n_ctx * n_layer, 1] - // kv_self.v shape [n_embd * n_ctx * n_layer, 1] - // k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0] - // v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0] - - /* { - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); - - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } //*/ - - kc = ggml_set_1d(ctx0, kc, ggml_reshape_1d(ctx0, Kcur, n_embd*N), (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); - vc = ggml_set_2d(ctx0, vc, Vcur, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); - } - - // Qcur shape [n_embd/n_head, n_head, N, 1] - // Q shape [n_embd/n_head, N, n_head, 1] - struct ggml_tensor * Q = - ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); - - // kv_self.k shape [n_embd * n_ctx * n_layer, 1] - // K shape [n_embd/n_head, n_past + N, n_head, 1] - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kc, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kc)*n_embd), - n_embd/n_head, n_head, n_past + N), - 0, 2, 1, 3); - - // K * Q - // KQ shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head)); - - // KQ_masked = mask_past(KQ_scaled) - // KQ_masked shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - - // KQ = soft_max(KQ_masked) - // KQ_soft_max shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - - // split cached V into n_head heads - //// V shape [n_past + N, n_embd/n_head, n_head, 1] - // V shape [n_past + N, n_embd/n_head, n_head, 1] == kv_self.v[:,:(n_past+N),il,1] - struct ggml_tensor * V = - ggml_view_3d(ctx0, vc, - n_past + N, n_embd/n_head, n_head, - n_ctx*ggml_element_size(vc), - n_ctx*ggml_element_size(vc)*n_embd/n_head, - il*n_ctx*ggml_element_size(vc)*n_embd); - - // KQV shape [n_embd/n_head, N, n_head, 1] - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - // KQV_merged shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - // KQV_merged shape - - // cur = KQV_merged.contiguous().view(n_embd, N) - // cur shape [n_embd,N,1,1] - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N); - // cur = ggml_cpy(ctx0, - // KQV_merged, - // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - - // projection (no bias) - // cur shape [n_embd,N,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].wo, - cur); - } - - // lctx.use_buf(ctx0, 1); - - // inpFF shape [n_embd,N,1,1] - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - - // feed-forward network - { - // norm - { - // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); - - // cur = ffn_norm*cur - // cur shape [n_embd,N,1,1] - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), - cur); - } - - // tmp shape [n_ff,N,1,1] - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model->layers[il].w3, - cur); - - // cur shape [n_ff,N,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w1, - cur); - - // SILU activation - // cur shape [n_ff,N,1,1] - cur = ggml_silu(ctx0, cur); - - // cur shape [n_ff,N,1,1] - cur = ggml_mul(ctx0, cur, tmp); - - // cur shape [n_embd,N,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w2, - cur); - } - - // cur shape [n_embd,N,1,1] - cur = ggml_add(ctx0, cur, inpFF); - - // input for next layer - // inpL shape [n_embd,N,1,1] - inpL = cur; - } - - // norm - { - - // inpL shape [n_embd,N,1,1] - inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - - // inpL = norm*inpL - // inpL shape [n_embd,N,1,1] - inpL = ggml_mul(ctx0, - ggml_repeat(ctx0, model->norm, inpL), - inpL); - - //embeddings = inpL; - } - - // lm_head - // inpL shape [n_vocab,N,1,1] - inpL = ggml_mul_mat(ctx0, model->output, inpL); - - // run the computation - ggml_build_forward_expand(gf, inpL); - - return inpL; -} - -static struct ggml_tensor * forward_batch( - struct llama_model * model, - struct llama_kv_cache * cache, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - struct ggml_tensor * tokens_input, - const int n_tokens, - const int n_past, - const int n_batch -) { - const int N = n_tokens; - - struct llama_kv_cache& kv_self = *cache; - const auto & hparams = model->hparams; - const int n_ctx = hparams.n_ctx; - const int n_vocab = hparams.n_vocab; - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_head = hparams.n_head; - const int n_rot = hparams.n_rot; - const int n_ff = get_n_ff(&hparams); - - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); - memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); - - struct ggml_tensor * kc = kv_self.k; - struct ggml_tensor * vc = kv_self.v; - - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < N; ++i) { - data[i] = n_past + i; - } - } - - // inpL shape [n_embd,N*n_batch,1] - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); - assert_shape_2d(inpL, n_embd, N*n_batch); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - struct ggml_tensor * cur; - - // lctx.use_buf(ctx0, 0); - - // norm - { - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - assert_shape_2d(cur, n_embd, N*n_batch); - - // cur = attention_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].attention_norm, cur), - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // self-attention - { - // compute Q and K and RoPE them - // wq shape [n_embd, n_embd, 1, 1] - // wk shape [n_embd, n_embd, 1, 1] - // Qcur shape [n_embd/n_head, n_head, N, n_batch] - // Kcur shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0); - assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); - assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); - - // store key and value to memory - { - // compute the transposed [N, n_embd] V matrix - // wv shape [n_embd, n_embd, 1, 1] - // Vcur shape [N, n_embd, n_batch, 1] - struct ggml_tensor * Vcur = ggml_cont(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_mul_mat(ctx0, - model->layers[il].wv, - cur), - n_embd, N, n_batch), - 1, 0, 2, 3)); - - assert_shape_3d(Vcur, N, n_embd, n_batch); - - // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] - // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer] - // k shape [n_embd * N, n_batch] == kv_self.k[:,n_past:n_past+N,:,il] - // v shape [N, n_embd, n_batch, 1] == kv_self.v[:,n_past:n_past+N,:,il] - - /* { - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); - - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } //*/ - - kc = ggml_set_2d(ctx0, kc, - ggml_reshape_2d(ctx0, Kcur, n_embd*N, n_batch), - ggml_element_size(kc)*n_embd*n_ctx, - (ggml_element_size(kc)*n_embd)*(il*n_batch*n_ctx + n_past)); - vc = ggml_set_2d(ctx0, vc, - ggml_reshape_2d(ctx0, Vcur, N*n_embd, n_batch), - ggml_element_size(vc)*n_ctx*n_embd, - ggml_element_size(vc)*(n_past + il*n_embd*n_batch*n_ctx)); - - assert_shape_1d(kc, n_embd * n_ctx * n_batch * n_layer); - assert_shape_1d(vc, n_embd * n_ctx * n_batch * n_layer); - } - - // Qcur shape [n_embd/n_head, n_head, N, n_batch] - // Q shape [n_embd/n_head, N, n_head, n_batch] - struct ggml_tensor * Q = - ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); - assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); - - // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] - // K shape [n_embd/n_head, n_past + N, n_head, n_batch] - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_4d(ctx0, - ggml_view_3d(ctx0, - kc, - n_embd, - (n_past + N), - n_batch, - n_embd*ggml_element_size(kc), - n_ctx*n_embd*ggml_element_size(kc), - il*n_batch*n_ctx*n_embd*ggml_element_size(kc)), - n_embd/n_head, n_head, n_past + N, n_batch), - 0, 2, 1, 3); - assert_shape_4d(K, n_embd/n_head, n_past + N, n_head, n_batch); - - // K * Q - // KQ shape [n_past + N, N, n_head, n_batch] - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - assert_shape_4d(KQ, n_past + N, N, n_head, n_batch); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - // KQ_scaled shape [n_past + N, N, n_head, n_batch] - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head)); - assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch); - - // KQ_masked = mask_past(KQ_scaled) - // KQ_masked shape [n_past + N, N, n_head, n_batch] - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - assert_shape_4d(KQ_masked, n_past + N, N, n_head, n_batch); - - // KQ = soft_max(KQ_masked) - // KQ_soft_max shape [n_past + N, N, n_head, n_batch] - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - assert_shape_4d(KQ_soft_max, n_past + N, N, n_head, n_batch); - - // split cached V into n_head heads - // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer] - // V shape [n_past + N, n_embd/n_head, n_head, n_batch] == kv_self.v[:(n_past+N),:,:,il] - struct ggml_tensor * V = - ggml_view_4d(ctx0, vc, - n_past + N, n_embd/n_head, n_head, n_batch, - ggml_element_size(vc)*n_ctx, - ggml_element_size(vc)*n_ctx*n_embd/n_head, - ggml_element_size(vc)*n_ctx*n_embd, - il*n_batch*n_ctx*n_embd*ggml_element_size(vc)); - assert_shape_4d(V, n_past + N, n_embd/n_head, n_head, n_batch); - - // KQV shape [n_embd/n_head, N, n_head, n_batch] - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - // KQV_merged shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); - // KQV_merged shape - - // cur = KQV_merged.contiguous().view(n_embd, N) - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); - assert_shape_2d(cur, n_embd, N*n_batch); - // cur = ggml_cpy(ctx0, - // KQV_merged, - // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - - // projection (no bias) - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].wo, - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // lctx.use_buf(ctx0, 1); - - // inpFF shape [n_embd,N*n_batch,1,1] - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - assert_shape_2d(inpFF, n_embd, N*n_batch); - - // feed-forward network - { - // norm - { - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); - assert_shape_2d(cur, n_embd, N*n_batch); - - // cur = ffn_norm*cur - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // tmp shape [n_ff,N*n_batch,1,1] - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model->layers[il].w3, - cur); - assert_shape_2d(tmp, n_ff, N*n_batch); - - // cur shape [n_ff,N*n_batch,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w1, - cur); - assert_shape_2d(cur, n_ff, N*n_batch); - - // SILU activation - // cur shape [n_ff,N*n_batch,1,1] - cur = ggml_silu(ctx0, cur); - assert_shape_2d(cur, n_ff, N*n_batch); - - // cur shape [n_ff,N*n_batch,1,1] - cur = ggml_mul(ctx0, cur, tmp); - assert_shape_2d(cur, n_ff, N*n_batch); - - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w2, - cur); - assert_shape_2d(cur, n_embd, N*n_batch); - } - - // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_add(ctx0, cur, inpFF); - assert_shape_2d(cur, n_embd, N*n_batch); - - // input for next layer - // inpL shape [n_embd,N*n_batch,1,1] - inpL = cur; - assert_shape_2d(inpL, n_embd, N*n_batch); - } - - // norm - { - - // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - assert_shape_2d(inpL, n_embd, N*n_batch); - - // inpL = norm*inpL - // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_mul(ctx0, - ggml_repeat(ctx0, model->norm, inpL), - inpL); - - assert_shape_2d(inpL, n_embd, N*n_batch); - - //embeddings = inpL; - } - - // lm_head - // inpL shape [n_vocab,N*n_batch,1,1] - inpL = ggml_mul_mat(ctx0, model->output, inpL); - assert_shape_2d(inpL, n_vocab, N*n_batch); - - { - // inpL shape [n_vocab,N,n_batch,1] - inpL = ggml_reshape_3d(ctx0, - inpL, - n_vocab, N, n_batch); - assert_shape_3d(inpL, n_vocab, N, n_batch); - } - - // run the computation - ggml_build_forward_expand(gf, inpL); - - return inpL; -} - -static struct ggml_tensor * forward_lora( - struct llama_model_lora * model, - struct llama_kv_cache * cache, - struct ggml_context * ctx0, - struct ggml_cgraph * gf, - struct ggml_tensor * tokens_input, - const int n_tokens, - const int n_past -) { - const int N = n_tokens; - - struct llama_kv_cache& kv_self = *cache; - const auto & hparams = model->hparams; - - const int n_ctx = hparams.n_ctx; - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_head = hparams.n_head; - const int n_rot = hparams.n_rot; - - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens)); - - struct ggml_tensor * kc = kv_self.k; - struct ggml_tensor * vc = kv_self.v; - - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < N; ++i) { - data[i] = n_past + i; - } - } - - // inpL shape [n_embd,N,1,1] - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * inpSA = inpL; - - struct ggml_tensor * cur; - - // norm - { - // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - - // cur = attention_norm*cur - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].attention_norm, cur), - cur); - } - - // self-attention - { - // compute Q and K and RoPE them - // wq shape [n_embd, n_embd, 1, 1] - // wk shape [n_embd, n_embd, 1, 1] - // Qcur shape [n_embd/n_head, n_head, N, 1] - // Kcur shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * Qcur = ggml_rope(ctx0, - ggml_reshape_3d(ctx0, - ggml_mul_mat(ctx0, - model->layers[il].wqa, - ggml_mul_mat(ctx0, - model->layers[il].wqb, - cur)), - n_embd/n_head, n_head, N), - KQ_pos, n_rot, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, - ggml_reshape_3d(ctx0, - ggml_mul_mat(ctx0, - model->layers[il].wka, - ggml_mul_mat(ctx0, - model->layers[il].wkb, - cur)), - n_embd/n_head, n_head, N), - KQ_pos, n_rot, 0); - - // store key and value to memory - { - // compute the transposed [N, n_embd] V matrix - // wv shape [n_embd, n_embd, 1, 1] - // Vcur shape [n_embd, N, 1, 1] - struct ggml_tensor * Vcur = ggml_cont(ctx0, - ggml_transpose(ctx0, - ggml_reshape_2d(ctx0, - ggml_mul_mat(ctx0, - model->layers[il].wva, - ggml_mul_mat(ctx0, - model->layers[il].wvb, - cur)), - n_embd, N))); - - // kv_self.k shape [n_embd * n_ctx * n_layer, 1] - // kv_self.v shape [n_embd * n_ctx * n_layer, 1] - // k shape [n_embd * N, 1] == kv_self.k[:,n_past:n_past+N,il,0] - // v shape [N, n_embd, 1, 1] == kv_self.v[:,n_past:n_past+N,il,0] - - /* { - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); - - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); - } //*/ - - kc = ggml_set_1d(ctx0, kc, ggml_reshape_1d(ctx0, Kcur, n_embd*N), (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); - vc = ggml_set_2d(ctx0, vc, Vcur, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); - } - - // Qcur shape [n_embd/n_head, n_head, N, 1] - // Q shape [n_embd/n_head, N, n_head, 1] - struct ggml_tensor * Q = - ggml_permute(ctx0, - Qcur, - 0, 2, 1, 3); - - // kv_self.k shape [n_embd * n_ctx * n_layer, 1] - // K shape [n_embd/n_head, n_past + N, n_head, 1] - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kc, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kc)*n_embd), - n_embd/n_head, n_head, n_past + N), - 0, 2, 1, 3); - - // K * Q - // KQ shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - - // KQ_scaled = KQ / sqrt(n_embd/n_head) - // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head)); - - // KQ_masked = mask_past(KQ_scaled) - // KQ_masked shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); - - // KQ = soft_max(KQ_masked) - // KQ_soft_max shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - - // split cached V into n_head heads - //// V shape [n_past + N, n_embd/n_head, n_head, 1] - // V shape [n_past + N, n_embd/n_head, n_head, 1] == kv_self.v[:,:(n_past+N),il,1] - struct ggml_tensor * V = - ggml_view_3d(ctx0, vc, - n_past + N, n_embd/n_head, n_head, - n_ctx*ggml_element_size(vc), - n_ctx*ggml_element_size(vc)*n_embd/n_head, - il*n_ctx*ggml_element_size(vc)*n_embd); - - // KQV shape [n_embd/n_head, N, n_head, 1] - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - - // KQV_merged = KQV.permute(0, 2, 1, 3) - // KQV_merged shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - // KQV_merged shape - - // cur = KQV_merged.contiguous().view(n_embd, N) - // cur shape [n_embd,N,1,1] - cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N); - // cur = ggml_cpy(ctx0, - // KQV_merged, - // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - - // projection (no bias) - // cur shape [n_embd,N,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].woa, - ggml_mul_mat(ctx0, - model->layers[il].wob, - cur)); - } - - // inpFF shape [n_embd,N,1,1] - struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); - - // feed-forward network - { - // norm - { - // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); - - // cur = ffn_norm*cur - // cur shape [n_embd,N,1,1] - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), - cur); - } - - // tmp shape [n_ff,N,1,1] - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model->layers[il].w3, - cur); - - // cur shape [n_ff,N,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w1, - cur); - - // SILU activation - // cur shape [n_ff,N,1,1] - cur = ggml_silu(ctx0, cur); - - // cur shape [n_ff,N,1,1] - cur = ggml_mul(ctx0, cur, tmp); - - // cur shape [n_embd,N,1,1] - cur = ggml_mul_mat(ctx0, - model->layers[il].w2, - cur); - } - - // cur shape [n_embd,N,1,1] - cur = ggml_add(ctx0, cur, inpFF); - - // input for next layer - // inpL shape [n_embd,N,1,1] - inpL = cur; - } - - // norm - { - - // inpL shape [n_embd,N,1,1] - inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); - - // inpL = norm*inpL - // inpL shape [n_embd,N,1,1] - inpL = ggml_mul(ctx0, - ggml_repeat(ctx0, model->norm, inpL), - inpL); - - //embeddings = inpL; - } - - - // lm_head - // inpL shape [n_vocab,N,1,1] - inpL = ggml_mul_mat(ctx0, - model->outputa, - ggml_mul_mat(ctx0, - model->outputb, - inpL)); - - // ggml_set_scratch(ctx0, { 0, 0, nullptr, }); - // run the computation - ggml_build_forward_expand(gf, inpL); - - return inpL; -} - -static void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples) { - assert(ggml_is_matrix(logits)); - assert(ggml_is_matrix(probs)); - assert(ggml_is_vector(best_samples)); - assert(logits->ne[1] == best_samples->ne[0]); - assert(logits->ne[0] == probs->ne[0]); - assert(logits->ne[1] == probs->ne[1]); - for (int i = 0; i < logits->ne[1]; ++i) { - float max_logit = ggml_get_f32_1d(logits, i * logits->ne[0]); - ggml_set_i32_1d(best_samples, i, 0); - for (int k = 0; k < logits->ne[0]; ++k) { - float logit = ggml_get_f32_1d(logits, i * logits->ne[0] + k); - if (logit > max_logit) { - max_logit = logit; - ggml_set_i32_1d(best_samples, i, k); - } - } - float psum = 0; - for (int k = 0; k < logits->ne[0]; ++k) { - float logit = ggml_get_f32_1d(logits, i * logits->ne[0] + k); - float p = (logit == -INFINITY) ? 0 : expf(logit - max_logit); - psum += p; - ggml_set_f32_1d(probs, i * probs->ne[0] + k, p); - } - for (int k = 0; k < logits->ne[0]; ++k) { - float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k); - ggml_set_f32_1d(probs, i * probs->ne[0] + k, p / psum); - } - } -} - -static void sample_softmax_batch( - struct ggml_context * ctx, struct ggml_tensor * logits, struct ggml_tensor * probs, - struct ggml_tensor * best_samples -) { - GGML_ASSERT(ggml_is_matrix(best_samples)); - GGML_ASSERT(ggml_is_3d(logits)); - GGML_ASSERT(ggml_is_3d(probs)); - int n_tokens = best_samples->ne[0]; - int n_batch = best_samples->ne[1]; - int n_vocab = logits->ne[0]; - GGML_ASSERT(n_tokens == logits->ne[1]); - GGML_ASSERT(n_batch == logits->ne[2]); - GGML_ASSERT(n_vocab == probs->ne[0]); - GGML_ASSERT(n_tokens == probs->ne[1]); - GGML_ASSERT(n_batch == probs->ne[2]); - - for (int k = 0; k < n_batch; ++k) { - struct ggml_tensor * best_samples_k = ggml_view_1d(ctx, - best_samples, - best_samples->ne[0], - k*best_samples->nb[1]); - struct ggml_tensor * logits_k = ggml_view_2d(ctx, - logits, - logits->ne[0], - logits->ne[1], - logits->nb[1], - k*logits->nb[2]); - struct ggml_tensor * probs_k = ggml_view_2d(ctx, - probs, - probs->ne[0], - probs->ne[1], - probs->nb[1], - k*probs->nb[2]); - sample_softmax(logits_k, probs_k, best_samples_k); - } -} - -static void print_row(struct ggml_tensor * probs, int i) { - for (int k = 0; k < probs->ne[0]; ++k) { - float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k); - printf(" %.2f", p); - } - printf("\n"); -} - -static void print_matrix(struct ggml_tensor * probs) { - assert(ggml_is_matrix(probs)); - for (int i = 0; i < probs->ne[1]; ++i) { - for (int k = 0; k < probs->ne[0]; ++k) { - float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k); - printf(" %.2f", p); - } - printf("\n"); - } -} - -static void print_token(int token, int n_vocab) { - for (int k = 0; k < token; ++k) { - printf(" "); - } - printf("X"); - for (int k = token+1; k < n_vocab; ++k) { - printf(" "); - } - printf("\n"); -} - -static void print_tokens(struct ggml_tensor * tokens, int n_vocab) { - for (int i=0; ine[0]; ++i) { - int token = ggml_get_i32_1d(tokens, i); - print_token(token, n_vocab); - } -} - -static void get_example_targets(int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) { - int n_tokens = tokens_input->ne[0]; - int n_vocab = targets->ne[0]; - float randomness = 0.0f; - // ggml_set_zero(targets); - ggml_set_f32(targets, -1.0f); - ggml_set_i32_1d(tokens_input, 0, 0); - for (int i=1; i 1.0f) ? 1.0f : z; // clamp to [0..1] - int token = std::max(1,std::min(1+(int)(z*(float)(n_vocab-1)), n_vocab-1)); - ggml_set_f32_1d(targets, (i-1)*n_vocab + token, +1.0f); - if (ine[0]; - int n_batch = tokens_input->ne[1]; - GGML_ASSERT(n_tokens == targets->ne[1]); - GGML_ASSERT(n_batch == targets->ne[2]); - - for (int k=0; kne[0], - k*tokens_input->nb[1]); - struct ggml_tensor * targets_k = ggml_view_2d(ctx, - targets, - targets->ne[0], - targets->ne[1], - targets->nb[1], - k*targets->nb[2]); - get_example_targets(example_id*n_batch + k, tokens_input_k, targets_k); - } -} - -static void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * targets, int n_shift) { - int n_tokens = tokens_input->ne[0]; - int n_vocab = targets->ne[0]; - for (int i=0; i work_buffer; - - for (int ex=0; ex "" [extra-main-args] -# - -if [ $# -lt 2 ]; then - echo "Usage: ./base-translate.sh \"\" [extra-main-args]" - exit 1 -fi - -eargs="" -if [ $# -gt 2 ]; then - eargs="${@:3}" -fi - -ftmp="__llama.cpp_example_tmp__.txt" -trap "rm -f $ftmp" EXIT - -echo "Translate from English to French: - -=== - -sea otter, peppermint, plush girafe: - -sea otter => loutre de mer -peppermint => menthe poivrée -plush girafe => girafe peluche - -=== - -violin - -violin => violon - -=== - -phone, computer, mouse, keyboard: - -phone => téléphone -computer => ordinateur -mouse => souris -keyboard => clavier - -=== -" > $ftmp - -echo "$2 -" >> $ftmp - -model=$1 - -# generate the most likely continuation until the string "===" is found -./llama-cli -m $model -f $ftmp -n 64 --temp 0 --repeat-penalty 1.0 --no-penalize-nl -r "===" $eargs diff --git a/examples/batched-bench/CMakeLists.txt b/examples/batched-bench/CMakeLists.txt index 959acaeee..68ad707f3 100644 --- a/examples/batched-bench/CMakeLists.txt +++ b/examples/batched-bench/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-batched-bench) add_executable(${TARGET} batched-bench.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index f3b0c433b..0659ab6f1 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -1,47 +1,28 @@ +#include "arg.h" #include "common.h" +#include "log.h" #include "llama.h" #include -#include #include #include #include -// mutates the input string -static std::vector parse_list(char * p) { - std::vector ret; - - char * q = p; - - while (*p) { - if (*p == ',') { - *p = '\0'; - ret.push_back(std::atoi(q)); - q = p + 1; - } - - ++p; - } - - ret.push_back(std::atoi(q)); - - return ret; -} - static void print_usage(int, char ** argv) { - LOG_TEE("\nexample usage:\n"); - LOG_TEE("\n %s -m model.gguf -c 2048 -b 2048 -ub 512 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32 [-pps]\n", argv[0]); - LOG_TEE("\n"); + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -c 2048 -b 2048 -ub 512 -npp 128,256,512 -ntg 128,256 -npl 1,2,4,8,16,32 [-pps]\n", argv[0]); + LOG("\n"); } int main(int argc, char ** argv) { - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_BENCH, print_usage); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) { return 1; } + common_init(); + int is_pp_shared = params.is_pp_shared; std::vector n_pp = params.n_pp; @@ -55,21 +36,21 @@ int main(int argc, char ** argv) { // initialize the model - llama_model_params model_params = llama_model_params_from_gpt_params(params); + llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; } - llama_context_params ctx_params = llama_context_params_from_gpt_params(params); + llama_context_params ctx_params = common_context_params_to_llama(params); // ensure enough sequences are available ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end()); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -93,12 +74,11 @@ int main(int argc, char ** argv) { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); if (ret != 0) { - LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); return false; } @@ -111,21 +91,21 @@ int main(int argc, char ** argv) { // warm up { for (int i = 0; i < 16; ++i) { - llama_batch_add(batch, 0, i, { 0 }, false); + common_batch_add(batch, 0, i, { 0 }, false); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { - LOG_TEE("%s: llama_decode() failed\n", __func__); + LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } } if (!params.batched_bench_output_jsonl) { - LOG_TEE("\n"); - LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); - LOG_TEE("\n"); - LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); - LOG_TEE("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------"); + LOG("\n"); + LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG("\n"); + LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); + LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------"); } for ( int i_pp = 0; i_pp < (int) n_pp.size(); ++i_pp) { @@ -141,11 +121,11 @@ int main(int argc, char ** argv) { continue; } - llama_batch_clear(batch); + common_batch_clear(batch); for (int i = 0; i < pp; ++i) { for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { - llama_batch_add(batch, 0, i, { j }, false); + common_batch_add(batch, 0, i, { j }, false); } } batch.logits[batch.n_tokens - 1] = true; @@ -155,7 +135,7 @@ int main(int argc, char ** argv) { llama_kv_cache_clear(ctx); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { - LOG_TEE("%s: llama_decode() failed\n", __func__); + LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -170,14 +150,14 @@ int main(int argc, char ** argv) { const auto t_tg_start = ggml_time_us(); for (int i = 0; i < tg; ++i) { - llama_batch_clear(batch); + common_batch_clear(batch); for (int j = 0; j < pl; ++j) { - llama_batch_add(batch, 0, pp + i, { j }, true); + common_batch_add(batch, 0, pp + i, { j }, true); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { - LOG_TEE("%s: llama_decode() failed\n", __func__); + LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } } @@ -195,30 +175,30 @@ int main(int argc, char ** argv) { const float speed = n_kv / t; if(params.batched_bench_output_jsonl) { - LOG_TEE( + LOG( "{\"n_kv_max\": %d, \"n_batch\": %d, \"n_ubatch\": %d, \"flash_attn\": %d, \"is_pp_shared\": %d, \"n_gpu_layers\": %d, \"n_threads\": %u, \"n_threads_batch\": %u, " "\"pp\": %d, \"tg\": %d, \"pl\": %d, \"n_kv\": %d, \"t_pp\": %f, \"speed_pp\": %f, \"t_tg\": %f, \"speed_tg\": %f, \"t\": %f, \"speed\": %f}\n", n_kv_max, params.n_batch, params.n_ubatch, params.flash_attn, params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch, pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed ); } else { - LOG_TEE("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed); + LOG("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed); } } } } - LOG_TEE("\n"); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + LOG("\n"); + llama_perf_context_print(ctx); llama_batch_free(batch); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); - fprintf(stderr, "\n\n"); + LOG("\n\n"); return 0; } diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 4bc2bbf2c..371917b2e 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -23,12 +23,12 @@ defer { } let model_params = llama_model_default_params() -guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), model_params) else { +guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else { print("Failed to load model") exit(1) } defer { - llama_free_model(model) + llama_model_free(model) } var tokens = tokenize(text: prompt, add_bos: true) @@ -140,10 +140,8 @@ while n_cur <= n_len { let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) - llama_sampler_accept(smpl, new_token_id) - // is it an end of stream? -> mark the stream as finished - if llama_token_is_eog(model, new_token_id) || n_cur == n_len { + if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len { i_batch[i] = -1 // print("") if n_parallel > 1 { @@ -202,8 +200,8 @@ let t_main_end = ggml_time_us() print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n") -llama_perf_print(UnsafeRawPointer(context), LLAMA_PERF_TYPE_CONTEXT) -llama_perf_print(UnsafeRawPointer(smpl), LLAMA_PERF_TYPE_SAMPLER_CHAIN) +llama_perf_sampler_print(smpl) +llama_perf_context_print(context) private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let utf8Count = text.utf8.count diff --git a/examples/batched/CMakeLists.txt b/examples/batched/CMakeLists.txt index 77e33343b..0d439f498 100644 --- a/examples/batched/CMakeLists.txt +++ b/examples/batched/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-batched) add_executable(${TARGET} batched.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index f5f309022..21b95ef5e 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -1,4 +1,6 @@ +#include "arg.h" #include "common.h" +#include "log.h" #include "llama.h" #include @@ -7,22 +9,22 @@ #include static void print_usage(int, char ** argv) { - LOG_TEE("\nexample usage:\n"); - LOG_TEE("\n %s -m model.gguf -p \"Hello my name is\" -n 32 -np 4\n", argv[0]); - LOG_TEE("\n"); + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -p \"Hello my name is\" -n 32 -np 4\n", argv[0]); + LOG("\n"); } int main(int argc, char ** argv) { - gpt_params params; + common_params params; params.prompt = "Hello my name is"; params.n_predict = 32; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON, print_usage); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) { return 1; } + common_init(); // number of parallel batches int n_parallel = params.n_parallel; @@ -37,66 +39,67 @@ int main(int argc, char ** argv) { // initialize the model - llama_model_params model_params = llama_model_params_from_gpt_params(params); + llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); if (model == NULL) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); + LOG_ERR("%s: error: unable to load model\n" , __func__); return 1; } + const llama_vocab * vocab = llama_model_get_vocab(model); + // tokenize the prompt std::vector tokens_list; - tokens_list = ::llama_tokenize(model, params.prompt, true); + tokens_list = common_tokenize(vocab, params.prompt, true); const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel; // initialize the context - llama_context_params ctx_params = llama_context_params_from_gpt_params(params); + llama_context_params ctx_params = common_context_params_to_llama(params); ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_predict, n_parallel); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k)); - llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep)); - llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp)); - llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); if (ctx == NULL) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + LOG_ERR("%s: error: failed to create the llama_context\n" , __func__); return 1; } const int n_ctx = llama_n_ctx(ctx); - LOG_TEE("\n%s: n_predict = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); + LOG_INF("\n%s: n_predict = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); // make sure the KV cache is big enough to hold all the prompt and generated tokens if (n_kv_req > n_ctx) { - LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req); - LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__); + LOG_ERR("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req); + LOG_ERR("%s: either reduce n_parallel or increase n_ctx\n", __func__); return 1; } // print the prompt token-by-token - fprintf(stderr, "\n"); + LOG("\n"); for (auto id : tokens_list) { - fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + LOG("%s", common_token_to_piece(ctx, id).c_str()); } - fflush(stderr); - // create a llama_batch // we use this object to submit token data for decoding llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel); @@ -108,30 +111,30 @@ int main(int argc, char ** argv) { // evaluate the initial prompt for (size_t i = 0; i < tokens_list.size(); ++i) { - llama_batch_add(batch, tokens_list[i], i, seq_ids, false); + common_batch_add(batch, tokens_list[i], i, seq_ids, false); } GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); if (llama_model_has_encoder(model)) { if (llama_encode(ctx, batch)) { - LOG_TEE("%s : failed to eval\n", __func__); + LOG_ERR("%s : failed to eval\n", __func__); return 1; } llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { - decoder_start_token_id = llama_token_bos(model); + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { + decoder_start_token_id = llama_vocab_bos(vocab); } - llama_batch_clear(batch); - llama_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); + common_batch_clear(batch); + common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); } // llama_decode will output logits only for the last token of the prompt batch.logits[batch.n_tokens - 1] = true; if (llama_decode(ctx, batch) != 0) { - LOG_TEE("%s: llama_decode() failed\n", __func__); + LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -142,7 +145,7 @@ int main(int argc, char ** argv) { //} if (n_parallel > 1) { - LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); + LOG("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); } // main loop @@ -161,7 +164,7 @@ int main(int argc, char ** argv) { while (n_cur <= n_predict) { // prepare the next batch - llama_batch_clear(batch); + common_batch_clear(batch); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -172,14 +175,12 @@ int main(int argc, char ** argv) { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); - llama_sampler_accept(smpl, new_token_id); - // is it an end of generation? -> mark the stream as finished - if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { i_batch[i] = -1; - LOG_TEE("\n"); + LOG("\n"); if (n_parallel > 1) { - LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur); + LOG_INF("%s: stream %d finished at n_cur = %d", __func__, i, n_cur); } continue; @@ -187,16 +188,15 @@ int main(int argc, char ** argv) { // if there is only one stream, we print immediately to stdout if (n_parallel == 1) { - LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); - fflush(stdout); + LOG("%s", common_token_to_piece(ctx, new_token_id).c_str()); } - streams[i] += llama_token_to_piece(ctx, new_token_id); + streams[i] += common_token_to_piece(ctx, new_token_id); i_batch[i] = batch.n_tokens; // push this new token for next evaluation - llama_batch_add(batch, new_token_id, n_cur, { i }, true); + common_batch_add(batch, new_token_id, n_cur, { i }, true); n_decode += 1; } @@ -210,29 +210,27 @@ int main(int argc, char ** argv) { // evaluate the current batch with the transformer model if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } } - LOG_TEE("\n"); - if (n_parallel > 1) { - LOG_TEE("\n"); + LOG("\n"); for (int32_t i = 0; i < n_parallel; ++i) { - LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str()); + LOG("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str()); } } const auto t_main_end = ggml_time_us(); - LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - LOG_TEE("\n"); - llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + LOG("\n"); + llama_perf_sampler_print(smpl); + llama_perf_context_print(ctx); fprintf(stderr, "\n"); @@ -240,7 +238,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt deleted file mode 100644 index 34a58cc02..000000000 --- a/examples/benchmark/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -set(TARGET llama-bench-matmult) -add_executable(${TARGET} benchmark-matmult.cpp) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT}) -target_include_directories(${TARGET} PRIVATE ../../common) -target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/benchmark/benchmark-matmult.cpp b/examples/benchmark/benchmark-matmult.cpp deleted file mode 100644 index 97622f4f4..000000000 --- a/examples/benchmark/benchmark-matmult.cpp +++ /dev/null @@ -1,275 +0,0 @@ -#include "common.h" -#include "ggml.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { - struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr); - - if (plan.work_size > 0) { - buf.resize(plan.work_size); - plan.work_data = buf.data(); - } - - ggml_graph_compute(graph, &plan); -} - -static float tensor_sum_elements(const ggml_tensor * tensor) { - double sum = 0; - if (tensor->type == GGML_TYPE_F32) { - for (int j = 0; j < tensor->ne[1]; j++) { - for (int k = 0; k < tensor->ne[0]; k++) { - sum += ((float *) tensor->data)[j*tensor->ne[0] + k]; - } - } - } - return sum; -} - -static void tensor_dump(const ggml_tensor * tensor, const char * name) { - printf("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi) - ", name, - tensor->type, ggml_type_name(tensor->type), - tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->nb[0], tensor->nb[1], tensor->nb[2]); - float sum = tensor_sum_elements(tensor); - printf("Sum of tensor %s is %6.2f\n", name, sum); -} - -#define TENSOR_DUMP(tensor) tensor_dump(tensor, #tensor) - -struct benchmark_params_struct { - int n_threads = 1; - int32_t n_iterations = 10; -}; - -static void print_usage(int /*argc*/, char ** argv, struct benchmark_params_struct params) { - fprintf(stderr, "usage: %s [options]\n", argv[0]); - fprintf(stderr, "\n"); - fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); - fprintf(stderr, " -i N, --iter N number of iterations to use during computation (default: %d)\n", params.n_iterations); - fprintf(stderr, "\n"); -} - -int main(int argc, char ** argv) { - struct benchmark_params_struct benchmark_params; - - bool invalid_param = false; - std::string arg; - for (int i = 1; i < argc; i++) { - arg = argv[i]; - - if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_param = true; - break; - } - benchmark_params.n_threads = std::stoi(argv[i]); - } else if (arg == "-i" || arg == "--iter") { - if (++i >= argc) { - invalid_param = true; - break; - } - benchmark_params.n_iterations = std::stoi(argv[i]); - } else if (arg == "-h" || arg == "--help") { - print_usage(argc, argv, benchmark_params); - exit(0); - } - } - if (invalid_param) { - fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); - print_usage(argc, argv, benchmark_params); - exit(1); - } - - print_build_info(); - printf("Starting Test\n"); - - // create the ggml context - struct ggml_context * ctx; - //const int sizex = 4096; - //const int sizey = 11008; - -#undef VERBOSE_DEBUGGING -#ifndef VERBOSE_DEBUGGING - const int sizey = 4096; - const int sizex = 11008; - const int sizez = 128; -#else - /* Working - let's increase size */ - const int sizey = 1; - const int sizex = (8*32); - const int sizez = 1; - - /*const int sizey = 1; - const int sizex = 3*(8*32); - const int sizez = 1;*/ -#endif - - //printf("Memsize required = %i\n", sizex*sizex); - - // TODO: perform the bench for all types or for a user specified type - const ggml_type qtype = GGML_TYPE_Q4_1; - - size_t ctx_size = 0; - ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); - ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); - ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizez); - ctx_size += ggml_row_size(qtype, sizex*sizey); - ctx_size += ggml_row_size(qtype, sizex*sizey); - ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS - ctx_size += ggml_row_size(GGML_TYPE_F32, sizex*sizey); // BLAS - ctx_size += 1024*1024*16; - - printf("Allocating Memory of size %zi bytes, %zi MB\n",ctx_size, (ctx_size/1024/1024)); - - struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /* no_alloc =*/ 0 - }; - - ctx = ggml_init(params); - if (!ctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return 1; - } - - - printf("Creating new tensors\n"); - // printf("Creating new tensor m1\n"); - struct ggml_tensor * m11 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, sizex, sizey); - ggml_set_f32(m11, 1.0f); - - // printf("Creating new tensor m1\n"); - struct ggml_tensor * m12 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, sizex, sizey); - ggml_set_f32(m12, 1.5f); - - // printf("Creating new tensor m2\n"); - struct ggml_tensor * m2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, sizex, sizez); - ggml_set_f32(m2, 2.0f); - - printf("\n------ Test 1 - Matrix Mult via F32 code\n"); - // printf("Creating new tensor m11xm2\n"); - struct ggml_tensor * m11xm2 = ggml_mul_mat(ctx, m11, m2); - - // printf("Creating compute graph\n"); - struct ggml_cgraph * gf = ggml_new_graph(ctx); - ggml_build_forward_expand(gf, m11xm2); - - printf("n_threads=%i\n", benchmark_params.n_threads); - - TENSOR_DUMP(m11); - TENSOR_DUMP(m2); - - std::vector work_buffer; - - ggml_graph_compute_helper(work_buffer, gf, benchmark_params.n_threads); - - TENSOR_DUMP(gf->nodes[0]); - - printf("\n------ Test 2 - Matrix Mult via %s code\n", ggml_type_name(qtype)); - - int32_t nelements = sizex*sizey; - - // Set up a the benchmark matrices - // printf("Creating new tensor q11 & Running quantize\n"); - struct ggml_tensor * q11 = ggml_new_tensor_2d(ctx, qtype, sizex, sizey); - ggml_quantize_chunk(qtype, (const float *) m11->data, q11->data, 0, nelements/m11->ne[0], m11->ne[0], nullptr); - - // Set up a the compute graph - // printf("Creating new tensor q31\n"); - struct ggml_tensor * q31 = ggml_mul_mat(ctx, q11, m2); - - // printf("Creating compute graph\n"); - struct ggml_cgraph * gf31 = ggml_new_graph(ctx); - ggml_build_forward_expand(gf31, q31); - - // Set up a second graph computation to make sure we override the CPU cache lines - // printf("Creating new tensor q12 & Running quantize\n"); - struct ggml_tensor * q12 = ggml_new_tensor_2d(ctx, qtype, sizex, sizey); - ggml_quantize_chunk(qtype, (const float *) m12->data, q12->data, 0, nelements/m12->ne[0], m12->ne[0], nullptr); - - // printf("Creating new tensor q32\n"); - struct ggml_tensor * q32 = ggml_mul_mat(ctx, q12, m2); - - //printf("Creating compute graph\n"); - struct ggml_cgraph * gf32 = ggml_new_graph(ctx); - ggml_build_forward_expand(gf32, q32); - printf("n_threads=%i\n", benchmark_params.n_threads); - - const int dimx = sizex; - const int dimy = sizey; - const int dimz = sizez; - long long int flops_per_dot_product = dimy + dimy; - long long int flops_per_matrix = flops_per_dot_product * dimx * dimz; ; - printf("Matrix Multiplication of (%i,%i,%i) x (%i,%i,%i) - about %6.2f gFLOPS\n\n", sizex, sizey, 1, sizex, sizez, 1, 1.0f*flops_per_matrix / 1000 / 1000 / 1000); - - - // Let's use the F32 result from above as a reference for the quantized multiplication - float sum_of_F32_reference = tensor_sum_elements(gf->nodes[0]); - - printf("Iteration;NThreads; SizeX; SizeY; SizeZ; Required_FLOPS; Elapsed_u_Seconds; gigaFLOPS\n"); - printf("=====================================================================================\n"); - - double gflops_sum = 0; - for (int i=0;inodes[0]); - float delta = std::abs(sum_of_Q4_result - sum_of_F32_reference); - float allowed_delta = (sum_of_F32_reference) / 1000 / 1000; // Let's accept an epsilon of 10^-6 - - if (delta > allowed_delta) { - printf("\nABORT - ERROR in Matrix Multiplication result - expected %6.2f, got %6.2f (delta %6.2f > allowed_delta %6.2f)\n", - sum_of_F32_reference, - sum_of_Q4_result, - delta, - allowed_delta - ); - exit(0); - } - - // Running a different graph computation to make sure we override the CPU cache lines - ggml_graph_compute_helper(work_buffer, gf32, benchmark_params.n_threads); - } - printf("\n"); - printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations)); - printf("=====================================================================================\n"); -} diff --git a/examples/chat-persistent.sh b/examples/chat-persistent.sh index d9cab9836..9d761ebb8 100755 --- a/examples/chat-persistent.sh +++ b/examples/chat-persistent.sh @@ -23,8 +23,9 @@ CUR_PROMPT_CACHE="${CHAT_SAVE_DIR}/current-cache.bin" NEXT_PROMPT_FILE="${CHAT_SAVE_DIR}/next-prompt.txt" NEXT_PROMPT_CACHE="${CHAT_SAVE_DIR}/next-cache.bin" -SESSION_SIZE_MSG_PATTERN='main: session file matches [[:digit:]]+ / [[:digit:]]+' -SAMPLE_TIME_MSG_PATTERN='sample time =[[:space:]]+[[:digit:]]+.[[:digit:]]+ ms /[[:space:]]+[[:digit:]]+' +SESSION_AND_SAMPLE_PATTERN='main: session file matches [[:digit:]]+ / [[:digit:]]+'\ +'|'\ +'sampling time =[[:space:]]+[[:digit:]]+.[[:digit:]]+ ms /[[:space:]]+[[:digit:]]+' SED_DELETE_MESSAGES="/^(${USER_NAME}:|${AI_NAME}:|\\.\\.\\.)/,\$d" CTX_SIZE=2048 @@ -129,15 +130,12 @@ while read -e line; do printf ' ' - # HACK get num tokens from debug message - # TODO get both messages in one go - if ! session_size_msg="$(tail -n30 "$LOG" | grep -oE "$SESSION_SIZE_MSG_PATTERN")" || - ! sample_time_msg="$(tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then + if ! session_and_sample_msg=$(tail -n30 "$LOG" | grep -oE "$SESSION_AND_SAMPLE_PATTERN"); then echo >&2 "Couldn't get number of tokens from ./llama-cli output!" exit 1 fi - n_tokens=$(($(cut -d/ -f2 <<<"$session_size_msg") + $(cut -d/ -f2 <<<"$sample_time_msg"))) + n_tokens=$(awk '{sum+=$1} END {print sum}' <<< "$(cut -d/ -f2 <<< "$session_and_sample_msg")") if ((n_tokens > CTX_ROTATE_POINT)); then tail -c+$((n_prompt_len_pre + 1)) "$CUR_PROMPT_FILE" >>"$NEXT_PROMPT_FILE" diff --git a/examples/convert-llama2c-to-ggml/CMakeLists.txt b/examples/convert-llama2c-to-ggml/CMakeLists.txt index a6790e617..44e5f722a 100644 --- a/examples/convert-llama2c-to-ggml/CMakeLists.txt +++ b/examples/convert-llama2c-to-ggml/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-convert-llama2c-to-ggml) add_executable(${TARGET} convert-llama2c-to-ggml.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/convert-llama2c-to-ggml/README.md b/examples/convert-llama2c-to-ggml/README.md index 5774ac83c..46a42da69 100644 --- a/examples/convert-llama2c-to-ggml/README.md +++ b/examples/convert-llama2c-to-ggml/README.md @@ -2,11 +2,8 @@ This example reads weights from project [llama2.c](https://github.com/karpathy/llama2.c) and saves them in ggml compatible format. The vocab that is available in `models/ggml-vocab.bin` is used by default. -To convert the model first download the models from the [llama2.c](https://github.com/karpathy/llama2.c) repository: +To convert the model first download the models from the [llama2.c](https://github.com/karpathy/llama2.c) repository. -`$ make -j` - -After successful compilation, following usage options are available: ``` usage: ./llama-convert-llama2c-to-ggml [options] diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 8ca9f8915..bdf0eed2a 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -1,4 +1,6 @@ #include "ggml.h" +#include "gguf.h" + #include "llama.h" #include "common.h" #include "log.h" @@ -9,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -105,43 +108,43 @@ static void alloc_weights(TransformerWeights * w, const Config * p, bool shared_ const int n_multiqueries = p->n_kv_heads <= 0 || p->n_kv_heads >= p->n_heads ? 1 : p->n_heads / p->n_kv_heads; try { w->token_embedding_table.resize(p->vocab_size * p->dim); - LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); w->rms_att_weight.resize(p->n_layers * p->dim); - LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim); w->rms_ffn_weight.resize(p->n_layers * p->dim); - LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim); w->wq.resize(p->n_layers * p->dim * p->dim); - LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); w->wk.resize(p->n_layers * p->dim * p->dim / n_multiqueries); - LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries); w->wv.resize(p->n_layers * p->dim * p->dim / n_multiqueries); - LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries); w->wo.resize(p->n_layers * p->dim * p->dim); - LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); w->w1.resize(p->n_layers * p->hidden_dim * p->dim); - LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); w->w2.resize(p->n_layers * p->hidden_dim * p->dim); - LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim); w->w3.resize(p->n_layers * p->hidden_dim * p->dim); - LOG("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); w->rms_final_weight.resize(p->dim); - LOG("%s: Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim); + LOG_INF("%s: Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim); if (shared_weights) { w->wcls = {}; } else { w->wcls.resize(p->vocab_size * p->dim); - LOG("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); } } catch (std::length_error &) { @@ -173,7 +176,7 @@ static int checkpoint_init_weights(TransformerWeights * w, const Config * p, FIL fseek(f, 0, SEEK_END); auto end = ftell(f); if (curr != end) { - LOG("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", __func__, curr, end); + LOG_ERR("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", __func__, curr, end); return 1; } @@ -181,26 +184,26 @@ static int checkpoint_init_weights(TransformerWeights * w, const Config * p, FIL } static void print_sample_weights(TransformerWeights *w){ - LOG("----- Quick print of first of the weight vales of all the variables\n"); - LOG("%f\n", w->token_embedding_table[0]); - LOG("%f\n", w->rms_att_weight[0]); - LOG("%f\n", w->rms_ffn_weight[0]); + LOG_INF("----- Quick print of first of the weight vales of all the variables\n"); + LOG_INF("%f\n", w->token_embedding_table[0]); + LOG_INF("%f\n", w->rms_att_weight[0]); + LOG_INF("%f\n", w->rms_ffn_weight[0]); - LOG("%f\n", w->wq[0]); - LOG("%f\n", w->wk[0]); - LOG("%f\n", w->wv[0]); - LOG("%f\n", w->wo[0]); - LOG("%f\n", w->w1[0]); - LOG("%f\n", w->w2[0]); - LOG("%f\n", w->w3[0]); - LOG("%f\n", w->rms_att_weight[0]); - if (!w->wcls.empty()) LOG("%f\n", w->wcls[0]); + LOG_INF("%f\n", w->wq[0]); + LOG_INF("%f\n", w->wk[0]); + LOG_INF("%f\n", w->wv[0]); + LOG_INF("%f\n", w->wo[0]); + LOG_INF("%f\n", w->w1[0]); + LOG_INF("%f\n", w->w2[0]); + LOG_INF("%f\n", w->w3[0]); + LOG_INF("%f\n", w->rms_att_weight[0]); + if (!w->wcls.empty()) LOG_INF("%f\n", w->wcls[0]); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////// ggml structs and functions required to load models, configs and save the model. -struct llama_vocab { +struct my_llama_vocab { using id = int32_t; using token = std::string; using ttype = llama_token_type; @@ -318,20 +321,20 @@ struct train_params { }; static void print_params(struct my_llama_hparams * params) { - LOG("%s: n_vocab: %u\n", __func__, params->n_vocab); - LOG("%s: n_ctx: %u\n", __func__, params->n_ctx); - LOG("%s: n_embd: %u\n", __func__, params->n_embd); - LOG("%s: n_mult: %u\n", __func__, params->n_mult); - LOG("%s: n_head: %u\n", __func__, params->n_head); - LOG("%s: n_head_kv: %u\n", __func__, params->n_head_kv); - LOG("%s: n_ff: %u\n", __func__, params->n_ff); - LOG("%s: n_layer: %u\n", __func__, params->n_layer); - LOG("%s: n_rot: %u\n", __func__, params->n_rot); + LOG_INF("%s: n_vocab: %u\n", __func__, params->n_vocab); + LOG_INF("%s: n_ctx: %u\n", __func__, params->n_ctx); + LOG_INF("%s: n_embd: %u\n", __func__, params->n_embd); + LOG_INF("%s: n_mult: %u\n", __func__, params->n_mult); + LOG_INF("%s: n_head: %u\n", __func__, params->n_head); + LOG_INF("%s: n_head_kv: %u\n", __func__, params->n_head_kv); + LOG_INF("%s: n_ff: %u\n", __func__, params->n_ff); + LOG_INF("%s: n_layer: %u\n", __func__, params->n_layer); + LOG_INF("%s: n_rot: %u\n", __func__, params->n_rot); } static void print_tensor_info(const struct ggml_context * ctx) { for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - LOG("%s: Allocating ", __func__); + LOG_INF("%s: Allocating ", __func__); int64_t total = 1; int i = 0; for (; i < ggml_n_dims(t); ++i) { @@ -433,12 +436,12 @@ static void print_matrix(struct ggml_tensor * probs) { } } -struct llama_file { +struct my_llama_file { // use FILE * so we don't have to re-open the file to mmap FILE * fp; size_t size; - llama_file(const char * fname, const char * mode) { + my_llama_file(const char * fname, const char * mode) { fp = std::fopen(fname, mode); if (fp == NULL) { size = 0; @@ -499,7 +502,7 @@ struct llama_file { return std::string(chars.data(), len); } - ~llama_file() { + ~my_llama_file() { if (fp) { std::fclose(fp); } @@ -507,7 +510,7 @@ struct llama_file { }; static bool is_ggml_file(const char * filename) { - llama_file file(filename, "rb"); + my_llama_file file(filename, "rb"); if (file.size < 4) { return false; } @@ -524,9 +527,9 @@ static std::string llama_escape_whitespaces(const std::string & text) { return out.str(); } -static void load_vocab(const char * filename, const Config * config, struct llama_vocab * vocab) { +static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) { if (is_ggml_file(filename)) { - LOG("%s: Loading vocabulary from gguf file %s\n", __func__, filename); + LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename); struct ggml_context * ctx_data = NULL; struct gguf_init_params params = { @@ -574,21 +577,21 @@ static void load_vocab(const char * filename, const Config * config, struct llam gguf_free(ctx); } else { // assume llama2.c vocabulary - LOG("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename); - llama_file file(filename, "rb"); + LOG_INF("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename); + my_llama_file file(filename, "rb"); if (!file.fp) { die_fmt("%s: %s", strerror(errno), filename); } const int n_vocab = config->vocab_size; /* uint32_t max_token_length = */ file.read_u32(); // unused vocab->id_to_token.resize(n_vocab); - for (llama_vocab::id id=0; idtoken_embedding_table -> model->tok_embeddings @@ -670,7 +673,7 @@ static void save_as_llama_model( std::vector tokens; std::vector scores; std::vector token_types; - for (const llama_vocab::token_data & token_data : vocab->id_to_token) { + for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) { tokens.push_back(token_data.text.c_str()); scores.push_back(token_data.score); token_types.push_back(token_data.type); @@ -688,8 +691,8 @@ static void save_as_llama_model( gguf_set_val_u32(ctx, KV_TOKENIZER_UNK_ID, UNKNOWN_TOKEN_ID); gguf_set_val_u32(ctx, KV_TOKENIZER_BOS_ID, BOS_TOKEN_ID); gguf_set_val_u32(ctx, KV_TOKENIZER_EOS_ID, EOS_TOKEN_ID); - gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, -1); - gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, -1); + gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, LLAMA_TOKEN_NULL); + gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, LLAMA_TOKEN_NULL); gguf_set_val_u32(ctx, KV_CONTEXT_LENGTH, model->hparams.n_ctx); gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd); @@ -871,23 +874,25 @@ static std::string basename(const std::string &path) { } int main(int argc, char ** argv) { + common_init(); + struct train_params params = get_default_train_params(); if (!params_parse(argc, argv, ¶ms)) { return 1; } - log_set_target(stdout); + Config config; TransformerWeights weights = {}; { - LOG("%s: Loading llama2c model from %s\n", __func__, params.fn_llama2c_model); + LOG_INF("%s: Loading llama2c model from %s\n", __func__, params.fn_llama2c_model); FILE * file = fopen(params.fn_llama2c_model, "rb"); if (!file) { - LOG("%s: Unable to open the checkpoint file %s!\n", __func__, params.fn_llama2c_model); + LOG_ERR("%s: Unable to open the checkpoint file %s!\n", __func__, params.fn_llama2c_model); return 1; } // read in the config header if (fread(&config, sizeof(Config), 1, file) != 1) { - LOG("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model); + LOG_ERR("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model); return 1; } auto shared_weights = config.vocab_size > 0; @@ -896,17 +901,17 @@ int main(int argc, char ** argv) { // read in the Transformer weights alloc_weights(&weights, &config, shared_weights); if (checkpoint_init_weights(&weights, &config, file, shared_weights)) { - LOG("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model); + LOG_ERR("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model); return 1; } fclose(file); } - struct llama_vocab vocab; + struct my_llama_vocab vocab; load_vocab(params.fn_vocab_model, &config, &vocab); struct my_llama_model model; - model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx); + model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx); model.hparams.n_ctx = params.n_ctx; model.hparams.n_embd = config.dim; //params.n_embd; model.hparams.n_ff = config.hidden_dim; @@ -929,7 +934,7 @@ int main(int argc, char ** argv) { model.name = basename(params.fn_llama2c_model); save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model); - LOG("%s: Saving llama.c model file %s in ggml format at %s\n", __func__, params.fn_llama2c_model, params.fn_llama2c_output_model); + LOG_INF("%s: Saving llama.c model file %s in ggml format at %s\n", __func__, params.fn_llama2c_model, params.fn_llama2c_output_model); ggml_free(model.ctx); return 0; diff --git a/examples/convert_legacy_llama.py b/examples/convert_legacy_llama.py index 9ab9ab06e..c4ec5c524 100755 --- a/examples/convert_legacy_llama.py +++ b/examples/convert_legacy_llama.py @@ -840,6 +840,8 @@ class OutputFile: self.gguf.add_base_model_version(key, base_model_entry["version"]) if "organization" in base_model_entry: self.gguf.add_base_model_organization(key, base_model_entry["organization"]) + if "description" in base_model_entry: + self.gguf.add_base_model_description(key, base_model_entry["description"]) if "url" in base_model_entry: self.gguf.add_base_model_url(key, base_model_entry["url"]) if "doi" in base_model_entry: @@ -849,12 +851,32 @@ class OutputFile: if "repo_url" in base_model_entry: self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"]) + if metadata.datasets is not None: + self.gguf.add_dataset_count(len(metadata.datasets)) + for key, dataset_entry in enumerate(metadata.datasets): + if "name" in dataset_entry: + self.gguf.add_dataset_name(key, dataset_entry["name"]) + if "author" in dataset_entry: + self.gguf.add_dataset_author(key, dataset_entry["author"]) + if "version" in dataset_entry: + self.gguf.add_dataset_version(key, dataset_entry["version"]) + if "organization" in dataset_entry: + self.gguf.add_dataset_organization(key, dataset_entry["organization"]) + if "description" in dataset_entry: + self.gguf.add_dataset_description(key, dataset_entry["description"]) + if "url" in dataset_entry: + self.gguf.add_dataset_url(key, dataset_entry["url"]) + if "doi" in dataset_entry: + self.gguf.add_dataset_doi(key, dataset_entry["doi"]) + if "uuid" in dataset_entry: + self.gguf.add_dataset_uuid(key, dataset_entry["uuid"]) + if "repo_url" in dataset_entry: + self.gguf.add_dataset_repo_url(key, dataset_entry["repo_url"]) + if metadata.tags is not None: self.gguf.add_tags(metadata.tags) if metadata.languages is not None: self.gguf.add_languages(metadata.languages) - if metadata.datasets is not None: - self.gguf.add_datasets(metadata.datasets) def add_meta_arch(self, params: Params) -> None: # Metadata About The Neural Architecture Itself diff --git a/examples/cvector-generator/CMakeLists.txt b/examples/cvector-generator/CMakeLists.txt index 0a559d60c..49ad9561c 100644 --- a/examples/cvector-generator/CMakeLists.txt +++ b/examples/cvector-generator/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-cvector-generator) add_executable(${TARGET} cvector-generator.cpp pca.hpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 0795175a1..413b71d34 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -1,6 +1,9 @@ +#include "ggml.h" +#include "gguf.h" + +#include "arg.h" #include "common.h" #include "llama.h" -#include "ggml.h" #include "pca.hpp" #include "mean.hpp" @@ -12,14 +15,15 @@ #include "ggml-metal.h" #endif +#include +#include #include +#include +#include +#include #include #include #include -#include -#include -#include -#include ////////////////////////////////////////////////// @@ -29,7 +33,7 @@ template static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { std::string ret; for (; begin != end; ++begin) { - ret += llama_token_to_piece(ctx, *begin); + ret += common_token_to_piece(ctx, *begin); } return ret; @@ -269,9 +273,11 @@ struct tokenized_prompt { size_t max_seq_len; tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); - tokens_pos = ::llama_tokenize(ctx, pos, add_bos, true); - tokens_neg = ::llama_tokenize(ctx, neg, add_bos, true); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); + tokens_pos = common_tokenize(ctx, pos, add_bos, true); + tokens_neg = common_tokenize(ctx, neg, add_bos, true); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); padding_seq(ctx, tokens_pos, max_seq_len); padding_seq(ctx, tokens_neg, max_seq_len); @@ -279,7 +285,7 @@ struct tokenized_prompt { void padding_seq(llama_context * ctx, std::vector & tokens, size_t len) { // TODO: customize padding token - std::vector pad_tokens = ::llama_tokenize(ctx, " ", false); + std::vector pad_tokens = common_tokenize(ctx, " ", false); llama_token pad_tok = pad_tokens.back(); while (tokens.size() < len) { tokens.push_back(pad_tok); @@ -337,7 +343,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_cache_clear(ctx); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } @@ -368,7 +374,7 @@ static void export_gguf(const std::vector & v_ctrl, const * Load prompt files and completion file. * Then format each pair of prompt + completion to make an entry. */ -static int prepare_entries(gpt_params & params, train_context & ctx_train) { +static int prepare_entries(common_params & params, train_context & ctx_train) { // load prompts std::vector positive_prompts = ctrlvec_load_prompt_file(params.cvector_positive_file, true); std::vector negative_prompts = ctrlvec_load_prompt_file(params.cvector_negative_file, true); @@ -386,10 +392,9 @@ static int prepare_entries(gpt_params & params, train_context & ctx_train) { } int main(int argc, char ** argv) { - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CVECTOR_GENERATOR, print_usage)) { return 1; } @@ -412,14 +417,15 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the model to get hparams - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); // int n_ctx = llama_n_ctx(ctx); - int n_layers = llama_n_layer(model); - int n_embd = llama_n_embd(model); + int n_layers = llama_model_n_layer(model); + int n_embd = llama_model_n_embd(model); + // get model hint param (a.k.a model arch name) char model_hint[128]; llama_model_meta_val_str(model, "general.architecture", model_hint, 128); @@ -473,8 +479,6 @@ int main(int argc, char ** argv) { // done with the model, we can now free it to make gain some memory printf("Done evaluate prompts, unload model...\n"); - llama_free(ctx); - llama_free_model(model); bool use_pca = params.cvector_dimre_method == DIMRE_METHOD_PCA; diff --git a/examples/cvector-generator/mean.hpp b/examples/cvector-generator/mean.hpp index 16be5ce3e..4eeac1eeb 100644 --- a/examples/cvector-generator/mean.hpp +++ b/examples/cvector-generator/mean.hpp @@ -15,7 +15,7 @@ static void run( for (size_t il = 0; il < v_input.size(); ++il) { // prepare output vector struct ggml_tensor * ctrl_out = v_output[il]; - ggml_format_name(ctrl_out, "direction.%ld", il+1); + ggml_format_name(ctrl_out, "direction.%zu", il+1); // calculate mean vector struct ggml_tensor * t_layer = v_input[il]; diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 6ec3141af..e88bbdde9 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -12,12 +12,9 @@ #include #include +#include #include -#include #include -#include -#include -#include #define DEBUG_POS 5 @@ -207,13 +204,6 @@ static ggml_status compute_piter( ggml_backend_cpu_set_n_threads(model.backend, params.n_threads); } -// TODO: enable GPU support when support for GGML_OP_SQRT is added -//#ifdef GGML_USE_METAL -// if (ggml_backend_is_metal(model.backend)) { -// ggml_backend_metal_set_n_cb(model.backend, params.n_threads); -// } -//#endif - ggml_status res = ggml_backend_graph_compute(model.backend, gf); if (res == GGML_STATUS_SUCCESS) { auto extract_i = [](std::string prefix, std::string str) -> int { @@ -229,8 +219,8 @@ static ggml_status compute_piter( result.eigenvectors.resize(params.n_batch); result.distances.resize(params.n_batch); // get output nodes - for (int i = 0; i < gf->n_nodes; ++i) { - auto node = gf->nodes[i]; + for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) { + auto node = ggml_graph_node(gf, i); int iter = -1; // find b_tensor (without copying data from device) if ((iter = extract_i("b_tensor_norm_", node->name)) > -1) { @@ -312,7 +302,7 @@ static void run_pca( // prepare output vector struct ggml_tensor * ctrl_out = v_output[il]; - ggml_format_name(ctrl_out, "direction.%ld", il+1); + ggml_format_name(ctrl_out, "direction.%zu", il+1); // run power_iteration params.i_layer = il; diff --git a/examples/deprecation-warning/deprecation-warning.cpp b/examples/deprecation-warning/deprecation-warning.cpp index 11b35d2c2..c2958ea12 100644 --- a/examples/deprecation-warning/deprecation-warning.cpp +++ b/examples/deprecation-warning/deprecation-warning.cpp @@ -12,7 +12,7 @@ int main(int argc, char** argv) { } // Get only the program name from the full path - auto pos = filename.find_last_of('/'); + auto pos = filename.find_last_of("/\\"); if (pos != std::string::npos) { filename = filename.substr(pos+1); } diff --git a/examples/embedding/CMakeLists.txt b/examples/embedding/CMakeLists.txt index 8256e789a..809040307 100644 --- a/examples/embedding/CMakeLists.txt +++ b/examples/embedding/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-embedding) add_executable(${TARGET} embedding.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 630f7c1c7..38d22c90f 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -1,4 +1,6 @@ +#include "arg.h" #include "common.h" +#include "log.h" #include "llama.h" #include @@ -26,7 +28,7 @@ static std::vector split_lines(const std::string & s, const std::st static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, true); + common_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -38,16 +40,16 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu llama_kv_cache_clear(ctx); // run model - fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { // encoder-only model if (llama_encode(ctx, batch) < 0) { - fprintf(stderr, "%s : failed to encode\n", __func__); + LOG_ERR("%s : failed to encode\n", __func__); } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { // decoder-only model if (llama_decode(ctx, batch) < 0) { - fprintf(stderr, "%s : failed to decode\n", __func__); + LOG_ERR("%s : failed to decode\n", __func__); } } @@ -72,58 +74,58 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } float * out = output + embd_pos * n_embd; - llama_embd_normalize(embd, out, n_embd, embd_norm); + common_embd_normalize(embd, out, n_embd, embd_norm); } } int main(int argc, char ** argv) { - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_EMBEDDING); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) { return 1; } + common_init(); + params.embedding = true; // For non-causal models, batch size must be equal to ubatch size params.n_ubatch = params.n_batch; - print_build_info(); - - LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); - llama_backend_init(); llama_numa_init(params.numa); // load the model - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); + + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; if (model == NULL) { - fprintf(stderr, "%s: error: unable to load model\n", __func__); + LOG_ERR("%s: unable to load model\n", __func__); return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) { - fprintf(stderr, "%s: error: computing embeddings in encoder-decoder models is not supported\n", __func__); + LOG_ERR("%s: computing embeddings in encoder-decoder models is not supported\n", __func__); return 1; } if (n_ctx > n_ctx_train) { - fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", + LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); } // print system information { - fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } // split the prompt into lines @@ -136,9 +138,9 @@ int main(int argc, char ** argv) { // tokenize the prompts and trim std::vector> inputs; for (const auto & prompt : prompts) { - auto inp = ::llama_tokenize(ctx, prompt, true, false); + auto inp = common_tokenize(ctx, prompt, true, true); if (inp.size() > n_batch) { - fprintf(stderr, "%s: error: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", + LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", __func__, (long long int) inp.size(), (long long int) n_batch); return 1; } @@ -148,21 +150,21 @@ int main(int argc, char ** argv) { // check if the last token is SEP // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' for (auto & inp : inputs) { - if (inp.empty() || inp.back() != llama_token_sep(model)) { - fprintf(stderr, "%s: warning: last token in the prompt is not SEP\n", __func__); - fprintf(stderr, "%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); + if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) { + LOG_WRN("%s: last token in the prompt is not SEP\n", __func__); + LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); } } // tokenization stats if (params.verbose_prompt) { for (int i = 0; i < (int) inputs.size(); i++) { - fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size()); + LOG_INF("%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str()); + LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size()); for (int j = 0; j < (int) inputs[i].size(); j++) { - fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str()); + LOG("%6d -> '%s'\n", inputs[i][j], common_token_to_piece(ctx, inputs[i][j]).c_str()); } - fprintf(stderr, "\n\n"); + LOG("\n\n"); } } @@ -181,7 +183,7 @@ int main(int argc, char ** argv) { } // allocate output - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embeddings(n_embd_count * n_embd, 0); float * emb = embeddings.data(); @@ -200,7 +202,7 @@ int main(int argc, char ** argv) { batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; s = 0; - llama_batch_clear(batch); + common_batch_clear(batch); } // add to batch @@ -213,57 +215,62 @@ int main(int argc, char ** argv) { batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); if (params.embd_out.empty()) { - fprintf(stdout, "\n"); + LOG("\n"); if (pooling_type == LLAMA_POOLING_TYPE_NONE) { for (int j = 0; j < n_embd_count; j++) { - fprintf(stdout, "embedding %d: ", j); + LOG("embedding %d: ", j); for (int i = 0; i < std::min(3, n_embd); i++) { if (params.embd_normalize == 0) { - fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + LOG("%6.0f ", emb[j * n_embd + i]); } else { - fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + LOG("%9.6f ", emb[j * n_embd + i]); } } - fprintf(stdout, " ... "); + LOG(" ... "); for (int i = n_embd - 3; i < n_embd; i++) { if (params.embd_normalize == 0) { - fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + LOG("%6.0f ", emb[j * n_embd + i]); } else { - fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + LOG("%9.6f ", emb[j * n_embd + i]); } } - fprintf(stdout, "\n"); + LOG("\n"); + } + } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + for (int j = 0; j < n_embd_count; j++) { + // NOTE: if you change this log - update the tests in ci/run.sh + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); } } else { // print the first part of the embeddings or for a single prompt, the full embedding for (int j = 0; j < n_prompts; j++) { - fprintf(stdout, "embedding %d: ", j); + LOG("embedding %d: ", j); for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { if (params.embd_normalize == 0) { - fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + LOG("%6.0f ", emb[j * n_embd + i]); } else { - fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + LOG("%9.6f ", emb[j * n_embd + i]); } } - fprintf(stdout, "\n"); + LOG("\n"); } // print cosine similarity matrix if (n_prompts > 1) { - fprintf(stdout, "\n"); - printf("cosine similarity matrix:\n\n"); + LOG("\n"); + LOG("cosine similarity matrix:\n\n"); for (int i = 0; i < n_prompts; i++) { - fprintf(stdout, "%6.6s ", prompts[i].c_str()); + LOG("%6.6s ", prompts[i].c_str()); } - fprintf(stdout, "\n"); + LOG("\n"); for (int i = 0; i < n_prompts; i++) { for (int j = 0; j < n_prompts; j++) { - float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); - fprintf(stdout, "%6.2f ", sim); + float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + LOG("%6.2f ", sim); } - fprintf(stdout, "%1.10s", prompts[i].c_str()); - fprintf(stdout, "\n"); + LOG("%1.10s", prompts[i].c_str()); + LOG("\n"); } } } @@ -272,48 +279,46 @@ int main(int argc, char ** argv) { if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") { const bool notArray = params.embd_out != "array"; - fprintf(stdout, notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "["); + LOG(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "["); for (int j = 0;;) { // at least one iteration (one prompt) - if (notArray) fprintf(stdout, " {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); - fprintf(stdout, "["); + if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); + LOG("["); for (int i = 0;;) { // at least one iteration (n_embd > 0) - fprintf(stdout, params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]); + LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]); i++; - if (i < n_embd) fprintf(stdout, ","); else break; + if (i < n_embd) LOG(","); else break; } - fprintf(stdout, notArray ? "]\n }" : "]"); + LOG(notArray ? "]\n }" : "]"); j++; - if (j < n_embd_count) fprintf(stdout, notArray ? ",\n" : ","); else break; + if (j < n_embd_count) LOG(notArray ? ",\n" : ","); else break; } - fprintf(stdout, notArray ? "\n ]" : "]\n"); + LOG(notArray ? "\n ]" : "]\n"); if (params.embd_out == "json+" && n_prompts > 1) { - fprintf(stdout, ",\n \"cosineSimilarity\": [\n"); + LOG(",\n \"cosineSimilarity\": [\n"); for (int i = 0;;) { // at least two iteration (n_embd_count > 1) - fprintf(stdout, " ["); + LOG(" ["); for (int j = 0;;) { // at least two iteration (n_embd_count > 1) - float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); - fprintf(stdout, "%6.2f", sim); + float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + LOG("%6.2f", sim); j++; - if (j < n_embd_count) fprintf(stdout, ", "); else break; + if (j < n_embd_count) LOG(", "); else break; } - fprintf(stdout, " ]"); + LOG(" ]"); i++; - if (i < n_embd_count) fprintf(stdout, ",\n"); else break; + if (i < n_embd_count) LOG(",\n"); else break; } - fprintf(stdout, "\n ]"); + LOG("\n ]"); } - if (notArray) fprintf(stdout, "\n}\n"); + if (notArray) LOG("\n}\n"); } - LOG_TEE("\n"); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + LOG("\n"); + llama_perf_context_print(ctx); // clean up llama_batch_free(batch); - llama_free(ctx); - llama_free_model(model); llama_backend_free(); return 0; diff --git a/examples/eval-callback/CMakeLists.txt b/examples/eval-callback/CMakeLists.txt index a48753d38..95915ed91 100644 --- a/examples/eval-callback/CMakeLists.txt +++ b/examples/eval-callback/CMakeLists.txt @@ -2,8 +2,9 @@ set(TARGET llama-eval-callback) add_executable(${TARGET} eval-callback.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) set(TEST_TARGET test-eval-callback) -add_test(NAME ${TEST_TARGET} COMMAND llama-eval-callback --hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf --model stories260K.gguf --prompt hello --seed 42 -ngl 0) +add_test(NAME ${TEST_TARGET} + COMMAND llama-eval-callback --hf-repo ggml-org/models --hf-file tinyllamas/stories260K.gguf --model stories260K.gguf --prompt hello --seed 42 -ngl 0) set_property(TEST ${TEST_TARGET} PROPERTY LABELS eval-callback curl) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 881111ffd..fb188f5a9 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -1,11 +1,11 @@ +#include "arg.h" #include "common.h" +#include "log.h" #include "llama.h" #include "ggml.h" #include -#include #include -#include #include /** @@ -31,22 +31,22 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne GGML_ASSERT(n > 0); float sum = 0; for (int64_t i3 = 0; i3 < ne[3]; i3++) { - printf(" [\n"); + LOG(" [\n"); for (int64_t i2 = 0; i2 < ne[2]; i2++) { if (i2 == n && ne[2] > 2*n) { - printf(" ..., \n"); + LOG(" ..., \n"); i2 = ne[2] - n; } - printf(" [\n"); + LOG(" [\n"); for (int64_t i1 = 0; i1 < ne[1]; i1++) { if (i1 == n && ne[1] > 2*n) { - printf(" ..., \n"); + LOG(" ..., \n"); i1 = ne[1] - n; } - printf(" ["); + LOG(" ["); for (int64_t i0 = 0; i0 < ne[0]; i0++) { if (i0 == n && ne[0] > 2*n) { - printf("..., "); + LOG("..., "); i0 = ne[0] - n; } size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; @@ -64,16 +64,16 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } else { GGML_ABORT("fatal error"); } - printf("%12.4f", v); + LOG("%12.4f", v); sum += v; - if (i0 < ne[0] - 1) printf(", "); + if (i0 < ne[0] - 1) LOG(", "); } - printf("],\n"); + LOG("],\n"); } - printf(" ],\n"); + LOG(" ],\n"); } - printf(" ]\n"); - printf(" sum = %f\n", sum); + LOG(" ]\n"); + LOG(" sum = %f\n", sum); } } @@ -102,11 +102,11 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); } - printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, - t->name, ggml_type_name(t->type), ggml_op_desc(t), - src0->name, ggml_ne_string(src0).c_str(), - src1 ? src1_str : "", - ggml_ne_string(t).c_str()); + LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + t->name, ggml_type_name(t->type), ggml_op_desc(t), + src0->name, ggml_ne_string(src0).c_str(), + src1 ? src1_str : "", + ggml_ne_string(t).c_str()); // copy the data from the GPU memory if needed @@ -126,13 +126,16 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { return true; } -static bool run(llama_context * ctx, const gpt_params & params) { - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); +static bool run(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); - std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + const bool add_bos = llama_vocab_get_add_bos(vocab); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { - fprintf(stderr, "%s : failed to eval\n", __func__); + std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + LOG_ERR("%s : failed to eval\n", __func__); return false; } @@ -142,14 +145,13 @@ static bool run(llama_context * ctx, const gpt_params & params) { int main(int argc, char ** argv) { callback_data cb_data; - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } - print_build_info(); + common_init(); llama_backend_init(); llama_numa_init(params.numa); @@ -161,19 +163,21 @@ int main(int argc, char ** argv) { params.warmup = false; // init - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); + + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; if (model == nullptr || ctx == nullptr) { - fprintf(stderr, "%s : failed to init\n", __func__); + LOG_ERR("%s : failed to init\n", __func__); return 1; } // print system information { - fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); } bool OK = run(ctx, params); @@ -181,11 +185,8 @@ int main(int argc, char ** argv) { return 1; } - LOG_TEE("\n"); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); - - llama_free(ctx); - llama_free_model(model); + LOG("\n"); + llama_perf_context_print(ctx); llama_backend_free(); diff --git a/examples/export-lora/CMakeLists.txt b/examples/export-lora/CMakeLists.txt index 1cef6e716..310455787 100644 --- a/examples/export-lora/CMakeLists.txt +++ b/examples/export-lora/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-export-lora) add_executable(${TARGET} export-lora.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index 544e7fff6..91238e4be 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -1,11 +1,13 @@ -#include "common.h" #include "ggml.h" #include "ggml-alloc.h" +#include "gguf.h" + +#include "arg.h" +#include "common.h" #include #include #include -#include #include static bool g_verbose = false; @@ -127,7 +129,7 @@ struct lora_merge_ctx { lora_merge_ctx( std::string & base_fname, - std::vector & lora_files, + std::vector & lora_files, std::string & outfile, int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) { fout.exceptions(std::ofstream::failbit); // fail fast on write errors @@ -264,8 +266,8 @@ struct lora_merge_ctx { fout.write((const char *)data.data(), data.size()); } - printf("%s : merged %ld tensors with lora adapters\n", __func__, n_merged); - printf("%s : wrote %ld tensors to output file\n", __func__, trans.size()); + printf("%s : merged %zu tensors with lora adapters\n", __func__, n_merged); + printf("%s : wrote %zu tensors to output file\n", __func__, trans.size()); } void copy_tensor(struct ggml_tensor * base) { @@ -313,9 +315,9 @@ struct lora_merge_ctx { // optionally dequantize it printf("%s : + dequantize base tensor from %s to F32\n", __func__, ggml_type_name(base->type)); auto nels = ggml_nelements(inp_base); - ggml_type_traits_t qtype = ggml_internal_get_type_traits(base->type); + const auto * qtype = ggml_get_type_traits(base->type); std::vector dequant_buf(nels * sizeof(float)); - qtype.to_float(read_buf.data(), (float *)dequant_buf.data(), nels); + qtype->to_float(read_buf.data(), (float *)dequant_buf.data(), nels); ggml_backend_tensor_set(inp_base, dequant_buf.data(), 0, dequant_buf.size()); } else { ggml_backend_tensor_set(inp_base, read_buf.data(), 0, ggml_nbytes(inp_base)); @@ -343,15 +345,25 @@ struct lora_merge_ctx { gf = ggml_new_graph(ctx0); struct ggml_tensor * cur = inp_base; for (size_t i = 0; i < adapters.size(); ++i) { - struct ggml_tensor * a_T = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))); - struct ggml_tensor * delta = ggml_mul_mat(ctx0, a_T, ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32)); + struct ggml_tensor * delta; + bool is_tok_embd = string_starts_with(name_base, "token_embd"); + if (is_tok_embd) { + printf("%s : detected token embeddings tensor\n", __func__); + delta = ggml_mul_mat(ctx0, + ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32), + ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32)); + } else { + delta = ggml_mul_mat(ctx0, + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_cast(ctx0, inp_a[i], GGML_TYPE_F32))), + ggml_cast(ctx0, inp_b[i], GGML_TYPE_F32)); + } // scale const float alpha = adapters[i]->alpha; const float rank = (float) inp_b[i]->ne[0]; const float scale = alpha ? adapters[i]->scale * alpha / rank : adapters[i]->scale; delta = ggml_scale(ctx0, delta, scale); cur = ggml_add(ctx0, delta, cur); - printf("%s : + merging from adapter[%ld] type=%s\n", __func__, i, ggml_type_name(inp_a[i]->type)); + printf("%s : + merging from adapter[%zu] type=%s\n", __func__, i, ggml_type_name(inp_a[i]->type)); printf("%s : input_scale=%f calculated_scale=%f rank=%d\n", __func__, adapters[i]->scale, scale, (int) inp_b[i]->ne[0]); } cur = ggml_cast(ctx0, cur, out->type); @@ -369,7 +381,7 @@ struct lora_merge_ctx { // write data to output file { - auto result = gf->nodes[gf->n_nodes - 1]; + auto * result = ggml_graph_node(gf, -1); size_t len = ggml_nbytes(result); if (read_buf.size() < len) { read_buf.resize(len); @@ -399,14 +411,13 @@ static void print_usage(int, char ** argv) { } int main(int argc, char ** argv) { - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EXPORT_LORA, print_usage)) { return 1; } - g_verbose = (params.verbosity == 1); + g_verbose = (params.verbosity > 1); try { lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.cpuparams.n_threads); ctx.run_merge(); diff --git a/examples/gbnf-validator/CMakeLists.txt b/examples/gbnf-validator/CMakeLists.txt index 4edd6ec73..d2cb524c0 100644 --- a/examples/gbnf-validator/CMakeLists.txt +++ b/examples/gbnf-validator/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-gbnf-validator) add_executable(${TARGET} gbnf-validator.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 7493af9d3..a610e6a0b 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -11,19 +11,15 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { const auto cpts = unicode_cpts_from_utf8(input_str); - const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); + auto & stacks_cur = llama_grammar_get_stacks(grammar); size_t pos = 0; for (const auto & cpt : cpts) { - const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - - llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); + llama_grammar_accept(grammar, cpt); if (stacks_cur.empty()) { error_pos = pos; error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; - stacks_cur = stacks_prev; return false; } ++pos; @@ -80,9 +76,10 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } - llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); + llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0); if (grammar == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); + fprintf(stdout, "Failed to initialize llama_grammar\n"); + return 1; } // Read the input file std::string input_str; diff --git a/examples/gen-docs/CMakeLists.txt b/examples/gen-docs/CMakeLists.txt index c94cda776..25de0af35 100644 --- a/examples/gen-docs/CMakeLists.txt +++ b/examples/gen-docs/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-gen-docs) add_executable(${TARGET} gen-docs.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gen-docs/gen-docs.cpp b/examples/gen-docs/gen-docs.cpp index 8b1dafd63..77c59a836 100644 --- a/examples/gen-docs/gen-docs.cpp +++ b/examples/gen-docs/gen-docs.cpp @@ -1,3 +1,4 @@ +#include "arg.h" #include "common.h" #include @@ -5,42 +6,73 @@ // Export usage message (-h) to markdown format +static void write_table_header(std::ofstream & file) { + file << "| Argument | Explanation |\n"; + file << "| -------- | ----------- |\n"; +} + +static void write_table_entry(std::ofstream & file, const common_arg & opt) { + file << "| `"; + // args + for (const auto & arg : opt.args) { + if (arg == opt.args.front()) { + file << arg; + if (opt.args.size() > 1) file << ", "; + } else { + file << arg << (arg != opt.args.back() ? ", " : ""); + } + } + // value hint + if (opt.value_hint) { + std::string md_value_hint(opt.value_hint); + string_replace_all(md_value_hint, "|", "\\|"); + file << " " << md_value_hint; + } + if (opt.value_hint_2) { + std::string md_value_hint_2(opt.value_hint_2); + string_replace_all(md_value_hint_2, "|", "\\|"); + file << " " << md_value_hint_2; + } + // help text + std::string md_help(opt.help); + string_replace_all(md_help, "\n", "
"); + string_replace_all(md_help, "|", "\\|"); + file << "` | " << md_help << " |\n"; +} + +static void write_table(std::ofstream & file, std::vector & opts) { + write_table_header(file); + for (const auto & opt : opts) { + write_table_entry(file, *opt); + } +} + static void export_md(std::string fname, llama_example ex) { std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc); - gpt_params params; - auto options = gpt_params_parser_init(params, ex); + common_params params; + auto ctx_arg = common_params_parser_init(params, ex); - file << "| Argument | Explanation |\n"; - file << "| -------- | ----------- |\n"; - for (auto & opt : options) { - file << "| `"; - // args - for (const auto & arg : opt.args) { - if (arg == opt.args.front()) { - file << arg; - if (opt.args.size() > 1) file << ", "; - } else { - file << arg << (arg != opt.args.back() ? ", " : ""); - } + std::vector common_options; + std::vector sparam_options; + std::vector specific_options; + for (auto & opt : ctx_arg.options) { + // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example + if (opt.is_sparam) { + sparam_options.push_back(&opt); + } else if (opt.in_example(ctx_arg.ex)) { + specific_options.push_back(&opt); + } else { + common_options.push_back(&opt); } - // value hint - if (opt.value_hint) { - std::string md_value_hint(opt.value_hint); - string_replace_all(md_value_hint, "|", "\\|"); - file << " " << md_value_hint; - } - if (opt.value_hint_2) { - std::string md_value_hint_2(opt.value_hint_2); - string_replace_all(md_value_hint_2, "|", "\\|"); - file << " " << md_value_hint_2; - } - // help text - std::string md_help(opt.help); - string_replace_all(md_help, "\n", "
"); - string_replace_all(md_help, "|", "\\|"); - file << "` | " << md_help << " |\n"; } + + file << "**Common params**\n\n"; + write_table(file, common_options); + file << "\n\n**Sampling params**\n\n"; + write_table(file, sparam_options); + file << "\n\n**Example-specific params**\n\n"; + write_table(file, specific_options); } int main(int, char **) { diff --git a/examples/gguf-hash/CMakeLists.txt b/examples/gguf-hash/CMakeLists.txt index 633f45535..15c5c68c6 100644 --- a/examples/gguf-hash/CMakeLists.txt +++ b/examples/gguf-hash/CMakeLists.txt @@ -4,12 +4,19 @@ install(TARGETS ${TARGET} RUNTIME) # clibs dependencies include_directories(deps/) + add_library(xxhash OBJECT deps/xxhash/xxhash.c deps/xxhash/xxhash.h) target_link_libraries(${TARGET} PRIVATE xxhash) + add_library(sha1 OBJECT deps/sha1/sha1.c deps/sha1/sha1.h) target_link_libraries(${TARGET} PRIVATE sha1) +if (NOT MSVC) + # disable warnings in 3rd party code + target_compile_options(sha1 PRIVATE -w) +endif() + add_library(sha256 OBJECT deps/sha256/sha256.c deps/sha256/sha256.h) target_link_libraries(${TARGET} PRIVATE sha256) target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gguf-hash/gguf-hash.cpp b/examples/gguf-hash/gguf-hash.cpp index e96c75117..9523ec122 100644 --- a/examples/gguf-hash/gguf-hash.cpp +++ b/examples/gguf-hash/gguf-hash.cpp @@ -1,4 +1,5 @@ #include "ggml.h" +#include "gguf.h" #include /* abort() */ #include diff --git a/examples/gguf-split/CMakeLists.txt b/examples/gguf-split/CMakeLists.txt index f63887da7..c407e2f0a 100644 --- a/examples/gguf-split/CMakeLists.txt +++ b/examples/gguf-split/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-gguf-split) add_executable(${TARGET} gguf-split.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gguf-split/gguf-split.cpp b/examples/gguf-split/gguf-split.cpp index 881f0451c..ef3ceb686 100644 --- a/examples/gguf-split/gguf-split.cpp +++ b/examples/gguf-split/gguf-split.cpp @@ -1,18 +1,19 @@ +#include "ggml.h" +#include "gguf.h" #include "llama.h" #include "common.h" #include -#include +#include +#include +#include #include +#include +#include #include #include #include -#include -#include -#include -#include - #if defined(_WIN32) #include #ifndef PATH_MAX @@ -22,12 +23,20 @@ #endif enum split_operation : uint8_t { - SPLIT_OP_SPLIT, - SPLIT_OP_MERGE, + OP_NONE, + OP_SPLIT, + OP_MERGE, +}; + +enum split_mode : uint8_t { + MODE_NONE, + MODE_TENSOR, + MODE_SIZE, }; struct split_params { - split_operation operation = SPLIT_OP_SPLIT; + split_operation operation = OP_NONE; + split_mode mode = MODE_NONE; size_t n_bytes_split = 0; int n_split_tensors = 128; std::string input; @@ -87,59 +96,52 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p } bool arg_found = false; - bool is_op_set = false; - bool is_mode_set = false; if (arg == "-h" || arg == "--help") { split_print_usage(argv[0]); exit(0); - } - if (arg == "--version") { + } else if (arg == "--version") { fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); exit(0); - } - if (arg == "--dry-run") { + } else if (arg == "--dry-run") { arg_found = true; params.dry_run = true; - } - if (arg == "--no-tensor-first-split") { + } else if (arg == "--no-tensor-first-split") { arg_found = true; params.no_tensor_first_split = true; - } - - if (is_op_set) { - throw std::invalid_argument("error: either --split or --merge can be specified, but not both"); - } - if (arg == "--merge") { + } else if (arg == "--merge") { arg_found = true; - is_op_set = true; - params.operation = SPLIT_OP_MERGE; - } - if (arg == "--split") { + if (params.operation != OP_NONE && params.operation != OP_MERGE) { + throw std::invalid_argument("error: either --split or --merge can be specified, but not both"); + } + params.operation = OP_MERGE; + } else if (arg == "--split") { arg_found = true; - is_op_set = true; - params.operation = SPLIT_OP_SPLIT; - } - - if (is_mode_set) { - throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both"); - } - if (arg == "--split-max-tensors") { + if (params.operation != OP_NONE && params.operation != OP_SPLIT) { + throw std::invalid_argument("error: either --split or --merge can be specified, but not both"); + } + params.operation = OP_SPLIT; + } else if (arg == "--split-max-tensors") { if (++arg_idx >= argc) { invalid_param = true; break; } arg_found = true; - is_mode_set = true; + if (params.mode != MODE_NONE && params.mode != MODE_TENSOR) { + throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both"); + } + params.mode = MODE_TENSOR; params.n_split_tensors = atoi(argv[arg_idx]); - } - if (arg == "--split-max-size") { + } else if (arg == "--split-max-size") { if (++arg_idx >= argc) { invalid_param = true; break; } arg_found = true; - is_mode_set = true; + if (params.mode != MODE_NONE && params.mode != MODE_SIZE) { + throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both"); + } + params.mode = MODE_SIZE; params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]); } @@ -148,11 +150,20 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p } } + // the operation is split if not specified + if (params.operation == OP_NONE) { + params.operation = OP_SPLIT; + } + // the split mode is by tensor if not specified + if (params.mode == MODE_NONE) { + params.mode = MODE_TENSOR; + } + if (invalid_param) { throw std::invalid_argument("error: invalid parameter for argument: " + arg); } - if (argc - arg_idx < 2) { + if (argc - arg_idx != 2) { throw std::invalid_argument("error: bad arguments"); } @@ -265,17 +276,19 @@ struct split_strategy { } bool should_split(int i_tensor, size_t next_size) { - if (params.n_bytes_split > 0) { + if (params.mode == MODE_SIZE) { // split by max size per file return next_size > params.n_bytes_split; - } else { + } else if (params.mode == MODE_TENSOR) { // split by number of tensors per file return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0; } + // should never happen + GGML_ABORT("invalid mode"); } void print_info() { - printf("n_split: %ld\n", ctx_outs.size()); + printf("n_split: %zu\n", ctx_outs.size()); int i_split = 0; for (auto & ctx_out : ctx_outs) { // re-calculate the real gguf size for each split (= metadata size + total size of all tensors) @@ -285,7 +298,7 @@ struct split_strategy { total_size += ggml_nbytes(t); } total_size = total_size / 1000 / 1000; // convert to megabytes - printf("split %05d: n_tensors = %d, total_size = %ldM\n", i_split + 1, gguf_get_n_tensors(ctx_out), total_size); + printf("split %05d: n_tensors = %" PRIi64 ", total_size = %zuM\n", i_split + 1, gguf_get_n_tensors(ctx_out), total_size); i_split++; } } @@ -389,10 +402,17 @@ static void gguf_merge(const split_params & split_params) { int n_split = 1; int total_tensors = 0; - auto * ctx_out = gguf_init_empty(); + // avoid overwriting existing output file + if (std::ifstream(split_params.output.c_str())) { + fprintf(stderr, "%s: output file %s already exists\n", __func__, split_params.output.c_str()); + exit(EXIT_FAILURE); + } + std::ofstream fout(split_params.output.c_str(), std::ios::binary); fout.exceptions(std::ofstream::failbit); // fail fast on write errors + auto * ctx_out = gguf_init_empty(); + std::vector read_data; std::vector ctx_metas; std::vector ctx_ggufs; @@ -552,9 +572,9 @@ int main(int argc, const char ** argv) { split_params_parse(argc, argv, params); switch (params.operation) { - case SPLIT_OP_SPLIT: gguf_split(params); + case OP_SPLIT: gguf_split(params); break; - case SPLIT_OP_MERGE: gguf_merge(params); + case OP_MERGE: gguf_merge(params); break; default: split_print_usage(argv[0]); exit(EXIT_FAILURE); diff --git a/examples/gguf-split/tests.sh b/examples/gguf-split/tests.sh index d5a92d605..05a932227 100755 --- a/examples/gguf-split/tests.sh +++ b/examples/gguf-split/tests.sh @@ -41,7 +41,7 @@ echo PASS echo # 2b. Test the sharded model is loading properly -$MAIN --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-00001-of-00006.gguf --n-predict 32 echo PASS echo @@ -51,7 +51,7 @@ echo PASS echo # 3b. Test the merged model is loading properly -$MAIN --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge.gguf --n-predict 32 echo PASS echo @@ -61,7 +61,7 @@ echo PASS echo # 4b. Test the sharded model is loading properly -$MAIN --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-32-tensors-00001-of-00007.gguf --n-predict 32 echo PASS echo @@ -71,7 +71,7 @@ echo #echo # 5b. Test the merged model is loading properly -#$MAIN --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32 +#$MAIN -no-cnv --model $WORK_PATH/ggml-model-merge-2.gguf --n-predict 32 #echo PASS #echo @@ -81,7 +81,7 @@ echo PASS echo # 6b. Test the sharded model is loading properly -$MAIN --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-split-2G-00001-of-00002.gguf --n-predict 32 echo PASS echo diff --git a/examples/gguf/CMakeLists.txt b/examples/gguf/CMakeLists.txt index a9569b411..fb04eb83f 100644 --- a/examples/gguf/CMakeLists.txt +++ b/examples/gguf/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-gguf) add_executable(${TARGET} gguf.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gguf/gguf.cpp b/examples/gguf/gguf.cpp index 7498f85ef..f31989c8c 100644 --- a/examples/gguf/gguf.cpp +++ b/examples/gguf/gguf.cpp @@ -1,10 +1,9 @@ #include "ggml.h" +#include "gguf.h" #include -#include #include #include -#include #include #undef MIN @@ -135,9 +134,10 @@ static bool gguf_ex_read_0(const std::string & fname) { for (int i = 0; i < n_tensors; ++i) { const char * name = gguf_get_tensor_name (ctx, i); + const size_t size = gguf_get_tensor_size (ctx, i); const size_t offset = gguf_get_tensor_offset(ctx, i); - printf("%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset); + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset); } } @@ -182,9 +182,10 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { for (int i = 0; i < n_tensors; ++i) { const char * name = gguf_get_tensor_name (ctx, i); + const size_t size = gguf_get_tensor_size (ctx, i); const size_t offset = gguf_get_tensor_offset(ctx, i); - printf("%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset); + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset); } } @@ -199,7 +200,8 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name); - printf("%s: tensor[%d]: n_dims = %d, name = %s, data = %p\n", __func__, i, ggml_n_dims(cur), cur->name, cur->data); + printf("%s: tensor[%d]: n_dims = %d, ne = (%d, %d, %d, %d), name = %s, data = %p\n", + __func__, i, ggml_n_dims(cur), int(cur->ne[0]), int(cur->ne[1]), int(cur->ne[2]), int(cur->ne[3]), cur->name, cur->data); // print first 10 elements const float * data = (const float *) cur->data; @@ -215,7 +217,7 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { const float * data = (const float *) cur->data; for (int j = 0; j < ggml_nelements(cur); ++j) { if (data[j] != 100 + i) { - fprintf(stderr, "%s: tensor[%d]: data[%d] = %f\n", __func__, i, j, data[j]); + fprintf(stderr, "%s: tensor[%d], data[%d]: found %f, expected %f\n", __func__, i, j, data[j], float(100 + i)); gguf_free(ctx); return false; } @@ -245,6 +247,8 @@ int main(int argc, char ** argv) { check_data = false; } + srand(123456); + const std::string fname(argv[1]); const std::string mode (argv[2]); diff --git a/examples/gritlm/CMakeLists.txt b/examples/gritlm/CMakeLists.txt index 86dfddca3..fa1b4dc70 100644 --- a/examples/gritlm/CMakeLists.txt +++ b/examples/gritlm/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-gritlm) add_executable(${TARGET} gritlm.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index e1efbf573..72eb46257 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -1,3 +1,4 @@ +#include "arg.h" #include "common.h" #include "llama.h" @@ -10,24 +11,25 @@ static std::vector> encode(llama_context * ctx, const std::ve std::vector> result; const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); for (uint64_t i = 0; i < sentences.size(); i++) { - llama_batch_clear(batch); + common_batch_clear(batch); const std::string input_string = instruction + sentences[i]; - std::vector inputs = llama_tokenize(model, input_string, true, false); + std::vector inputs = common_tokenize(vocab, input_string, true, false); const int32_t n_toks = inputs.size(); // GritLM seems to have EOS = "" // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 - // inputs.push_back(llama_token_eos(model)); + // inputs.push_back(llama_vocab_eos(vocab)); // we want to ignore instruction tokens for mean pooling - const int32_t n_inst = llama_tokenize(model, instruction, true, false).size(); + const int32_t n_inst = common_tokenize(vocab, instruction, true, false).size(); #ifdef GRIT_DEBUG // debug tokens - should be matching as referenced in the GritLM sample @@ -39,7 +41,7 @@ static std::vector> encode(llama_context * ctx, const std::ve // add input to batch (this increments n_tokens) for (int32_t j = 0; j < n_toks; j++) { - llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); + common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); } // clear previous kv_cache values (irrelevant for embeddings) @@ -51,7 +53,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_decode(ctx, batch); // get embedding dimensions - uint64_t n_embd = llama_n_embd(model); + uint64_t n_embd = llama_model_n_embd(model); // allocate embedding output std::vector emb_unorm(n_embd, 0.0f); @@ -74,7 +76,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } std::vector emb_norm(emb_unorm.size()); - llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd); + common_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd, 2); result.push_back(emb_norm); #ifdef GRIT_DEBUG @@ -96,7 +98,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std std::string result; const llama_model * model = llama_get_model(ctx); - llama_token eos_token = llama_token_eos(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + + llama_token eos_token = llama_vocab_eos(vocab); llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, false); @@ -104,16 +108,16 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); - std::vector inputs = llama_tokenize(model, prompt, false, true); + std::vector inputs = common_tokenize(vocab, prompt, false, true); int32_t i_current_token = 0; while (true) { - llama_batch_clear(bat); + common_batch_clear(bat); { const int32_t n_inputs = inputs.size(); for (int32_t i = 0; i < n_inputs; i++) { - llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); } } inputs.clear(); @@ -121,13 +125,12 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_decode(ctx, bat); llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); - llama_sampler_accept(smpl, token); if (token == eos_token) { break; } - std::string piece = llama_token_to_piece(ctx, token); + std::string piece = common_token_to_piece(ctx, token); if (stream) { std::printf("%s", piece.c_str()); std::fflush(stdout); @@ -152,22 +155,23 @@ static std::string gritlm_instruction(const std::string & instruction) { } int main(int argc, char * argv[]) { - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } - llama_model_params mparams = llama_model_params_from_gpt_params(params); - llama_context_params cparams = llama_context_params_from_gpt_params(params); + common_init(); + + llama_model_params mparams = common_model_params_to_llama(params); + llama_context_params cparams = common_context_params_to_llama(params); llama_backend_init(); - llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams); // create generation context - llama_context * ctx = llama_new_context_with_model(model, cparams); + llama_context * ctx = llama_init_from_model(model, cparams); auto sparams = llama_sampler_chain_default_params(); @@ -196,12 +200,12 @@ int main(int argc, char * argv[]) { const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); - const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); - const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); - const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd); - const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd); + const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); + const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); + const float cosine_sim_q1_d0 = common_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd); + const float cosine_sim_q1_d1 = common_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd); std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0); std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1); @@ -218,7 +222,7 @@ int main(int argc, char * argv[]) { llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); return 0; diff --git a/examples/imatrix/CMakeLists.txt b/examples/imatrix/CMakeLists.txt index d4c8265bd..412696c47 100644 --- a/examples/imatrix/CMakeLists.txt +++ b/examples/imatrix/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-imatrix) add_executable(${TARGET} imatrix.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/imatrix/README.md b/examples/imatrix/README.md index bb5faec94..9c056986b 100644 --- a/examples/imatrix/README.md +++ b/examples/imatrix/README.md @@ -25,8 +25,6 @@ For faster computation, make sure to use GPU offloading via the `-ngl` argument ## Example ```bash -GGML_CUDA=1 make -j - # generate importance matrix (imatrix.dat) ./llama-imatrix -m ggml-model-f16.gguf -f train-data.txt -ngl 99 diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 0e4cc8e68..99056e74c 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -1,5 +1,8 @@ +#include "arg.h" #include "common.h" +#include "log.h" #include "llama.h" +#include "gguf.h" #include #include @@ -17,12 +20,12 @@ #endif static void print_usage(int, char ** argv) { - LOG_TEE("\nexample usage:\n"); - LOG_TEE("\n %s \\\n" - " -m model.gguf -f some-text.txt [-o imatrix.gguf] [--process-output] [--verbosity 1] \\\n" + LOG("\nexample usage:\n"); + LOG("\n %s \\\n" + " -m model.gguf -f some-text.txt [-o imatrix.gguf] [--process-output] \\\n" " [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n" " [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...]\n" , argv[0]); - LOG_TEE("\n"); + LOG("\n"); } static bool str_remove_suffix(std::string & str, const std::string & suffix) { @@ -45,13 +48,13 @@ struct Stats { class IMatrixCollector { public: IMatrixCollector() = default; - void set_params(gpt_params params) { m_params = std::move(params); } + void set_params(common_params params) { m_params = std::move(params); } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); void save_imatrix(int32_t n_chunk = -1) const; bool load_imatrix(const char * file_name); private: std::unordered_map m_stats; - gpt_params m_params; + common_params m_params; std::mutex m_mutex; int32_t m_last_chunk = 0; std::vector m_src1_data; @@ -136,16 +139,14 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * e.counts.resize(n_as, 0); } else if (e.values.size() != (size_t)src1->ne[0]*n_as) { - fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); + LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); exit(1); //GGML_ABORT("fatal error"); } else if (e.counts.size() != (size_t)n_as) { - fprintf(stderr, "Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), (int)n_as); + LOG_ERR("Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), (int)n_as); exit(1); //GGML_ABORT("fatal error"); } - if (m_params.verbosity > 1) { - printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); - } + LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); // loop over all possible experts, regardless if they are used or not in the batch for (int ex = 0; ex < n_as; ++ex) { size_t e_start = ex*src1->ne[0]; @@ -167,7 +168,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[e_start + j] = std::fma(x[j], x[j], e.values[e_start + j]); if (!std::isfinite((float)e.values[e_start + j])) { - fprintf(stderr, "%f detected in %s\n", (float)e.values[e_start + j], wname.c_str()); + LOG_ERR("%f detected in %s\n", (float)e.values[e_start + j], wname.c_str()); exit(1); } } @@ -192,16 +193,14 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * e.counts.resize(1, 0); } else if (e.values.size() != (size_t)src1->ne[0]) { - fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); + LOG_ERR("%s: inconsistent size for %s (%d vs %d)\n", __func__, wname.c_str(), (int)e.values.size(), (int)src1->ne[0]); exit(1); //GGML_ABORT("fatal error"); } else if (e.counts.size() != 1) { - fprintf(stderr, "Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), 1); + LOG_ERR("Oops: inconsistent expert count for %s (%d vs %d)\n", wname.c_str(), (int)e.counts.size(), 1); exit(1); //GGML_ABORT("fatal error"); } - if (m_params.verbosity > 1) { - printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); - } + LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); // TODO: higher dimensions for (int row = 0; row < (int)src1->ne[1]; ++row) { const float * x = data + row * src1->ne[0]; @@ -209,7 +208,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * for (int j = 0; j < (int)src1->ne[0]; ++j) { e.values[j] = std::fma(x[j], x[j], e.values[j]); if (!std::isfinite((float)e.values[j])) { - fprintf(stderr, "%f detected in %s\n", (float)e.values[j], wname.c_str()); + LOG_ERR("%f detected in %s\n", (float)e.values[j], wname.c_str()); exit(1); } } @@ -263,17 +262,17 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { } if (n_zeros != 0 && is_first) { - fprintf(stderr, "\n"); + LOG_INF("\n"); is_first = false; } if (n_zeros == n_all) { - fprintf(stderr, "%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str()); + LOG_WRN("%s: entry '%40s' has no data - skipping\n", __func__, kv.first.c_str()); continue; } if (n_zeros > 0) { - fprintf(stderr, "%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); + LOG_WRN("%s: entry '%40s' has partial data (%.2f%%) - skipping\n", __func__, kv.first.c_str(), 100.0f * (n_all - n_zeros) / n_all); continue; } @@ -283,7 +282,7 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { } if (to_store.size() < m_stats.size()) { - fprintf(stderr, "%s: warning: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size()); + LOG_WRN("%s: storing only %zu out of %zu entries\n", __func__, to_store.size(), m_stats.size()); } // deterministic tensor name order @@ -328,9 +327,8 @@ void IMatrixCollector::save_imatrix(int32_t n_chunk) const { gguf_write_to_file(ctx_gguf, fname.c_str(), false); - if (m_params.verbosity > 0) { - fprintf(stderr, "\n%s: stored collected data after %d chunks in %s\n", __func__, m_last_chunk, fname.c_str()); - } + LOGV(1, "\n"); + LOG_DBGV(1, "%s: stored collected data after %d chunks in %s\n", __func__, m_last_chunk, fname.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); @@ -348,7 +346,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { } const int32_t n_entries = gguf_get_n_tensors(ctx_gguf); if (n_entries < 1) { - fprintf(stderr, "%s: no data in file %s\n", __func__, file_name); + LOG_ERR("%s: no data in file %s\n", __func__, file_name); gguf_free(ctx_gguf); ggml_free(ctx); return false; @@ -375,7 +373,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { // counts sums_counts_for[name].second = cur; } else { - fprintf(stderr, "%s: invalid imatrix tensor name: %s\n", __func__, name.c_str()); + LOG_ERR("%s: invalid imatrix tensor name: %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); return false; @@ -388,7 +386,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { const struct ggml_tensor * counts = sc.second.second; if (!sums || !counts) { - fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str()); + LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); return false; @@ -400,7 +398,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { if (e.values.empty()) { e.values.resize(nval, 0); } else if ((size_t) nval != e.values.size()) { - fprintf(stderr, "%s: mismatched sums size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) nval, e.values.size()); + LOG_ERR("%s: mismatched sums size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) nval, e.values.size()); gguf_free(ctx_gguf); ggml_free(ctx); return false; @@ -413,7 +411,7 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { // broadcast, when loading an old imatrix e.counts.resize(ncounts, e.counts[0]); } else if ((size_t) ncounts != e.counts.size()) { - fprintf(stderr, "%s: mismatched counts size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) ncounts, e.counts.size()); + LOG_ERR("%s: mismatched counts size for %s: %zu != %zu\n", __func__, name.c_str(), (size_t) ncounts, e.counts.size()); gguf_free(ctx_gguf); ggml_free(ctx); return false; @@ -511,31 +509,34 @@ static void process_logits( } } -static bool compute_imatrix(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) { - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); +static bool compute_imatrix(llama_context * ctx, const common_params & params, const int32_t n_ctx) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); auto tim1 = std::chrono::high_resolution_clock::now(); - fprintf(stderr, "%s: tokenizing the input ..\n", __func__); + LOG_INF("%s: tokenizing the input ..\n", __func__); - std::vector tokens = ::llama_tokenize(ctx, params.prompt, true); + std::vector tokens = common_tokenize(ctx, params.prompt, true); auto tim2 = std::chrono::high_resolution_clock::now(); - fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); + LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); if (params.i_chunk > 0) { if (size_t((params.i_chunk + 2)*n_ctx) >= tokens.size()) { - fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, params.i_chunk); + LOG_ERR("%s: there will be not enough tokens left after removing %d chunks\n", __func__, params.i_chunk); return false; } - fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, params.i_chunk, params.i_chunk*n_ctx); + LOG_INF("%s: removing initial %d chunks (%d tokens)\n", __func__, params.i_chunk, params.i_chunk*n_ctx); tokens.erase(tokens.begin(), tokens.begin() + params.i_chunk*n_ctx); } if (int(tokens.size()) < 2*n_ctx) { - fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx, - n_ctx); - fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); + LOG_ERR("%s: you need at least %d tokens for a context of %d tokens\n", __func__, 2*n_ctx, n_ctx); + LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n", __func__, tokens.size()); return false; } @@ -550,15 +551,13 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, cons const int n_chunk_max = tokens.size() / n_ctx; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_vocab = llama_vocab_n_tokens(vocab); const int n_batch = params.n_batch; int count = 0; double nll = 0.0; double nll2 = 0.0; - std::vector workers(std::thread::hardware_concurrency() - 1); - const int num_batches = (n_ctx + n_batch - 1) / n_batch; const int n_seq = std::max(1, n_batch / n_ctx); @@ -572,7 +571,9 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, cons logits.reserve((size_t)n_ctx * n_vocab); } - fprintf(stderr, "%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); + LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); + + std::vector workers(std::thread::hardware_concurrency() - 1); for (int i = 0; i < n_chunk; i += n_seq) { const int start = i * n_ctx; @@ -590,7 +591,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, cons const int batch_size = std::min(end - batch_start, n_batch); // clear the batch - llama_batch_clear(batch); + common_batch_clear(batch); for (int seq = 0; seq < n_seq_batch; seq++) { int seq_start = batch_start + seq*n_ctx; @@ -600,23 +601,23 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, cons // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[seq_start] = llama_token_bos(llama_get_model(ctx)); + tokens[seq_start] = llama_vocab_bos(vocab); } - for (int k = 0; k < batch_size; ++k) { // NOTE: specifying all logits to get activations for the output.weight tensor // and also for the perplexity calculation. // TODO: only get outputs when (params.process_output || params.compute_ppl) // (not possible when this skips FFN computation of the last layer) - llama_batch_add(batch, tokens[seq_start + k], j*n_batch + k, { seq }, true); + common_batch_add(batch, tokens[seq_start + k], j*n_batch + k, { seq }, true); } - + // restore the original token in case it was set to BOS tokens[seq_start] = token_org; } if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval\n", __func__); + LOG_ERR("%s : failed to eval\n", __func__); + llama_batch_free(batch); return false; } @@ -631,13 +632,13 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, cons llama_synchronize(ctx); const auto t_end = std::chrono::high_resolution_clock::now(); const float t_total = std::chrono::duration(t_end - t_start).count(); - fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); - int total_seconds = (int)(t_total*n_chunk/n_seq); + LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk / n_seq); if (total_seconds >= 60*60) { - fprintf(stderr, "%d hours ", total_seconds / (60*60)); + LOG("%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } - fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); + LOG("%.2f minutes\n", total_seconds / 60.0); } if (params.compute_ppl) { @@ -655,14 +656,15 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, cons count += n_ctx - first - 1; - printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count)); + LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count)); } fflush(stdout); logits.clear(); } } - printf("\n"); + + LOG("\n"); if (params.compute_ppl) { nll2 /= count; @@ -671,31 +673,34 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, cons nll2 -= nll * nll; if (nll2 > 0) { nll2 = sqrt(nll2/(count-1)); - printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); + LOG("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); } else { - printf("Unexpected negative standard deviation of log(prob)\n"); + LOG("Unexpected negative standard deviation of log(prob)\n"); } } + llama_batch_free(batch); + return true; } int main(int argc, char ** argv) { - gpt_params params; + common_params params; params.n_ctx = 512; params.logits_all = true; - params.verbosity = 1; + params.escape = false; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_IMATRIX, print_usage); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) { return 1; } + common_init(); + const int32_t n_ctx = params.n_ctx; if (n_ctx <= 0) { - fprintf(stderr, "%s: imatrix tool requires '--ctx-size' > 0\n", __func__); + LOG_ERR("%s: imatrix tool requires '--ctx-size' > 0\n", __func__); return 1; } @@ -712,15 +717,15 @@ int main(int argc, char ** argv) { g_collector.set_params(params); for (const auto & in_file : params.in_files) { - printf("%s : loading imatrix from '%s'\n", __func__, in_file.c_str()); + LOG_INF("%s : loading imatrix from '%s'\n", __func__, in_file.c_str()); if (!g_collector.load_imatrix(in_file.c_str())) { - fprintf(stderr, "%s : failed to load %s\n", __func__, in_file.c_str()); + LOG_ERR("%s : failed to load %s\n", __func__, in_file.c_str()); return 1; } } if (params.in_files.size() > 1) { - printf("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str()); + LOG_INF("%s : saving combined imatrix to '%s'\n", __func__, params.out_file.c_str()); g_collector.save_imatrix(); } @@ -734,38 +739,45 @@ int main(int argc, char ** argv) { params.warmup = false; // init - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); + + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; if (model == nullptr || ctx == nullptr) { - fprintf(stderr, "%s : failed to init\n", __func__); + LOG_ERR("%s : failed to init\n", __func__); return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); if (params.n_ctx > n_ctx_train) { - fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", + LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, params.n_ctx); } // print system information { - fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } - if (!compute_imatrix(ctx, params, n_ctx)) { - return 1; + if (params.prompt.empty()) { + if (params.in_files.empty()) { + LOG_ERR("Error: No prompt provided and no precomputed matrices (--in-file) to combine.\n"); + return 1; + } + LOG_INF("No prompt provided; combining precomputed matrices only.\n"); + } else { + if (!compute_imatrix(ctx, params, n_ctx)) { + return 1; + } } + g_collector.save_imatrix(); - LOG_TEE("\n"); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); - - llama_free(ctx); - llama_free_model(model); + LOG("\n"); + llama_perf_context_print(ctx); llama_backend_free(); diff --git a/examples/infill/CMakeLists.txt b/examples/infill/CMakeLists.txt index 9b1aa3b63..fb26628d8 100644 --- a/examples/infill/CMakeLists.txt +++ b/examples/infill/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-infill) add_executable(${TARGET} infill.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/infill/README.md b/examples/infill/README.md index 810a0c5e7..df4d976f2 100644 --- a/examples/infill/README.md +++ b/examples/infill/README.md @@ -14,7 +14,7 @@ In this section, we cover the most commonly used options for running the `infill - `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). - `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses. - `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text. -- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. +- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 4096, but if a LLaMA model was built with a longer context, increasing this value will provide better results for longer input/inference. - `--spm-infill`: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. ## Input Prompts diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index d06071377..489a208b6 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -1,6 +1,8 @@ +#include "arg.h" #include "common.h" - #include "console.h" +#include "sampling.h" +#include "log.h" #include "llama.h" #include @@ -33,58 +35,14 @@ static llama_context ** g_ctx; static llama_model ** g_model; -static gpt_sampler ** g_smpl; -static gpt_params * g_params; +static common_sampler ** g_smpl; +static common_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; static bool is_interacting = false; -static void write_logfile( - const llama_context * ctx, const gpt_params & params, const llama_model * model, - const std::vector & input_tokens, const std::string & output, - const std::vector & output_tokens -) { - if (params.logdir.empty()) { - return; - } - - const std::string timestamp = string_get_sortable_timestamp(); - - const bool success = fs_create_directory_with_parents(params.logdir); - if (!success) { - fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", - __func__, params.logdir.c_str()); - return; - } - - const std::string logfile_path = params.logdir + timestamp + ".yml"; - FILE * logfile = fopen(logfile_path.c_str(), "w"); - - if (logfile == NULL) { - fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); - return; - } - - fprintf(logfile, "binary: infill\n"); - char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); - yaml_dump_non_result_info(logfile, params, ctx, timestamp, input_tokens, model_desc); - - fprintf(logfile, "\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "# Generation Results #\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "\n"); - - yaml_dump_string_multiline(logfile, "output", output.c_str()); - yaml_dump_vector_int(logfile, "output_tokens", output_tokens); - - llama_perf_dump_yaml(logfile, ctx); - fclose(logfile); -} - #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) static void sigint_handler(int signo) { if (signo == SIGINT) { @@ -92,9 +50,13 @@ static void sigint_handler(int signo) { is_interacting = true; } else { console::cleanup(); - printf("\n"); - gpt_perf_print(*g_ctx, *g_smpl); - write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); + LOG("\n"); + common_perf_print(*g_ctx, *g_smpl); + + // make sure all logs are flushed + LOG("Interrupted by user\n"); + common_log_pause(common_log_main()); + _exit(130); } } @@ -102,142 +64,135 @@ static void sigint_handler(int signo) { #endif int main(int argc, char ** argv) { - gpt_params params; + common_params params; g_params = ¶ms; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_INFILL); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_INFILL)) { return 1; } - auto & sparams = params.sparams; + common_init(); -#ifndef LOG_DISABLE_LOGS - log_set_target(log_filename_generator("infill", "log")); - LOG_TEE("Log start\n"); - log_dump_cmdline(argc, argv); -#endif // LOG_DISABLE_LOGS + auto & sparams = params.sampling; console::init(params.simple_io, params.use_color); atexit([]() { console::cleanup(); }); if (params.logits_all) { - printf("\n************\n"); - printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); - printf("************\n\n"); + LOG_ERR("\n************\n"); + LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); + LOG_ERR("************\n\n"); return 0; } if (params.embedding) { - printf("\n************\n"); - printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__); - printf("************\n\n"); + LOG_ERR("\n************\n"); + LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__); + LOG_ERR("************\n\n"); return 0; } if (params.n_ctx != 0 && params.n_ctx < 8) { - LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + LOG_WRN("%s: minimum context size is 8, using minimum size.\n", __func__); params.n_ctx = 8; } + if (!params.interactive_first && (params.input_prefix.empty() && params.input_suffix.empty())) { - printf("\n************\n"); - printf("%s: please use '--interactive_first' or specify '--in_prefix' and/or '--in_suffix'\n", __func__); - printf("************\n\n"); + LOG_ERR("\n************\n"); + LOG_ERR("%s: please use '--interactive_first' or specify '--in_prefix' and/or '--in_suffix'\n", __func__); + LOG_ERR("************\n\n"); return 0; } if (params.rope_freq_base != 0.0) { - LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); + LOG_WRN("%s: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); } if (params.rope_freq_scale != 0.0) { - LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); + LOG_WRN("%s: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); } - print_build_info(); - - LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); - - LOG("%s: llama backend init\n", __func__); + LOG_INF("%s: llama backend init\n", __func__); llama_backend_init(); llama_numa_init(params.numa); llama_model * model = nullptr; llama_context * ctx = nullptr; - gpt_sampler * smpl = nullptr; + common_sampler * smpl = nullptr; g_model = &model; g_ctx = &ctx; g_smpl = &smpl; // load the model and apply lora adapter, if any - LOG("%s: load the model and apply lora adapter, if any\n", __func__); - llama_init_result llama_init = llama_init_from_gpt_params(params); + LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); + common_init_result llama_init = common_init_from_params(params); - model = llama_init.model; - ctx = llama_init.context; + model = llama_init.model.get(); + ctx = llama_init.context.get(); if (model == NULL) { - LOG_TEE("%s: error: unable to load model\n", __func__); + LOG_ERR("%s: unable to load model\n", __func__); return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); - LOG("n_ctx: %d\n", n_ctx); + LOG_DBG("n_ctx: %d\n", n_ctx); if (n_ctx > n_ctx_train) { - LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n", - __func__, n_ctx_train, n_ctx); + LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); } // print system information { - LOG_TEE("\n"); - LOG_TEE("%s\n", gpt_params_get_system_info(params).c_str()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } - const bool add_bos = llama_add_bos_token(model); - GGML_ASSERT(!llama_add_eos_token(model)); - LOG("add_bos: %d\n", add_bos); + const bool add_bos = llama_vocab_get_add_bos(vocab); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); std::vector embd_inp; std::vector embd_end; - std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false); - std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false); + std::vector inp_pfx = common_tokenize(ctx, params.input_prefix, false); + std::vector inp_sfx = common_tokenize(ctx, params.input_suffix, false); - GGML_ASSERT(llama_token_prefix(model) >= 0); - GGML_ASSERT(llama_token_suffix(model) >= 0); + GGML_ASSERT(llama_vocab_fim_pre(vocab) >= 0); + GGML_ASSERT(llama_vocab_fim_suf(vocab) >= 0); - inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model)); - inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model)); + inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab)); + inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab)); embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx; if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - const llama_token middle_token = llama_token_middle(model); + const llama_token middle_token = llama_vocab_fim_mid(vocab); if (middle_token >= 0) { embd_inp.push_back(middle_token); } - LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix)); - LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix)); - LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + LOG_DBG("add_bos: %d\n", add_bos); + LOG_DBG("prefix: \"%s\"\n", params.input_prefix.c_str()); + LOG_DBG("suffix: \"%s\"\n", params.input_suffix.c_str()); + LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str()); // Should not run without any tokens if (embd_inp.empty()) { - embd_inp.push_back(llama_token_bos(model)); - LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + embd_inp.push_back(llama_vocab_bos(vocab)); + LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); } if ((int) embd_inp.size() > n_ctx - 4) { - LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); + LOG_ERR("%s: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); return 1; } @@ -246,9 +201,8 @@ int main(int argc, char ** argv) { params.n_keep = (int)embd_inp.size(); } - LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str()); - LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str()); - + LOG_INF("inp_pfx: %s\n", string_from(ctx, inp_pfx).c_str()); + LOG_INF("inp_sfx: %s\n", string_from(ctx, inp_sfx).c_str()); // enable interactive mode if interactive start is specified if (params.interactive_first) { @@ -256,21 +210,21 @@ int main(int argc, char ** argv) { } if (params.verbose_prompt) { - LOG_TEE("\n"); - LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + LOG_INF("\n"); + LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); + LOG_INF("%6d -> '%s'\n", embd_inp[i], common_token_to_piece(ctx, embd_inp[i]).c_str()); } if (params.n_keep > 0) { - LOG_TEE("%s: static prompt based on n_keep: '", __func__); + LOG_INF("%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { - LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str()); + LOG_CNT("%s", common_token_to_piece(ctx, embd_inp[i]).c_str()); } - LOG_TEE("'\n"); + LOG_CNT("'\n"); } - LOG_TEE("\n"); + LOG_INF("\n"); } if (params.interactive) { @@ -287,30 +241,30 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - LOG_TEE("%s: interactive mode on.\n", __func__); + LOG_INF("%s: interactive mode on.\n", __func__); if (params.input_prefix_bos) { - LOG_TEE("Input prefix with BOS\n"); + LOG_INF("Input prefix with BOS\n"); } if (!params.input_prefix.empty()) { - LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str()); + LOG_INF("Input prefix: '%s'\n", params.input_prefix.c_str()); } if (!params.input_suffix.empty()) { - LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); + LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str()); } } - LOG_TEE("sampling: \n%s\n", sparams.print().c_str()); - LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); - LOG_TEE("\n\n"); + smpl = common_sampler_init(model, sparams); - LOG_TEE("\n##### Infill mode #####\n\n"); - if (params.infill) { - printf("\n************\n"); - printf("no need to specify '--infill', always running infill\n"); - printf("************\n\n"); - } + LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl)); + LOG_INF("sampler params: \n%s\n", sparams.print().c_str()); + LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str()); + + LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + + LOG_INF("\n"); + LOG_INF("\n##### Infill mode #####\n\n"); if (params.interactive) { const char *control_message; if (params.multiline_input) { @@ -321,11 +275,11 @@ int main(int argc, char ** argv) { " - To return control without starting a new line, end your input with '/'.\n" " - If you want to submit another line, end your input with '\\'.\n"; } - LOG_TEE("== Running in interactive mode. ==\n"); + LOG_INF("== Running in interactive mode. ==\n"); #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) - LOG_TEE( " - Press Ctrl+C to interject at any time.\n"); + LOG_INF( " - Press Ctrl+C to interject at any time.\n"); #endif - LOG_TEE( "%s\n", control_message); + LOG_INF( "%s\n", control_message); is_interacting = params.interactive_first; } @@ -345,8 +299,6 @@ int main(int argc, char ** argv) { std::vector embd; - smpl = gpt_sampler_init(model, sparams); - while (n_remain != 0 || params.interactive) { // predict if (!embd.empty()) { @@ -360,9 +312,8 @@ int main(int argc, char ** argv) { embd.resize(max_embd_size); console::set_display(console::error); - printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + LOG_WRN("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); console::set_display(console::reset); - fflush(stdout); } // infinite text generation via context swapping @@ -371,14 +322,14 @@ int main(int argc, char ** argv) { // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches if (n_past + (int) embd.size() > n_ctx) { if (params.n_predict == -2) { - LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); break; } const int n_left = n_past - params.n_keep - 1; const int n_discard = n_left/2; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); @@ -386,9 +337,9 @@ int main(int argc, char ** argv) { n_past -= n_discard; - LOG("after swap: n_past = %d\n", n_past); + LOG_DBG("after swap: n_past = %d\n", n_past); - LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str()); } @@ -400,16 +351,16 @@ int main(int argc, char ** argv) { n_eval = params.n_batch; } - LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { - LOG_TEE("%s : failed to eval\n", __func__); + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + LOG_ERR("%s : failed to eval\n", __func__); return 1; } n_past += n_eval; - LOG("n_past = %d\n", n_past); + LOG_DBG("n_past = %d\n", n_past); } } @@ -417,11 +368,11 @@ int main(int argc, char ** argv) { embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = gpt_sampler_sample(smpl, ctx, -1); + const llama_token id = common_sampler_sample(smpl, ctx, -1); - gpt_sampler_accept(smpl, id, true); + common_sampler_accept(smpl, id, true); - // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); + // LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str()); embd.push_back(id); @@ -431,16 +382,16 @@ int main(int argc, char ** argv) { // decrement remaining sampling budget --n_remain; - LOG("n_remain: %d\n", n_remain); + LOG_DBG("n_remain: %d\n", n_remain); } else { // some user input remains from prompt or interaction, forward it to processing - LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); + LOG_DBG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - gpt_sampler_accept(smpl, embd_inp[n_consumed], false); + common_sampler_accept(smpl, embd_inp[n_consumed], false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -452,8 +403,8 @@ int main(int argc, char ** argv) { // display text if (input_echo) { for (auto id : embd) { - const std::string token_str = llama_token_to_piece(ctx, id); - printf("%s", token_str.c_str()); + const std::string token_str = common_token_to_piece(ctx, id); + LOG("%s", token_str.c_str()); if (embd.size() > 1) { input_tokens.push_back(id); @@ -462,7 +413,6 @@ int main(int argc, char ** argv) { output_ss << token_str; } } - fflush(stdout); } // reset color to default if we there is no pending user input if (input_echo && (int) embd_inp.size() == n_consumed) { @@ -472,13 +422,12 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ + if ((common_sampler_last(smpl) == llama_vocab_eot(vocab) || is_interacting) && params.interactive){ if (is_interacting && !params.interactive_first) { // print an eot token - printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str()); + LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str()); } - fflush(stdout); - printf("\n"); + LOG("\n"); console::set_display(console::user_input); std::string buffer; std::string line; @@ -513,16 +462,16 @@ int main(int argc, char ** argv) { } // tokenize new prefix and suffix - std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false); - std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false); + std::vector inp_pfx = common_tokenize(ctx, params.input_prefix, false); + std::vector inp_sfx = common_tokenize(ctx, params.input_suffix, false); - inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model)); - inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model)); + inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab)); + inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab)); embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx; if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); @@ -534,35 +483,33 @@ int main(int argc, char ** argv) { n_remain = params.n_predict; n_past = 0; n_consumed = 0; - // LOG_TEE("took new input\n"); is_interacting = false; } // deal with end of generation tokens in interactive mode - else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) { - LOG("found EOS token\n"); + else if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { + LOG_DBG("found EOS token\n"); if (params.interactive) { is_interacting = true; - printf("\n"); + LOG("\n"); console::set_display(console::user_input); - fflush(stdout); } } if (n_past > 0 && is_interacting && !params.interactive) { - LOG("waiting for user input\n"); + LOG_DBG("waiting for user input\n"); if (params.input_prefix_bos) { - LOG("adding input prefix BOS token\n"); - embd_inp.push_back(llama_token_bos(model)); + LOG_DBG("adding input prefix BOS token\n"); + embd_inp.push_back(llama_vocab_bos(vocab)); } std::string buffer; if (!params.input_prefix.empty()) { - LOG("appending input prefix: '%s'\n", params.input_prefix.c_str()); + LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str()); buffer += params.input_prefix; - printf("%s", buffer.c_str()); + LOG("%s", buffer.c_str()); } std::string line; @@ -580,30 +527,30 @@ int main(int argc, char ** argv) { if (buffer.length() > 1) { // append input suffix if any if (!params.input_suffix.empty()) { - LOG("appending input suffix: '%s'\n", params.input_suffix.c_str()); + LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str()); buffer += params.input_suffix; - printf("%s", params.input_suffix.c_str()); + LOG("%s", params.input_suffix.c_str()); } - LOG("buffer: '%s'\n", buffer.c_str()); + LOG_DBG("buffer: '%s'\n", buffer.c_str()); const size_t original_size = embd_inp.size(); - const auto line_inp = ::llama_tokenize(ctx, buffer, false); - LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); + const auto line_inp = common_tokenize(ctx, buffer, false); + LOG_DBG("input tokens: %s\n", string_from(ctx, line_inp).c_str()); embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); - output_ss << llama_token_to_piece(ctx, token); + output_ss << common_token_to_piece(ctx, token); } n_remain -= line_inp.size(); - LOG("n_remain: %d\n", n_remain); + LOG_DBG("n_remain: %d\n", n_remain); } else { - LOG("empty line, passing control back\n"); + LOG_DBG("empty line, passing control back\n"); } input_echo = false; // do not echo this again @@ -611,14 +558,14 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - gpt_sampler_reset(smpl); + common_sampler_reset(smpl); } is_interacting = false; } } // end of generation - if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !params.interactive) { + if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !params.interactive) { break; } @@ -630,23 +577,14 @@ int main(int argc, char ** argv) { } } if (!params.interactive && n_remain <= 0) { - printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str()); - fflush(stdout); + LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str()); } - LOG_TEE("\n"); - gpt_perf_print(ctx, smpl); - write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + LOG("\n"); + common_perf_print(ctx, smpl); - llama_free(ctx); - llama_free_model(model); - - gpt_sampler_free(smpl); + common_sampler_free(smpl); llama_backend_free(); -#ifndef LOG_DISABLE_LOGS - LOG_TEE("Log end\n"); -#endif // LOG_DISABLE_LOGS - return 0; } diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index a8779bf3b..fc9f0097f 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -540,7 +540,7 @@ class SchemaConverter: return self._add_rule( name, to_rule(transform()) if self._raw_pattern \ - else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") + else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space") def _resolve_ref(self, ref): diff --git a/examples/llama-bench/CMakeLists.txt b/examples/llama-bench/CMakeLists.txt index 5bdbea4e2..17e3b9b87 100644 --- a/examples/llama-bench/CMakeLists.txt +++ b/examples/llama-bench/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-bench) add_executable(${TARGET} llama-bench.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index d7db5af72..4ac19ca86 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -6,34 +6,28 @@ #include #include #include +#include #include #include -#include #include #include #include #include #include #include -#include #include +#include +#include "common.h" #include "ggml.h" #include "llama.h" -#include "common.h" -#include "ggml-cuda.h" -#include "ggml-sycl.h" - -#ifdef GGML_USE_CANN -#include "ggml-cann.h" -#endif #ifdef _WIN32 -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX -#endif -#include +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include #endif // utils @@ -42,8 +36,7 @@ static uint64_t get_time_ns() { return std::chrono::nanoseconds(clock::now().time_since_epoch()).count(); } -template -static std::string join(const std::vector & values, const std::string & delim) { +template static std::string join(const std::vector & values, const std::string & delim) { std::ostringstream str; for (size_t i = 0; i < values.size(); i++) { str << values[i]; @@ -54,137 +47,73 @@ static std::string join(const std::vector & values, const std::string & delim return str.str(); } -template -static std::vector transform_to_str(const std::vector & values, F f) { +template static std::vector transform_to_str(const std::vector & values, F f) { std::vector str_values; std::transform(values.begin(), values.end(), std::back_inserter(str_values), f); return str_values; } -template -static T avg(const std::vector & v) { +template static T avg(const std::vector & v) { if (v.empty()) { return 0; } T sum = std::accumulate(v.begin(), v.end(), T(0)); - return sum / (T)v.size(); + return sum / (T) v.size(); } -template -static T stdev(const std::vector & v) { +template static T stdev(const std::vector & v) { if (v.size() <= 1) { return 0; } - T mean = avg(v); + T mean = avg(v); T sq_sum = std::inner_product(v.begin(), v.end(), v.begin(), T(0)); - T stdev = std::sqrt(sq_sum / (T)(v.size() - 1) - mean * mean * (T)v.size() / (T)(v.size() - 1)); + T stdev = std::sqrt(sq_sum / (T) (v.size() - 1) - mean * mean * (T) v.size() / (T) (v.size() - 1)); return stdev; } static std::string get_cpu_info() { - std::string id; -#ifdef __linux__ - FILE * f = fopen("/proc/cpuinfo", "r"); - if (f) { - char buf[1024]; - while (fgets(buf, sizeof(buf), f)) { - if (strncmp(buf, "model name", 10) == 0) { - char * p = strchr(buf, ':'); - if (p) { - p++; - while (std::isspace(*p)) { - p++; - } - while (std::isspace(p[strlen(p) - 1])) { - p[strlen(p) - 1] = '\0'; - } - id = p; - break; - } - } - } - fclose(f); - } -#elif defined(_WIN32) - HKEY hKey; - if (RegOpenKeyEx(HKEY_LOCAL_MACHINE, - TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"), - 0, - KEY_READ, - &hKey) != ERROR_SUCCESS) { - // fail to open registry key - return ""; - } - char cpu_brand[256]; - DWORD cpu_brand_size = sizeof(cpu_brand); - if (RegQueryValueExA(hKey, - TEXT("ProcessorNameString"), - NULL, - NULL, - (LPBYTE)cpu_brand, - &cpu_brand_size) == ERROR_SUCCESS) { - id.assign(cpu_brand, cpu_brand_size); - if (id.find('\0') != std::string::npos) { - id.resize(id.find('\0')); + std::vector cpu_list; + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + auto * dev = ggml_backend_dev_get(i); + auto dev_type = ggml_backend_dev_type(dev); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU || dev_type == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + cpu_list.push_back(ggml_backend_dev_description(dev)); } } - RegCloseKey(hKey); -#endif - // TODO: other platforms - return id; + return join(cpu_list, ", "); } static std::string get_gpu_info() { - std::string id; -#ifdef GGML_USE_CUDA - int count = ggml_backend_cuda_get_device_count(); - for (int i = 0; i < count; i++) { - char buf[128]; - ggml_backend_cuda_get_device_description(i, buf, sizeof(buf)); - id += buf; - if (i < count - 1) { - id += "/"; + std::vector gpu_list; + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + auto * dev = ggml_backend_dev_get(i); + auto dev_type = ggml_backend_dev_type(dev); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU) { + gpu_list.push_back(ggml_backend_dev_description(dev)); } } -#endif -#ifdef GGML_USE_SYCL - int count = ggml_backend_sycl_get_device_count(); - for (int i = 0; i < count; i++) { - char buf[128]; - ggml_sycl_get_device_description(i, buf, sizeof(buf)); - id += buf; - if (i < count - 1) { - id += "/"; - } - } -#endif -#ifdef GGML_USE_CANN - uint32_t count = ggml_backend_cann_get_device_count(); - for (uint32_t i = 0; i < count; i++) { - char buf[128]; - ggml_backend_cann_get_device_description(i, buf, sizeof(buf)); - id += buf; - if (i < count - 1) { - id += "/"; - } - } -#endif - // TODO: other backends - return id; + return join(gpu_list, ", "); } // command line params -enum output_formats {NONE, CSV, JSON, JSONL, MARKDOWN, SQL}; +enum output_formats { NONE, CSV, JSON, JSONL, MARKDOWN, SQL }; static const char * output_format_str(output_formats format) { switch (format) { - case NONE: return "none"; - case CSV: return "csv"; - case JSON: return "json"; - case JSONL: return "jsonl"; - case MARKDOWN: return "md"; - case SQL: return "sql"; - default: GGML_ABORT("invalid output format"); + case NONE: + return "none"; + case CSV: + return "csv"; + case JSON: + return "json"; + case JSONL: + return "jsonl"; + case MARKDOWN: + return "md"; + case SQL: + return "sql"; + default: + GGML_ABORT("invalid output format"); } } @@ -209,10 +138,14 @@ static bool output_format_from_str(const std::string & s, output_formats & forma static const char * split_mode_str(llama_split_mode mode) { switch (mode) { - case LLAMA_SPLIT_MODE_NONE: return "none"; - case LLAMA_SPLIT_MODE_LAYER: return "layer"; - case LLAMA_SPLIT_MODE_ROW: return "row"; - default: GGML_ABORT("invalid split mode"); + case LLAMA_SPLIT_MODE_NONE: + return "none"; + case LLAMA_SPLIT_MODE_LAYER: + return "layer"; + case LLAMA_SPLIT_MODE_ROW: + return "row"; + default: + GGML_ABORT("invalid split mode"); } } @@ -223,59 +156,59 @@ static std::string pair_str(const std::pair & p) { } struct cmd_params { - std::vector model; - std::vector n_prompt; - std::vector n_gen; + std::vector model; + std::vector n_prompt; + std::vector n_gen; std::vector> n_pg; - std::vector n_batch; - std::vector n_ubatch; - std::vector type_k; - std::vector type_v; - std::vector n_threads; - std::vector cpu_mask; - std::vector cpu_strict; - std::vector poll; - std::vector n_gpu_layers; - std::vector rpc_servers; - std::vector split_mode; - std::vector main_gpu; - std::vector no_kv_offload; - std::vector flash_attn; - std::vector> tensor_split; - std::vector use_mmap; - std::vector embeddings; - ggml_numa_strategy numa; - int reps; - ggml_sched_priority prio; - int delay; - bool verbose; - bool progress; - output_formats output_format; - output_formats output_format_stderr; + std::vector n_batch; + std::vector n_ubatch; + std::vector type_k; + std::vector type_v; + std::vector n_threads; + std::vector cpu_mask; + std::vector cpu_strict; + std::vector poll; + std::vector n_gpu_layers; + std::vector rpc_servers; + std::vector split_mode; + std::vector main_gpu; + std::vector no_kv_offload; + std::vector flash_attn; + std::vector> tensor_split; + std::vector use_mmap; + std::vector embeddings; + ggml_numa_strategy numa; + int reps; + ggml_sched_priority prio; + int delay; + bool verbose; + bool progress; + output_formats output_format; + output_formats output_format_stderr; }; static const cmd_params cmd_params_defaults = { - /* model */ {"models/7B/ggml-model-q4_0.gguf"}, - /* n_prompt */ {512}, - /* n_gen */ {128}, + /* model */ { "models/7B/ggml-model-q4_0.gguf" }, + /* n_prompt */ { 512 }, + /* n_gen */ { 128 }, /* n_pg */ {}, - /* n_batch */ {2048}, - /* n_ubatch */ {512}, - /* type_k */ {GGML_TYPE_F16}, - /* type_v */ {GGML_TYPE_F16}, - /* n_threads */ {cpu_get_num_math()}, - /* cpu_mask */ {"0x0"}, - /* cpu_strict */ {false}, - /* poll */ {50}, - /* n_gpu_layers */ {99}, - /* rpc_servers */ {""}, - /* split_mode */ {LLAMA_SPLIT_MODE_LAYER}, - /* main_gpu */ {0}, - /* no_kv_offload */ {false}, - /* flash_attn */ {false}, - /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, - /* use_mmap */ {true}, - /* embeddings */ {false}, + /* n_batch */ { 2048 }, + /* n_ubatch */ { 512 }, + /* type_k */ { GGML_TYPE_F16 }, + /* type_v */ { GGML_TYPE_F16 }, + /* n_threads */ { cpu_get_num_math() }, + /* cpu_mask */ { "0x0" }, + /* cpu_strict */ { false }, + /* poll */ { 50 }, + /* n_gpu_layers */ { 99 }, + /* rpc_servers */ { "" }, + /* split_mode */ { LLAMA_SPLIT_MODE_LAYER }, + /* main_gpu */ { 0 }, + /* no_kv_offload */ { false }, + /* flash_attn */ { false }, + /* tensor_split */ { std::vector(llama_max_devices(), 0.0f) }, + /* use_mmap */ { true }, + /* embeddings */ { false }, /* numa */ GGML_NUMA_STRATEGY_DISABLED, /* reps */ 5, /* prio */ GGML_SCHED_PRIO_NORMAL, @@ -292,44 +225,68 @@ static void print_usage(int /* argc */, char ** argv) { printf("options:\n"); printf(" -h, --help\n"); printf(" -m, --model (default: %s)\n", join(cmd_params_defaults.model, ",").c_str()); - printf(" -p, --n-prompt (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str()); + printf(" -p, --n-prompt (default: %s)\n", + join(cmd_params_defaults.n_prompt, ",").c_str()); printf(" -n, --n-gen (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str()); - printf(" -pg (default: %s)\n", join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str()); - printf(" -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); - printf(" -ub, --ubatch-size (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str()); - printf(" -ctk, --cache-type-k (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); - printf(" -ctv, --cache-type-v (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); - printf(" -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); - printf(" -C, --cpu-mask (default: %s)\n", join(cmd_params_defaults.cpu_mask, ",").c_str()); - printf(" --cpu-strict <0|1> (default: %s)\n", join(cmd_params_defaults.cpu_strict, ",").c_str()); + printf(" -pg (default: %s)\n", + join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str()); + printf(" -b, --batch-size (default: %s)\n", + join(cmd_params_defaults.n_batch, ",").c_str()); + printf(" -ub, --ubatch-size (default: %s)\n", + join(cmd_params_defaults.n_ubatch, ",").c_str()); + printf(" -ctk, --cache-type-k (default: %s)\n", + join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str()); + printf(" -ctv, --cache-type-v (default: %s)\n", + join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str()); + printf(" -t, --threads (default: %s)\n", + join(cmd_params_defaults.n_threads, ",").c_str()); + printf(" -C, --cpu-mask (default: %s)\n", + join(cmd_params_defaults.cpu_mask, ",").c_str()); + printf(" --cpu-strict <0|1> (default: %s)\n", + join(cmd_params_defaults.cpu_strict, ",").c_str()); printf(" --poll <0...100> (default: %s)\n", join(cmd_params_defaults.poll, ",").c_str()); - printf(" -ngl, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); -#ifdef GGML_USE_RPC - printf(" -rpc, --rpc (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str()); -#endif - printf(" -sm, --split-mode (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); - printf(" -mg, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); - printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); - printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); - printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); + printf(" -ngl, --n-gpu-layers (default: %s)\n", + join(cmd_params_defaults.n_gpu_layers, ",").c_str()); + if (llama_supports_rpc()) { + printf(" -rpc, --rpc (default: %s)\n", + join(cmd_params_defaults.rpc_servers, ",").c_str()); + } + printf(" -sm, --split-mode (default: %s)\n", + join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str()); + printf(" -mg, --main-gpu (default: %s)\n", + join(cmd_params_defaults.main_gpu, ",").c_str()); + printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", + join(cmd_params_defaults.no_kv_offload, ",").c_str()); + printf(" -fa, --flash-attn <0|1> (default: %s)\n", + join(cmd_params_defaults.flash_attn, ",").c_str()); + printf(" -mmp, --mmap <0|1> (default: %s)\n", + join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); - printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); + printf(" -embd, --embeddings <0|1> (default: %s)\n", + join(cmd_params_defaults.embeddings, ",").c_str()); printf(" -ts, --tensor-split (default: 0)\n"); printf(" -r, --repetitions (default: %d)\n", cmd_params_defaults.reps); printf(" --prio <0|1|2|3> (default: %d)\n", cmd_params_defaults.prio); printf(" --delay <0...N> (seconds) (default: %d)\n", cmd_params_defaults.delay); - printf(" -o, --output (default: %s)\n", output_format_str(cmd_params_defaults.output_format)); - printf(" -oe, --output-err (default: %s)\n", output_format_str(cmd_params_defaults.output_format_stderr)); + printf(" -o, --output (default: %s)\n", + output_format_str(cmd_params_defaults.output_format)); + printf(" -oe, --output-err (default: %s)\n", + output_format_str(cmd_params_defaults.output_format_stderr)); printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" --progress (default: %s)\n", cmd_params_defaults.progress ? "1" : "0"); printf("\n"); - printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); + printf( + "Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter " + "multiple times.\n"); } static ggml_type ggml_type_from_name(const std::string & s) { if (s == "f16") { return GGML_TYPE_F16; } + if (s == "bf16") { + return GGML_TYPE_BF16; + } if (s == "q8_0") { return GGML_TYPE_Q8_0; } @@ -352,22 +309,21 @@ static ggml_type ggml_type_from_name(const std::string & s) { return GGML_TYPE_COUNT; } - static cmd_params parse_cmd_params(int argc, char ** argv) { - cmd_params params; - std::string arg; - bool invalid_param = false; - const std::string arg_prefix = "--"; - const char split_delim = ','; + cmd_params params; + std::string arg; + bool invalid_param = false; + const std::string arg_prefix = "--"; + const char split_delim = ','; - params.verbose = cmd_params_defaults.verbose; - params.output_format = cmd_params_defaults.output_format; + params.verbose = cmd_params_defaults.verbose; + params.output_format = cmd_params_defaults.output_format; params.output_format_stderr = cmd_params_defaults.output_format_stderr; - params.reps = cmd_params_defaults.reps; - params.numa = cmd_params_defaults.numa; - params.prio = cmd_params_defaults.prio; - params.delay = cmd_params_defaults.delay; - params.progress = cmd_params_defaults.progress; + params.reps = cmd_params_defaults.reps; + params.numa = cmd_params_defaults.numa; + params.prio = cmd_params_defaults.prio; + params.delay = cmd_params_defaults.delay; + params.progress = cmd_params_defaults.progress; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -409,7 +365,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { invalid_param = true; break; } - params.n_pg.push_back({std::stoi(p[0]), std::stoi(p[1])}); + params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) }); } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) { invalid_param = true; @@ -429,7 +385,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { invalid_param = true; break; } - auto p = string_split(argv[i], split_delim); + auto p = string_split(argv[i], split_delim); std::vector types; for (const auto & t : p) { ggml_type gt = ggml_type_from_name(t); @@ -439,13 +395,16 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } types.push_back(gt); } + if (invalid_param) { + break; + } params.type_k.insert(params.type_k.end(), types.begin(), types.end()); } else if (arg == "-ctv" || arg == "--cache-type-v") { if (++i >= argc) { invalid_param = true; break; } - auto p = string_split(argv[i], split_delim); + auto p = string_split(argv[i], split_delim); std::vector types; for (const auto & t : p) { ggml_type gt = ggml_type_from_name(t); @@ -455,6 +414,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } types.push_back(gt); } + if (invalid_param) { + break; + } params.type_v.insert(params.type_v.end(), types.begin(), types.end()); } else if (arg == "-t" || arg == "--threads") { if (++i >= argc) { @@ -491,20 +453,18 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split(argv[i], split_delim); params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end()); -#ifdef GGML_USE_RPC - } else if (arg == "-rpc" || arg == "--rpc") { + } else if (llama_supports_rpc() && (arg == "-rpc" || arg == "--rpc")) { if (++i >= argc) { invalid_param = true; break; } params.rpc_servers.push_back(argv[i]); -#endif } else if (arg == "-sm" || arg == "--split-mode") { if (++i >= argc) { invalid_param = true; break; } - auto p = string_split(argv[i], split_delim); + auto p = string_split(argv[i], split_delim); std::vector modes; for (const auto & m : p) { llama_split_mode mode; @@ -520,6 +480,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } modes.push_back(mode); } + if (invalid_param) { + break; + } params.split_mode.insert(params.split_mode.end(), modes.begin(), modes.end()); } else if (arg == "-mg" || arg == "--main-gpu") { if (++i >= argc) { @@ -540,10 +503,16 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } else { std::string value(argv[i]); - /**/ if (value == "distribute" || value == "" ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } - else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } - else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } - else { invalid_param = true; break; } + /**/ if (value == "distribute" || value == "") { + params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; + } else if (value == "isolate") { + params.numa = GGML_NUMA_STRATEGY_ISOLATE; + } else if (value == "numactl") { + params.numa = GGML_NUMA_STRATEGY_NUMACTL; + } else { + invalid_param = true; + break; + } } } else if (arg == "-fa" || arg == "--flash-attn") { if (++i >= argc) { @@ -573,9 +542,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } for (auto ts : string_split(argv[i], split_delim)) { // split string by ; and / - const std::regex regex{R"([;/]+)"}; - std::sregex_token_iterator it{ts.begin(), ts.end(), regex, -1}; - std::vector split_arg{it, {}}; + const std::regex regex{ R"([;/]+)" }; + std::sregex_token_iterator it{ ts.begin(), ts.end(), regex, -1 }; + std::vector split_arg{ it, {} }; GGML_ASSERT(split_arg.size() <= llama_max_devices()); std::vector tensor_split(llama_max_devices()); @@ -634,89 +603,156 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } // set defaults - if (params.model.empty()) { params.model = cmd_params_defaults.model; } - if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; } - if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; } - if (params.n_pg.empty()) { params.n_pg = cmd_params_defaults.n_pg; } - if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; } - if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; } - if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; } - if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; } - if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; } - if (params.rpc_servers.empty()) { params.rpc_servers = cmd_params_defaults.rpc_servers; } - if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; } - if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } - if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; } - if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } - if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } - if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } - if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } - if (params.n_threads.empty()) { params.n_threads = cmd_params_defaults.n_threads; } - if (params.cpu_mask.empty()) { params.cpu_mask = cmd_params_defaults.cpu_mask; } - if (params.cpu_strict.empty()) { params.cpu_strict = cmd_params_defaults.cpu_strict; } - if (params.poll.empty()) { params.poll = cmd_params_defaults.poll; } + if (params.model.empty()) { + params.model = cmd_params_defaults.model; + } + if (params.n_prompt.empty()) { + params.n_prompt = cmd_params_defaults.n_prompt; + } + if (params.n_gen.empty()) { + params.n_gen = cmd_params_defaults.n_gen; + } + if (params.n_pg.empty()) { + params.n_pg = cmd_params_defaults.n_pg; + } + if (params.n_batch.empty()) { + params.n_batch = cmd_params_defaults.n_batch; + } + if (params.n_ubatch.empty()) { + params.n_ubatch = cmd_params_defaults.n_ubatch; + } + if (params.type_k.empty()) { + params.type_k = cmd_params_defaults.type_k; + } + if (params.type_v.empty()) { + params.type_v = cmd_params_defaults.type_v; + } + if (params.n_gpu_layers.empty()) { + params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; + } + if (params.rpc_servers.empty()) { + params.rpc_servers = cmd_params_defaults.rpc_servers; + } + if (params.split_mode.empty()) { + params.split_mode = cmd_params_defaults.split_mode; + } + if (params.main_gpu.empty()) { + params.main_gpu = cmd_params_defaults.main_gpu; + } + if (params.no_kv_offload.empty()) { + params.no_kv_offload = cmd_params_defaults.no_kv_offload; + } + if (params.flash_attn.empty()) { + params.flash_attn = cmd_params_defaults.flash_attn; + } + if (params.tensor_split.empty()) { + params.tensor_split = cmd_params_defaults.tensor_split; + } + if (params.use_mmap.empty()) { + params.use_mmap = cmd_params_defaults.use_mmap; + } + if (params.embeddings.empty()) { + params.embeddings = cmd_params_defaults.embeddings; + } + if (params.n_threads.empty()) { + params.n_threads = cmd_params_defaults.n_threads; + } + if (params.cpu_mask.empty()) { + params.cpu_mask = cmd_params_defaults.cpu_mask; + } + if (params.cpu_strict.empty()) { + params.cpu_strict = cmd_params_defaults.cpu_strict; + } + if (params.poll.empty()) { + params.poll = cmd_params_defaults.poll; + } return params; } struct cmd_params_instance { - std::string model; - int n_prompt; - int n_gen; - int n_batch; - int n_ubatch; - ggml_type type_k; - ggml_type type_v; - int n_threads; - std::string cpu_mask; - bool cpu_strict; - int poll; - int n_gpu_layers; - std::string rpc_servers; - llama_split_mode split_mode; - int main_gpu; - bool no_kv_offload; - bool flash_attn; + std::string model; + int n_prompt; + int n_gen; + int n_batch; + int n_ubatch; + ggml_type type_k; + ggml_type type_v; + int n_threads; + std::string cpu_mask; + bool cpu_strict; + int poll; + int n_gpu_layers; + std::string rpc_servers_str; + llama_split_mode split_mode; + int main_gpu; + bool no_kv_offload; + bool flash_attn; std::vector tensor_split; - bool use_mmap; - bool embeddings; + bool use_mmap; + bool embeddings; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); mparams.n_gpu_layers = n_gpu_layers; - if (!rpc_servers.empty()) { - mparams.rpc_servers = rpc_servers.c_str(); + if (!rpc_servers_str.empty()) { + auto rpc_servers = string_split(rpc_servers_str, ','); + + // add RPC devices + if (!rpc_servers.empty()) { + ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); + if (!rpc_reg) { + fprintf(stderr, "%s: failed to find RPC backend\n", __func__); + exit(1); + } + + typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint); + ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device"); + if (!ggml_backend_rpc_add_device_fn) { + fprintf(stderr, "%s: failed to find RPC device add function\n", __func__); + exit(1); + } + static std::vector devices; + devices.clear(); + for (const std::string & server : rpc_servers) { + ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str()); + if (dev) { + devices.push_back(dev); + } else { + fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str()); + exit(1); + } + } + devices.push_back(nullptr); + mparams.devices = devices.data(); + } } - mparams.split_mode = split_mode; - mparams.main_gpu = main_gpu; + mparams.split_mode = split_mode; + mparams.main_gpu = main_gpu; mparams.tensor_split = tensor_split.data(); - mparams.use_mmap = use_mmap; + mparams.use_mmap = use_mmap; return mparams; } bool equal_mparams(const cmd_params_instance & other) const { - return model == other.model && - n_gpu_layers == other.n_gpu_layers && - rpc_servers == other.rpc_servers && - split_mode == other.split_mode && - main_gpu == other.main_gpu && - use_mmap == other.use_mmap && + return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str && + split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap && tensor_split == other.tensor_split; } llama_context_params to_llama_cparams() const { llama_context_params cparams = llama_context_default_params(); - cparams.n_ctx = n_prompt + n_gen; - cparams.n_batch = n_batch; - cparams.n_ubatch = n_ubatch; - cparams.type_k = type_k; - cparams.type_v = type_v; + cparams.n_ctx = n_prompt + n_gen; + cparams.n_batch = n_batch; + cparams.n_ubatch = n_ubatch; + cparams.type_k = type_k; + cparams.type_v = type_v; cparams.offload_kqv = !no_kv_offload; - cparams.flash_attn = flash_attn; - cparams.embeddings = embeddings; + cparams.flash_attn = flash_attn; + cparams.embeddings = embeddings; return cparams; } @@ -726,6 +762,7 @@ static std::vector get_cmd_params_instances(const cmd_param std::vector instances; // this ordering minimizes the number of times that each model needs to be reloaded + // clang-format off for (const auto & m : params.model) for (const auto & nl : params.n_gpu_layers) for (const auto & rpc : params.rpc_servers) @@ -831,165 +868,125 @@ static std::vector get_cmd_params_instances(const cmd_param instances.push_back(instance); } } + // clang-format on return instances; } struct test { static const std::string build_commit; - static const int build_number; - static const bool cuda; - static const bool vulkan; - static const bool kompute; - static const bool metal; - static const bool sycl; - static const bool gpu_blas; - static const bool blas; + static const int build_number; static const std::string cpu_info; static const std::string gpu_info; - std::string model_filename; - std::string model_type; - uint64_t model_size; - uint64_t model_n_params; - int n_batch; - int n_ubatch; - int n_threads; - std::string cpu_mask; - bool cpu_strict; - int poll; - bool has_rpc; - ggml_type type_k; - ggml_type type_v; - int n_gpu_layers; - llama_split_mode split_mode; - int main_gpu; - bool no_kv_offload; - bool flash_attn; - std::vector tensor_split; - bool use_mmap; - bool embeddings; - int n_prompt; - int n_gen; - std::string test_time; - std::vector samples_ns; + std::string model_filename; + std::string model_type; + uint64_t model_size; + uint64_t model_n_params; + int n_batch; + int n_ubatch; + int n_threads; + std::string cpu_mask; + bool cpu_strict; + int poll; + ggml_type type_k; + ggml_type type_v; + int n_gpu_layers; + llama_split_mode split_mode; + int main_gpu; + bool no_kv_offload; + bool flash_attn; + std::vector tensor_split; + bool use_mmap; + bool embeddings; + int n_prompt; + int n_gen; + std::string test_time; + std::vector samples_ns; test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) { model_filename = inst.model; char buf[128]; llama_model_desc(lmodel, buf, sizeof(buf)); - model_type = buf; - model_size = llama_model_size(lmodel); + model_type = buf; + model_size = llama_model_size(lmodel); model_n_params = llama_model_n_params(lmodel); - n_batch = inst.n_batch; - n_ubatch = inst.n_ubatch; - n_threads = inst.n_threads; - cpu_mask = inst.cpu_mask; - cpu_strict = inst.cpu_strict; - poll = inst.poll; - has_rpc = !inst.rpc_servers.empty(); - type_k = inst.type_k; - type_v = inst.type_v; - n_gpu_layers = inst.n_gpu_layers; - split_mode = inst.split_mode; - main_gpu = inst.main_gpu; - no_kv_offload = inst.no_kv_offload; - flash_attn = inst.flash_attn; - tensor_split = inst.tensor_split; - use_mmap = inst.use_mmap; - embeddings = inst.embeddings; - n_prompt = inst.n_prompt; - n_gen = inst.n_gen; + n_batch = inst.n_batch; + n_ubatch = inst.n_ubatch; + n_threads = inst.n_threads; + cpu_mask = inst.cpu_mask; + cpu_strict = inst.cpu_strict; + poll = inst.poll; + type_k = inst.type_k; + type_v = inst.type_v; + n_gpu_layers = inst.n_gpu_layers; + split_mode = inst.split_mode; + main_gpu = inst.main_gpu; + no_kv_offload = inst.no_kv_offload; + flash_attn = inst.flash_attn; + tensor_split = inst.tensor_split; + use_mmap = inst.use_mmap; + embeddings = inst.embeddings; + n_prompt = inst.n_prompt; + n_gen = inst.n_gen; // RFC 3339 date-time format - time_t t = time(NULL); + time_t t = time(NULL); std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t)); test_time = buf; (void) ctx; } - uint64_t avg_ns() const { - return ::avg(samples_ns); - } + uint64_t avg_ns() const { return ::avg(samples_ns); } - uint64_t stdev_ns() const { - return ::stdev(samples_ns); - } + uint64_t stdev_ns() const { return ::stdev(samples_ns); } std::vector get_ts() const { - int n_tokens = n_prompt + n_gen; + int n_tokens = n_prompt + n_gen; std::vector ts; - std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts), [n_tokens](uint64_t t) { return 1e9 * n_tokens / t; }); + std::transform(samples_ns.begin(), samples_ns.end(), std::back_inserter(ts), + [n_tokens](uint64_t t) { return 1e9 * n_tokens / t; }); return ts; } - double avg_ts() const { - return ::avg(get_ts()); - } + double avg_ts() const { return ::avg(get_ts()); } - double stdev_ts() const { - return ::stdev(get_ts()); - } + double stdev_ts() const { return ::stdev(get_ts()); } static std::string get_backend() { - if (cuda) { - return GGML_CUDA_NAME; + std::vector backends; + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto * reg = ggml_backend_reg_get(i); + std::string name = ggml_backend_reg_name(reg); + if (name != "CPU") { + backends.push_back(ggml_backend_reg_name(reg)); + } } - if (vulkan) { - return "Vulkan"; - } - if (kompute) { - return "Kompute"; - } - if (metal) { - return "Metal"; - } - if (sycl) { - return GGML_SYCL_NAME; - } - if (gpu_blas) { - return "GPU BLAS"; - } - if (blas) { - return "BLAS"; - } - - return "CPU"; + return backends.empty() ? "CPU" : join(backends, ","); } static const std::vector & get_fields() { static const std::vector fields = { - "build_commit", "build_number", - "cuda", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas", "blas", - "cpu_info", "gpu_info", - "model_filename", "model_type", "model_size", "model_n_params", - "n_batch", "n_ubatch", - "n_threads", "cpu_mask", "cpu_strict", "poll", - "type_k", "type_v", - "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", - "tensor_split", "use_mmap", "embeddings", - "n_prompt", "n_gen", "test_time", - "avg_ns", "stddev_ns", - "avg_ts", "stddev_ts", + "build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename", + "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", + "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", + "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap", + "embeddings", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", + "avg_ts", "stddev_ts", }; return fields; } - enum field_type {STRING, BOOL, INT, FLOAT}; + enum field_type { STRING, BOOL, INT, FLOAT }; static field_type get_field_type(const std::string & field) { - if (field == "build_number" || field == "n_batch" || field == "n_ubatch" || - field == "n_threads" || field == "poll" || - field == "model_size" || field == "model_n_params" || - field == "n_gpu_layers" || field == "main_gpu" || - field == "n_prompt" || field == "n_gen" || - field == "avg_ns" || field == "stddev_ns") { + if (field == "build_number" || field == "n_batch" || field == "n_ubatch" || field == "n_threads" || + field == "poll" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || + field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "avg_ns" || + field == "stddev_ns") { return INT; } - if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || - field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "cpu_strict" || - field == "flash_attn" || field == "use_mmap" || field == "embeddings") { + if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" || + field == "use_mmap" || field == "embeddings") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1000,7 +997,7 @@ struct test { std::vector get_values() const { std::string tensor_split_str; - int max_nonzero = 0; + int max_nonzero = 0; for (size_t i = 0; i < llama_max_devices(); i++) { if (tensor_split[i] > 0) { max_nonzero = i; @@ -1014,44 +1011,53 @@ struct test { tensor_split_str += "/"; } } - std::vector values = { - build_commit, std::to_string(build_number), - std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan), - std::to_string(metal), std::to_string(sycl), std::to_string(has_rpc), std::to_string(gpu_blas), std::to_string(blas), - cpu_info, gpu_info, - model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params), - std::to_string(n_batch), std::to_string(n_ubatch), - std::to_string(n_threads), cpu_mask, std::to_string(cpu_strict), std::to_string(poll), - ggml_type_name(type_k), ggml_type_name(type_v), - std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), - std::to_string(n_prompt), std::to_string(n_gen), test_time, - std::to_string(avg_ns()), std::to_string(stdev_ns()), - std::to_string(avg_ts()), std::to_string(stdev_ts()) - }; + std::vector values = { build_commit, + std::to_string(build_number), + cpu_info, + gpu_info, + get_backend(), + model_filename, + model_type, + std::to_string(model_size), + std::to_string(model_n_params), + std::to_string(n_batch), + std::to_string(n_ubatch), + std::to_string(n_threads), + cpu_mask, + std::to_string(cpu_strict), + std::to_string(poll), + ggml_type_name(type_k), + ggml_type_name(type_v), + std::to_string(n_gpu_layers), + split_mode_str(split_mode), + std::to_string(main_gpu), + std::to_string(no_kv_offload), + std::to_string(flash_attn), + tensor_split_str, + std::to_string(use_mmap), + std::to_string(embeddings), + std::to_string(n_prompt), + std::to_string(n_gen), + test_time, + std::to_string(avg_ns()), + std::to_string(stdev_ns()), + std::to_string(avg_ts()), + std::to_string(stdev_ts()) }; return values; } std::map get_map() const { std::map map; - auto fields = get_fields(); - auto values = get_values(); - std::transform(fields.begin(), fields.end(), values.begin(), - std::inserter(map, map.end()), std::make_pair); + auto fields = get_fields(); + auto values = get_values(); + std::transform(fields.begin(), fields.end(), values.begin(), std::inserter(map, map.end()), + std::make_pair); return map; } }; const std::string test::build_commit = LLAMA_COMMIT; const int test::build_number = LLAMA_BUILD_NUMBER; -const bool test::cuda = !!ggml_cpu_has_cuda(); -const bool test::vulkan = !!ggml_cpu_has_vulkan(); -const bool test::kompute = !!ggml_cpu_has_kompute(); -const bool test::metal = !!ggml_cpu_has_metal(); -const bool test::gpu_blas = !!ggml_cpu_has_gpublas(); -const bool test::blas = !!ggml_cpu_has_blas(); -const bool test::sycl = !!ggml_cpu_has_sycl(); const std::string test::cpu_info = get_cpu_info(); const std::string test::gpu_info = get_gpu_info(); @@ -1059,9 +1065,12 @@ struct printer { virtual ~printer() {} FILE * fout; + virtual void print_header(const cmd_params & params) { (void) params; } + virtual void print_test(const test & t) = 0; - virtual void print_footer() { } + + virtual void print_footer() {} }; struct csv_printer : public printer { @@ -1077,7 +1086,7 @@ struct csv_printer : public printer { return escaped; } - void print_header(const cmd_params & params) override { + void print_header(const cmd_params & params) override { std::vector fields = test::get_fields(); fprintf(fout, "%s\n", join(fields, ",").c_str()); (void) params; @@ -1090,7 +1099,6 @@ struct csv_printer : public printer { } }; - static std::string escape_json(const std::string & value) { std::string escaped; for (auto c : value) { @@ -1098,7 +1106,7 @@ static std::string escape_json(const std::string & value) { escaped += "\\\""; } else if (c == '\\') { escaped += "\\\\"; - } else if (c <= 0x1f) { + } else if (c <= 0x1f) { char buf[8]; snprintf(buf, sizeof(buf), "\\u%04x", c); escaped += buf; @@ -1131,7 +1139,8 @@ struct json_printer : public printer { void print_fields(const std::vector & fields, const std::vector & values) { assert(fields.size() == values.size()); for (size_t i = 0; i < fields.size(); i++) { - fprintf(fout, " \"%s\": %s,\n", fields.at(i).c_str(), format_json_value(fields.at(i), values.at(i)).c_str()); + fprintf(fout, " \"%s\": %s,\n", fields.at(i).c_str(), + format_json_value(fields.at(i), values.at(i)).c_str()); } } @@ -1149,12 +1158,9 @@ struct json_printer : public printer { fflush(fout); } - void print_footer() override { - fprintf(fout, "\n]\n"); - } + void print_footer() override { fprintf(fout, "\n]\n"); } }; - struct jsonl_printer : public printer { void print_fields(const std::vector & fields, const std::vector & values) { assert(fields.size() == values.size()); @@ -1214,7 +1220,7 @@ struct markdown_printer : public printer { return 13; } - int width = std::max((int)field.length(), 10); + int width = std::max((int) field.length(), 10); if (test::get_field_type(field) == test::STRING) { return -width; @@ -1256,7 +1262,8 @@ struct markdown_printer : public printer { fields.emplace_back("size"); fields.emplace_back("params"); fields.emplace_back("backend"); - bool is_cpu_backend = test::get_backend() == "CPU" || test::get_backend() == "BLAS"; + bool is_cpu_backend = test::get_backend().find("CPU") != std::string::npos || + test::get_backend().find("BLAS") != std::string::npos; if (!is_cpu_backend) { fields.emplace_back("n_gpu_layers"); } @@ -1327,18 +1334,18 @@ struct markdown_printer : public printer { fprintf(fout, "|"); for (const auto & field : fields) { std::string value; - char buf[128]; + char buf[128]; if (field == "model") { value = t.model_type; } else if (field == "size") { - if (t.model_size < 1024*1024*1024) { + if (t.model_size < 1024 * 1024 * 1024) { snprintf(buf, sizeof(buf), "%.2f MiB", t.model_size / 1024.0 / 1024.0); } else { snprintf(buf, sizeof(buf), "%.2f GiB", t.model_size / 1024.0 / 1024.0 / 1024.0); } value = buf; } else if (field == "params") { - if (t.model_n_params < 1000*1000*1000) { + if (t.model_n_params < 1000 * 1000 * 1000) { snprintf(buf, sizeof(buf), "%.2f M", t.model_n_params / 1e6); } else { snprintf(buf, sizeof(buf), "%.2f B", t.model_n_params / 1e9); @@ -1346,9 +1353,6 @@ struct markdown_printer : public printer { value = buf; } else if (field == "backend") { value = test::get_backend(); - if (t.has_rpc) { - value += "+RPC"; - } } else if (field == "test") { if (t.n_prompt > 0 && t.n_gen == 0) { snprintf(buf, sizeof(buf), "pp%d", t.n_prompt); @@ -1403,7 +1407,8 @@ struct sql_printer : public printer { std::vector fields = test::get_fields(); fprintf(fout, "CREATE TABLE IF NOT EXISTS test (\n"); for (size_t i = 0; i < fields.size(); i++) { - fprintf(fout, " %s %s%s\n", fields.at(i).c_str(), get_sql_field_type(fields.at(i)).c_str(), i < fields.size() - 1 ? "," : ""); + fprintf(fout, " %s %s%s\n", fields.at(i).c_str(), get_sql_field_type(fields.at(i)).c_str(), + i < fields.size() - 1 ? "," : ""); } fprintf(fout, ");\n"); fprintf(fout, "\n"); @@ -1421,11 +1426,12 @@ struct sql_printer : public printer { } }; -static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { +static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); - const llama_model * model = llama_get_model(ctx); - const int32_t n_vocab = llama_n_vocab(model); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); std::vector tokens(n_batch); @@ -1433,27 +1439,28 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); - tokens[0] = n_processed == 0 && llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; + tokens[0] = n_processed == 0 && llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0)); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); n_processed += n_tokens; } llama_synchronize(ctx); } -static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { +static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); - const llama_model * model = llama_get_model(ctx); - const int32_t n_vocab = llama_n_vocab(model); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_tokens(vocab); - llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; + llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0)); + llama_decode(ctx, llama_batch_get_one(&token, 1)); llama_synchronize(ctx); token = std::rand() % n_vocab; } @@ -1501,6 +1508,17 @@ int main(int argc, char ** argv) { cmd_params params = parse_cmd_params(argc, argv); + // initialize backends + ggml_backend_load_all(); + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + fprintf(stderr, "%s: error: CPU backend is not loaded\n", __func__); + return 1; + } + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto * ggml_threadpool_new_fn = (decltype(ggml_threadpool_new) *) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_threadpool_new"); + auto * ggml_threadpool_free_fn = (decltype(ggml_threadpool_free) *) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_threadpool_free"); + // initialize llama.cpp if (!params.verbose) { llama_log_set(llama_null_log_callback, NULL); @@ -1511,7 +1529,7 @@ int main(int argc, char ** argv) { set_process_priority(params.prio); // initialize printer - std::unique_ptr p = create_printer(params.output_format); + std::unique_ptr p = create_printer(params.output_format); std::unique_ptr p_err = create_printer(params.output_format_stderr); if (p) { @@ -1526,23 +1544,23 @@ int main(int argc, char ** argv) { std::vector params_instances = get_cmd_params_instances(params); - llama_model * lmodel = nullptr; + llama_model * lmodel = nullptr; const cmd_params_instance * prev_inst = nullptr; - int params_idx = 0; + int params_idx = 0; auto params_count = params_instances.size(); for (const auto & inst : params_instances) { - params_idx ++; + params_idx++; if (params.progress) { - fprintf(stderr, "llama-bench: benchmark %d/%ld: starting\n", params_idx, params_count); + fprintf(stderr, "llama-bench: benchmark %d/%zu: starting\n", params_idx, params_count); } // keep the same model between tests when possible if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) { if (lmodel) { - llama_free_model(lmodel); + llama_model_free(lmodel); } - lmodel = llama_load_model_from_file(inst.model.c_str(), inst.to_llama_mparams()); + lmodel = llama_model_load_from_file(inst.model.c_str(), inst.to_llama_mparams()); if (lmodel == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, inst.model.c_str()); return 1; @@ -1550,10 +1568,10 @@ int main(int argc, char ** argv) { prev_inst = &inst; } - llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams()); + llama_context * ctx = llama_init_from_model(lmodel, inst.to_llama_cparams()); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str()); - llama_free_model(lmodel); + llama_model_free(lmodel); return 1; } @@ -1575,7 +1593,7 @@ int main(int argc, char ** argv) { tpp.poll = t.poll; tpp.prio = params.prio; - struct ggml_threadpool* threadpool = ggml_threadpool_new(&tpp); + struct ggml_threadpool * threadpool = ggml_threadpool_new_fn(&tpp); if (!threadpool) { fprintf(stderr, "%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); exit(1); @@ -1586,16 +1604,16 @@ int main(int argc, char ** argv) { // warmup run if (t.n_prompt > 0) { if (params.progress) { - fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count); + fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count); } //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { - fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count); + fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count); } - test_gen(ctx, 1, 0, t.n_threads); + test_gen(ctx, 1, t.n_threads); } for (int i = 0; i < params.reps; i++) { @@ -1605,15 +1623,17 @@ int main(int argc, char ** argv) { if (t.n_prompt > 0) { if (params.progress) { - fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps); + fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count, + i + 1, params.reps); } - test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { - fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps); + fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count, + i + 1, params.reps); } - test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); + test_gen(ctx, t.n_gen, t.n_threads); } uint64_t t_ns = get_time_ns() - t_start; @@ -1630,14 +1650,14 @@ int main(int argc, char ** argv) { fflush(p_err->fout); } - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + llama_perf_context_print(ctx); llama_free(ctx); - ggml_threadpool_free(threadpool); + ggml_threadpool_free_fn(threadpool); } - llama_free_model(lmodel); + llama_model_free(lmodel); if (p) { p->print_footer(); diff --git a/examples/llama.android/llama/build.gradle.kts b/examples/llama.android/llama/build.gradle.kts index 0a3806172..28dbc1904 100644 --- a/examples/llama.android/llama/build.gradle.kts +++ b/examples/llama.android/llama/build.gradle.kts @@ -18,6 +18,8 @@ android { } externalNativeBuild { cmake { + arguments += "-DLLAMA_BUILD_COMMON=ON" + arguments += "-DGGML_LLAMAFILE=OFF" arguments += "-DCMAKE_BUILD_TYPE=Release" cppFlags += listOf() arguments += listOf() diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 06ec160c2..2a73983a9 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -87,7 +87,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi auto path_to_model = env->GetStringUTFChars(filename, 0); LOGi("Loading model from %s", path_to_model); - auto model = llama_load_model_from_file(path_to_model, model_params); + auto model = llama_model_load_from_file(path_to_model, model_params); env->ReleaseStringUTFChars(filename, path_to_model); if (!model) { @@ -102,7 +102,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { - llama_free_model(reinterpret_cast(model)); + llama_model_free(reinterpret_cast(model)); } extern "C" @@ -186,11 +186,11 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( for (nri = 0; nri < nr; nri++) { LOGi("Benchmark prompt processing (pp)"); - llama_batch_clear(*batch); + common_batch_clear(*batch); const int n_tokens = pp; for (i = 0; i < n_tokens; i++) { - llama_batch_add(*batch, 0, i, { 0 }, false); + common_batch_add(*batch, 0, i, { 0 }, false); } batch->logits[batch->n_tokens - 1] = true; @@ -210,9 +210,9 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { - llama_batch_clear(*batch); + common_batch_clear(*batch); for (j = 0; j < pl; j++) { - llama_batch_add(*batch, 0, i, { j }, true); + common_batch_add(*batch, 0, i, { j }, true); } LOGi("llama_decode() text generation: %d", i); @@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, nullptr, nullptr, nullptr, - 0, - 0, - 0, }; if (embd) { @@ -308,7 +305,9 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - llama_batch_free(*reinterpret_cast(batch_pointer)); + //llama_batch_free(*reinterpret_cast(batch_pointer)); + const auto batch = reinterpret_cast(batch_pointer); + delete batch; } extern "C" @@ -348,6 +347,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( jlong context_pointer, jlong batch_pointer, jstring jtext, + jboolean format_chat, jint n_len ) { @@ -357,7 +357,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( const auto context = reinterpret_cast(context_pointer); const auto batch = reinterpret_cast(batch_pointer); - const auto tokens_list = llama_tokenize(context, text, 1); + bool parse_special = (format_chat == JNI_TRUE); + const auto tokens_list = common_tokenize(context, text, true, parse_special); auto n_ctx = llama_n_ctx(context); auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); @@ -369,14 +370,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( } for (auto id : tokens_list) { - LOGi("%s", llama_token_to_piece(context, id).c_str()); + LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id); } - llama_batch_clear(*batch); + common_batch_clear(*batch); // evaluate the initial prompt for (auto i = 0; i < tokens_list.size(); i++) { - llama_batch_add(*batch, tokens_list[i], i, { 0 }, false); + common_batch_add(*batch, tokens_list[i], i, { 0 }, false); } // llama_decode will output logits only for the last token of the prompt @@ -406,6 +407,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( const auto batch = reinterpret_cast(batch_pointer); const auto sampler = reinterpret_cast(sampler_pointer); const auto model = llama_get_model(context); + const auto vocab = llama_model_get_vocab(model); if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); @@ -414,14 +416,12 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( // sample the most likely token const auto new_token_id = llama_sampler_sample(sampler, context, -1); - llama_sampler_accept(sampler, new_token_id); - const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); - if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { return nullptr; } - auto new_token_chars = llama_token_to_piece(context, new_token_id); + auto new_token_chars = common_token_to_piece(context, new_token_id); cached_token_chars += new_token_chars; jstring new_token = nullptr; @@ -433,8 +433,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( new_token = env->NewStringUTF(""); } - llama_batch_clear(*batch); - llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); + common_batch_clear(*batch); + common_batch_add(*batch, new_token_id, n_cur, { 0 }, true); env->CallVoidMethod(intvar_ncur, la_int_var_inc); diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index cf520e459..b964d93e3 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -65,6 +65,7 @@ class LLamaAndroid { context: Long, batch: Long, text: String, + formatChat: Boolean, nLen: Int ): Int @@ -115,10 +116,10 @@ class LLamaAndroid { } } - fun send(message: String): Flow = flow { + fun send(message: String, formatChat: Boolean = false): Flow = flow { when (val state = threadLocalState.get()) { is State.Loaded -> { - val ncur = IntVar(completion_init(state.context, state.batch, message, nlen)) + val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen)) while (ncur.value <= nlen) { val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur) if (str == null) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 92f61fe83..477c3e6f2 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -46,15 +46,14 @@ actor LlamaContext { let sparams = llama_sampler_chain_default_params() self.sampling = llama_sampler_chain_init(sparams) llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4)) - llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax()) llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234)) } deinit { llama_sampler_free(sampling) llama_batch_free(batch) + llama_model_free(model) llama_free(context) - llama_free_model(model) llama_backend_free() } @@ -66,7 +65,7 @@ actor LlamaContext { model_params.n_gpu_layers = 0 print("Running on simulator, force use n_gpu_layers = 0") #endif - let model = llama_load_model_from_file(path, model_params) + let model = llama_model_load_from_file(path, model_params) guard let model else { print("Could not load model at \(path)") throw LlamaError.couldNotInitializeContext @@ -152,9 +151,7 @@ actor LlamaContext { new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) - llama_sampler_accept(sampling, new_token_id) - - if llama_token_is_eog(model, new_token_id) || n_cur == n_len { + if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len { print("\n") is_done = true let new_token_str = String(cString: temporary_invalid_cchars + [0]) @@ -213,20 +210,20 @@ actor LlamaContext { llama_kv_cache_clear(context) - let t_pp_start = ggml_time_us() + let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000; if llama_decode(context, batch) != 0 { print("llama_decode() failed during prompt") } llama_synchronize(context) - let t_pp_end = ggml_time_us() + let t_pp_end = DispatchTime.now().uptimeNanoseconds / 1000; // bench text generation llama_kv_cache_clear(context) - let t_tg_start = ggml_time_us() + let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000; for i in 0.. -" Similarly, you could add an insert mode keybind with -" inoremap call llama#doLlamaGen() +" LLM-based text completion using llama.cpp " -" g:llama_api_url, g:llama_api_key and g:llama_overrides can be configured in your .vimrc -" let g:llama_api_url = "192.168.1.10:8080" -" llama_overrides can also be set through buffer/window scopes. For instance -" autocmd filetype python let b:llama_overrides = {"temp": 0.2} -" Could be added to your .vimrc to automatically set a lower temperature when -" editing a python script -" Additionally, an override dict can be stored at the top of a file -" !*{"stop": ["User:"]} -" Could be added to the start of your chatlog.txt to set the stopping token -" These parameter dicts are merged together from lowest to highest priority: -" server default -> g:llama_overrides -> w:llama_overrides -> -" b:llama_overrides -> in file (!*) overrides +" requires: +" +" - neovim or vim +" - curl +" - llama.cpp server instance +" - FIM-compatible model +" +" sample config: +" +" - Tab - accept the current suggestion +" - Shift+Tab - accept just the first line of the suggestion +" - Ctrl+F - toggle FIM completion manually +" +" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim +" +" start the llama.cpp server with a FIM-compatible model. for example: +" +" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa -dt 0.1 --ubatch-size 512 --batch-size 1024 --cache-reuse 256 +" +" --batch-size [512, model max context] +" +" adjust the batch size to control how much of the provided local context will be used during the inference +" lower values will use smaller part of the context around the cursor, which will result in faster processing +" +" --ubatch-size [64, 2048] +" +" chunks the batch into smaller chunks for faster processing +" depends on the specific hardware. use llama-bench to profile and determine the best size +" +" --cache-reuse (ge:llama_config.n_predict, 1024] +" +" this should be either 0 (disabled) or strictly larger than g:llama_config.n_predict +" using non-zero value enables context reuse on the server side which dramatically improves the performance at +" large contexts. a value of 256 should be good for all cases +" +" run this once to initialise llama.vim: +" +" :call llama#init() +" +" more info: https://github.com/ggerganov/llama.cpp/pull/9787 " -" Sublists (like logit_bias and stop) are overridden, not merged -" Example override: -" !*{"logit_bias": [[13, -5], [2, false]], "temperature": 1, "top_k": 5, "top_p": 0.5, "n_predict": 256, "repeat_last_n": 256, "repeat_penalty": 1.17647} -if !exists("g:llama_api_url") - let g:llama_api_url= "127.0.0.1:8080" -endif -if !exists("g:llama_overrides") - let g:llama_overrides = {} -endif -const s:querydata = {"n_predict": 256, "stop": [ "\n" ], "stream": v:true } -const s:curlcommand = ['curl','--data-raw', "{\"prompt\":\"### System:\"}", '--silent', '--no-buffer', '--request', 'POST', '--url', g:llama_api_url .. '/completion', '--header', "Content-Type: application/json"] -let s:linedict = {} -func s:callbackHandler(bufn, channel, msg) - if len(a:msg) < 3 - return - elseif a:msg[0] == "d" - let l:msg = a:msg[6:-1] - else - let l:msg = a:msg - endif - let l:decoded_msg = json_decode(l:msg) - let l:newtext = split(l:decoded_msg['content'], "\n", 1) - if len(l:newtext) > 0 - call setbufline(a:bufn, s:linedict[a:bufn], getbufline(a:bufn, s:linedict[a:bufn])[0] .. newtext[0]) - else - echo "nothing genned" - endif - if len(newtext) > 1 - let l:failed = appendbufline(a:bufn, s:linedict[a:bufn], newtext[1:-1]) - let s:linedict[a:bufn] = s:linedict[a:bufn] + len(newtext)-1 - endif - if has_key(l:decoded_msg, "stop") && l:decoded_msg.stop - echo "Finished generation" - endif -endfunction +" colors (adjust to your liking) +highlight llama_hl_hint guifg=#ff772f ctermfg=202 +highlight llama_hl_info guifg=#77ff2f ctermfg=119 -func llama#doLlamaGen() - if exists("b:job") - if job_status(b:job) == "run" - call job_stop(b:job) - return - endif - endif +" general parameters: +" +" endpoint: llama.cpp server endpoint +" n_prefix: number of lines before the cursor location to include in the local prefix +" n_suffix: number of lines after the cursor location to include in the local suffix +" n_predict: max number of tokens to predict +" t_max_prompt_ms: max alloted time for the prompt processing (TODO: not yet supported) +" t_max_predict_ms: max alloted time for the prediction +" show_info: show extra info about the inference (0 - disabled, 1 - statusline, 2 - inline) +" auto_fim: trigger FIM completion automatically on cursor movement +" max_line_suffix: do not auto-trigger FIM completion if there are more than this number of characters to the right of the cursor +" +" ring buffer of chunks, accumulated with time upon: +" +" - completion request +" - yank +" - entering a buffer +" - leaving a buffer +" - writing a file +" +" parameters for the ring-buffer with extra context: +" +" ring_n_chunks: max number of chunks to pass as extra context to the server (0 to disable) +" ring_chunk_size: max size of the chunks (in number of lines) +" note: adjust these numbers so that you don't overrun your context +" at ring_n_chunks = 64 and ring_chunk_size = 64 you need ~32k context +" ring_scope: the range around the cursor position (in number of lines) for gathering chunks after FIM +" ring_update_ms: how often to process queued chunks in normal mode +" +let s:default_config = { + \ 'endpoint': 'http://127.0.0.1:8012/infill', + \ 'n_prefix': 256, + \ 'n_suffix': 64, + \ 'n_predict': 128, + \ 't_max_prompt_ms': 500, + \ 't_max_predict_ms': 3000, + \ 'show_info': 2, + \ 'auto_fim': v:true, + \ 'max_line_suffix': 8, + \ 'ring_n_chunks': 64, + \ 'ring_chunk_size': 64, + \ 'ring_scope': 1024, + \ 'ring_update_ms': 1000, + \ } - let l:cbuffer = bufnr("%") - let s:linedict[l:cbuffer] = line('$') - let l:buflines = getbufline(l:cbuffer, 1, 1000) - let l:querydata = copy(s:querydata) - call extend(l:querydata, g:llama_overrides) - if exists("w:llama_overrides") - call extend(l:querydata, w:llama_overrides) - endif - if exists("b:llama_overrides") - call extend(l:querydata, b:llama_overrides) - endif - if l:buflines[0][0:1] == '!*' - let l:userdata = json_decode(l:buflines[0][2:-1]) - call extend(l:querydata, l:userdata) - let l:buflines = l:buflines[1:-1] - endif - let l:querydata.prompt = join(l:buflines, "\n") - let l:curlcommand = copy(s:curlcommand) - if exists("g:llama_api_key") - call extend(l:curlcommand, ['--header', 'Authorization: Bearer ' .. g:llama_api_key]) - endif - let l:curlcommand[2] = json_encode(l:querydata) - let b:job = job_start(l:curlcommand, {"callback": function("s:callbackHandler", [l:cbuffer])}) -endfunction +let g:llama_config = get(g:, 'llama_config', s:default_config) -" Echos the tokkenization of the provided string , or cursor to end of word -" Onus is placed on the user to include the preceding space -func llama#tokenizeWord(...) - if (a:0 > 0) - let l:input = a:1 - else - exe "normal \"*ye" - let l:input = @* - endif - let l:querydata = {"content": l:input} - let l:curlcommand = copy(s:curlcommand) - let l:curlcommand[2] = json_encode(l:querydata) - let l:curlcommand[8] = g:llama_api_url .. "/tokenize" - let s:token_job = job_start(l:curlcommand, {"callback": function("s:tokenizeWordCallback", [l:input])}) -endfunction - -func s:tokenizeWordCallback(plaintext, channel, msg) - echo '"' .. a:plaintext ..'" - ' .. string(json_decode(a:msg).tokens) -endfunction - - -" Echos the token count of the entire buffer (or provided string) -" Example usage :echo llama#tokenCount() -func llama#tokenCount(...) - if (a:0 > 0) - let l:buflines = a:1 - else - let l:buflines = getline(1,1000) - if l:buflines[0][0:1] == '!*' - let l:buflines = l:buflines[1:-1] +function! s:get_indent(str) + let l:count = 0 + for i in range(len(a:str)) + if a:str[i] == "\t" + let l:count += &tabstop - 1 + else + break endif - let l:buflines = join(l:buflines, "\n") - endif - let l:querydata = {"content": l:buflines} - let l:curlcommand = copy(s:curlcommand) - let l:curlcommand[2] = json_encode(l:querydata) - let l:curlcommand[8] = g:llama_api_url .. "/tokenize" - let s:token_job = job_start(l:curlcommand, {"callback": "s:tokenCountCallback"}) + endfor + return l:count endfunction -func s:tokenCountCallback(channel, msg) - let resp = json_decode(a:msg) - echo len(resp.tokens) +function! s:rand(i0, i1) abort + return a:i0 + rand() % (a:i1 - a:i0 + 1) +endfunction + +function! llama#init() + if !executable('curl') + echohl WarningMsg + echo 'llama.vim requires the "curl" command to be available' + echohl None + return + endif + + let s:pos_x = 0 " cursor position upon start of completion + let s:pos_y = 0 + + let s:line_cur = '' + + let s:line_cur_prefix = '' + let s:line_cur_suffix = '' + + let s:ring_chunks = [] " current set of chunks used as extra context + let s:ring_queued = [] " chunks that are queued to be sent for processing + let s:ring_n_evict = 0 + + let s:hint_shown = v:false + let s:pos_y_pick = -9999 " last y where we picked a chunk + let s:pos_dx = 0 + let s:content = [] + let s:can_accept = v:false + + let s:timer_fim = -1 + let s:t_fim_start = reltime() " used to measure total FIM time + let s:t_last_move = reltime() " last time the cursor moved + + let s:current_job = v:null + + let s:ghost_text_nvim = exists('*nvim_buf_get_mark') + let s:ghost_text_vim = has('textprop') + + if s:ghost_text_vim + let s:hlgroup_hint = 'llama_hl_hint' + let s:hlgroup_info = 'llama_hl_info' + + if empty(prop_type_get(s:hlgroup_hint)) + call prop_type_add(s:hlgroup_hint, {'highlight': s:hlgroup_hint}) + endif + if empty(prop_type_get(s:hlgroup_info)) + call prop_type_add(s:hlgroup_info, {'highlight': s:hlgroup_info}) + endif + endif + + augroup llama + autocmd! + autocmd InsertEnter * inoremap llama#fim_inline(v:false) + autocmd InsertLeavePre * call llama#fim_cancel() + + autocmd CursorMoved * call s:on_move() + autocmd CursorMovedI * call s:on_move() + autocmd CompleteChanged * call llama#fim_cancel() + + if g:llama_config.auto_fim + autocmd CursorMovedI * call llama#fim(v:true) + endif + + " gather chunks upon yanking + autocmd TextYankPost * if v:event.operator ==# 'y' | call s:pick_chunk(v:event.regcontents, v:false, v:true) | endif + + " gather chunks upon entering/leaving a buffer + autocmd BufEnter * call timer_start(100, {-> s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)}) + autocmd BufLeave * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true) + + " gather chunk upon saving the file + autocmd BufWritePost * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true) + augroup END + + silent! call llama#fim_cancel() + + " init background update of the ring buffer + if g:llama_config.ring_n_chunks > 0 + call s:ring_update() + endif +endfunction + +" compute how similar two chunks of text are +" 0 - no similarity, 1 - high similarity +" TODO: figure out something better +function! s:chunk_sim(c0, c1) + let l:lines0 = len(a:c0) + let l:lines1 = len(a:c1) + + let l:common = 0 + + for l:line0 in a:c0 + for l:line1 in a:c1 + if l:line0 == l:line1 + let l:common += 1 + break + endif + endfor + endfor + + return 2.0 * l:common / (l:lines0 + l:lines1) +endfunction + +" pick a random chunk of size g:llama_config.ring_chunk_size from the provided text and queue it for processing +" +" no_mod - do not pick chunks from buffers with pending changes +" do_evict - evict chunks that are very similar to the new one +" +function! s:pick_chunk(text, no_mod, do_evict) + " do not pick chunks from buffers with pending changes or buffers that are not files + if a:no_mod && (getbufvar(bufnr('%'), '&modified') || !buflisted(bufnr('%')) || !filereadable(expand('%'))) + return + endif + + " if the extra context option is disabled - do nothing + if g:llama_config.ring_n_chunks <= 0 + return + endif + + " don't pick very small chunks + if len(a:text) < 3 + return + endif + + if len(a:text) + 1 < g:llama_config.ring_chunk_size + let l:chunk = a:text + else + let l:l0 = s:rand(0, max([0, len(a:text) - g:llama_config.ring_chunk_size/2])) + let l:l1 = min([l:l0 + g:llama_config.ring_chunk_size/2, len(a:text)]) + + let l:chunk = a:text[l:l0:l:l1] + endif + + let l:chunk_str = join(l:chunk, "\n") . "\n" + + " check if this chunk is already added + let l:exist = v:false + + for i in range(len(s:ring_chunks)) + if s:ring_chunks[i].data == l:chunk + let l:exist = v:true + break + endif + endfor + + for i in range(len(s:ring_queued)) + if s:ring_queued[i].data == l:chunk + let l:exist = v:true + break + endif + endfor + + if l:exist + return + endif + + " evict queued chunks that are very similar to the new one + for i in range(len(s:ring_queued) - 1, 0, -1) + if s:chunk_sim(s:ring_queued[i].data, l:chunk) > 0.9 + if a:do_evict + call remove(s:ring_queued, i) + let s:ring_n_evict += 1 + else + return + endif + endif + endfor + + " also from s:ring_chunks + for i in range(len(s:ring_chunks) - 1, 0, -1) + if s:chunk_sim(s:ring_chunks[i].data, l:chunk) > 0.9 + if a:do_evict + call remove(s:ring_chunks, i) + let s:ring_n_evict += 1 + else + return + endif + endif + endfor + + " TODO: become parameter ? + if len(s:ring_queued) == 16 + call remove(s:ring_queued, 0) + endif + + call add(s:ring_queued, {'data': l:chunk, 'str': l:chunk_str, 'time': reltime(), 'filename': expand('%')}) + + "let &statusline = 'extra context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued) +endfunction + +" picks a queued chunk, sends it for processing and adds it to s:ring_chunks +" called every g:llama_config.ring_update_ms +function! s:ring_update() + call timer_start(g:llama_config.ring_update_ms, {-> s:ring_update()}) + + " update only if in normal mode or if the cursor hasn't moved for a while + if mode() !=# 'n' && reltimefloat(reltime(s:t_last_move)) < 3.0 + return + endif + + if len(s:ring_queued) == 0 + return + endif + + " move the first queued chunk to the ring buffer + if len(s:ring_chunks) == g:llama_config.ring_n_chunks + call remove(s:ring_chunks, 0) + endif + + call add(s:ring_chunks, remove(s:ring_queued, 0)) + + "let &statusline = 'updated context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued) + + " send asynchronous job with the new extra context so that it is ready for the next FIM + let l:extra_context = [] + for l:chunk in s:ring_chunks + call add(l:extra_context, { + \ 'text': l:chunk.str, + \ 'time': l:chunk.time, + \ 'filename': l:chunk.filename + \ }) + endfor + + " no samplers needed here + let l:request = json_encode({ + \ 'input_prefix': "", + \ 'input_suffix': "", + \ 'input_extra': l:extra_context, + \ 'prompt': "", + \ 'n_predict': 1, + \ 'temperature': 0.0, + \ 'stream': v:false, + \ 'samplers': ["temperature"], + \ 'cache_prompt': v:true, + \ 't_max_prompt_ms': 1, + \ 't_max_predict_ms': 1 + \ }) + + let l:curl_command = [ + \ "curl", + \ "--silent", + \ "--no-buffer", + \ "--request", "POST", + \ "--url", g:llama_config.endpoint, + \ "--header", "Content-Type: application/json", + \ "--data", l:request + \ ] + + " no callbacks because we don't need to process the response + if s:ghost_text_nvim + call jobstart(l:curl_command, {}) + elseif s:ghost_text_vim + call job_start(l:curl_command, {}) + endif +endfunction + +" necessary for 'inoremap ' +function! llama#fim_inline(is_auto) abort + call llama#fim(a:is_auto) + return '' +endfunction + +" the main FIM call +" takes local context around the cursor and sends it together with the extra context to the server for completion +function! llama#fim(is_auto) abort + " we already have a suggestion for the current cursor position + if s:hint_shown && !a:is_auto + call llama#fim_cancel() + return + endif + + call llama#fim_cancel() + + " avoid sending repeated requests too fast + if reltimefloat(reltime(s:t_fim_start)) < 0.6 + if s:timer_fim != -1 + call timer_stop(s:timer_fim) + let s:timer_fim = -1 + endif + + let s:t_fim_start = reltime() + let s:timer_fim = timer_start(600, {-> llama#fim(v:true)}) + return + endif + + let s:t_fim_start = reltime() + + let s:content = [] + let s:can_accept = v:false + + let s:pos_x = col('.') - 1 + let s:pos_y = line('.') + let l:max_y = line('$') + + let l:lines_prefix = getline(max([1, s:pos_y - g:llama_config.n_prefix]), s:pos_y - 1) + let l:lines_suffix = getline(s:pos_y + 1, min([l:max_y, s:pos_y + g:llama_config.n_suffix])) + + let s:line_cur = getline('.') + + let s:line_cur_prefix = strpart(s:line_cur, 0, s:pos_x) + let s:line_cur_suffix = strpart(s:line_cur, s:pos_x) + + if a:is_auto && len(s:line_cur_suffix) > g:llama_config.max_line_suffix + return + endif + + let l:prefix = "" + \ . join(l:lines_prefix, "\n") + \ . "\n" + + let l:prompt = "" + \ . s:line_cur_prefix + + let l:suffix = "" + \ . s:line_cur_suffix + \ . "\n" + \ . join(l:lines_suffix, "\n") + \ . "\n" + + " prepare the extra context data + let l:extra_context = [] + for l:chunk in s:ring_chunks + call add(l:extra_context, { + \ 'text': l:chunk.str, + \ 'time': l:chunk.time, + \ 'filename': l:chunk.filename + \ }) + endfor + + " the indentation of the current line + let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*')) + + let l:request = json_encode({ + \ 'input_prefix': l:prefix, + \ 'input_suffix': l:suffix, + \ 'input_extra': l:extra_context, + \ 'prompt': l:prompt, + \ 'n_predict': g:llama_config.n_predict, + \ 'n_indent': l:indent, + \ 'top_k': 40, + \ 'top_p': 0.99, + \ 'stream': v:false, + \ 'samplers': ["top_k", "top_p", "infill"], + \ 'cache_prompt': v:true, + \ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms, + \ 't_max_predict_ms': g:llama_config.t_max_predict_ms + \ }) + + let l:curl_command = [ + \ "curl", + \ "--silent", + \ "--no-buffer", + \ "--request", "POST", + \ "--url", g:llama_config.endpoint, + \ "--header", "Content-Type: application/json", + \ "--data", l:request + \ ] + + if s:current_job != v:null + if s:ghost_text_nvim + call jobstop(s:current_job) + elseif s:ghost_text_vim + call job_stop(s:current_job) + endif + endif + + " send the request asynchronously + if s:ghost_text_nvim + let s:current_job = jobstart(l:curl_command, { + \ 'on_stdout': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]), + \ 'on_exit': function('s:fim_on_exit'), + \ 'stdout_buffered': v:true + \ }) + elseif s:ghost_text_vim + let s:current_job = job_start(l:curl_command, { + \ 'out_cb': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]), + \ 'exit_cb': function('s:fim_on_exit') + \ }) + endif + + " TODO: per-file location + let l:delta_y = abs(s:pos_y - s:pos_y_pick) + + " gather some extra context nearby and process it in the background + " only gather chunks if the cursor has moved a lot + " TODO: something more clever? reranking? + if a:is_auto && l:delta_y > 32 + " expand the prefix even further + call s:pick_chunk(getline(max([1, s:pos_y - g:llama_config.ring_scope]), max([1, s:pos_y - g:llama_config.n_prefix])), v:false, v:false) + + " pick a suffix chunk + call s:pick_chunk(getline(min([l:max_y, s:pos_y + g:llama_config.n_suffix]), min([l:max_y, s:pos_y + g:llama_config.n_suffix + g:llama_config.ring_chunk_size])), v:false, v:false) + + let s:pos_y_pick = s:pos_y + endif +endfunction + +" if first_line == v:true accept only the first line of the response +function! llama#fim_accept(first_line) + " insert the suggestion at the cursor location + if s:can_accept && len(s:content) > 0 + call setline(s:pos_y, s:line_cur[:(s:pos_x - 1)] . s:content[0]) + if len(s:content) > 1 + if !a:first_line + call append(s:pos_y, s:content[1:-1]) + endif + endif + + " move the cursor to the end of the accepted text + if !a:first_line && len(s:content) > 1 + call cursor(s:pos_y + len(s:content) - 1, s:pos_x + s:pos_dx + 1) + else + call cursor(s:pos_y, s:pos_x + len(s:content[0])) + endif + endif + + call llama#fim_cancel() +endfunction + +function! llama#fim_cancel() + let s:hint_shown = v:false + + " clear the virtual text + let l:bufnr = bufnr('%') + + if s:ghost_text_nvim + let l:id_vt_fim = nvim_create_namespace('vt_fim') + call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1) + elseif s:ghost_text_vim + call prop_remove({'type': s:hlgroup_hint, 'all': v:true}) + call prop_remove({'type': s:hlgroup_info, 'all': v:true}) + endif + + " remove the mappings + silent! iunmap + silent! iunmap + silent! iunmap +endfunction + +function! s:on_move() + let s:t_last_move = reltime() + + call llama#fim_cancel() +endfunction + +" callback that processes the FIM result from the server and displays the suggestion +function! s:fim_on_stdout(pos_x, pos_y, is_auto, job_id, data, event = v:null) + if s:ghost_text_nvim + let l:raw = join(a:data, "\n") + elseif s:ghost_text_vim + let l:raw = a:data + endif + + if len(l:raw) == 0 + return + endif + + if a:pos_x != col('.') - 1 || a:pos_y != line('.') + return + endif + + " show the suggestion only in insert mode + if mode() !=# 'i' + return + endif + + let s:pos_x = a:pos_x + let s:pos_y = a:pos_y + + let s:can_accept = v:true + let l:has_info = v:false + + if s:can_accept && v:shell_error + if !a:is_auto + call add(s:content, "<| curl error: is the server on? |>") + endif + let s:can_accept = v:false + endif + + let l:n_prompt = 0 + let l:t_prompt_ms = 1.0 + let l:s_prompt = 0 + + let l:n_predict = 0 + let l:t_predict_ms = 1.0 + let l:s_predict = 0 + + " get the generated suggestion + if s:can_accept + let l:response = json_decode(l:raw) + + for l:part in split(get(l:response, 'content', ''), "\n", 1) + call add(s:content, l:part) + endfor + + " remove trailing new lines + while len(s:content) > 0 && s:content[-1] == "" + call remove(s:content, -1) + endwhile + + let l:generation_settings = get(l:response, 'generation_settings', {}) + let l:n_ctx = get(l:generation_settings, 'n_ctx', 0) + + let l:n_cached = get(l:response, 'tokens_cached', 0) + let l:truncated = get(l:response, 'truncated', v:false) + + " if response.timings is available + if len(get(l:response, 'timings', {})) > 0 + let l:has_info = v:true + let l:timings = get(l:response, 'timings', {}) + + let l:n_prompt = get(l:timings, 'prompt_n', 0) + let l:t_prompt_ms = get(l:timings, 'prompt_ms', 1) + let l:s_prompt = get(l:timings, 'prompt_per_second', 0) + + let l:n_predict = get(l:timings, 'predicted_n', 0) + let l:t_predict_ms = get(l:timings, 'predicted_ms', 1) + let l:s_predict = get(l:timings, 'predicted_per_second', 0) + endif + endif + + if len(s:content) == 0 + call add(s:content, "") + let s:can_accept = v:false + endif + + if len(s:content) == 0 + return + endif + + " NOTE: the following is logic for discarding predictions that repeat existing text + " the code is quite ugly and there is very likely a simpler and more canonical way to implement this + " + " still, I wonder if there is some better way that avoids having to do these special hacks? + " on one hand, the LLM 'sees' the contents of the file before we start editing, so it is normal that it would + " start generating whatever we have given it via the extra context. but on the other hand, it's not very + " helpful to re-generate the same code that is already there + + " truncate the suggestion if the first line is empty + if len(s:content) == 1 && s:content[0] == "" + let s:content = [""] + endif + + " ... and the next lines are repeated + if len(s:content) > 1 && s:content[0] == "" && s:content[1:] == getline(s:pos_y + 1, s:pos_y + len(s:content) - 1) + let s:content = [""] + endif + + " truncate the suggestion if it repeats the suffix + if len(s:content) == 1 && s:content[0] == s:line_cur_suffix + let s:content = [""] + endif + + " find the first non-empty line (strip whitespace) + let l:cmp_y = s:pos_y + 1 + while l:cmp_y < line('$') && getline(l:cmp_y) =~? '^\s*$' + let l:cmp_y += 1 + endwhile + + if (s:line_cur_prefix . s:content[0]) == getline(l:cmp_y) + " truncate the suggestion if it repeats the next line + if len(s:content) == 1 + let s:content = [""] + endif + + " ... or if the second line of the suggestion is the prefix of line l:cmp_y + 1 + if len(s:content) == 2 && s:content[-1] == getline(l:cmp_y + 1)[:len(s:content[-1]) - 1] + let s:content = [""] + endif + + " ... or if the middle chunk of lines of the suggestion is the same as [l:cmp_y + 1, l:cmp_y + len(s:content) - 1) + if len(s:content) > 2 && join(s:content[1:-1], "\n") == join(getline(l:cmp_y + 1, l:cmp_y + len(s:content) - 1), "\n") + let s:content = [""] + endif + endif + + " keep only lines that have the same or larger whitespace prefix as s:line_cur_prefix + "let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*')) + "for i in range(1, len(s:content) - 1) + " if strlen(matchstr(s:content[i], '^\s*')) < l:indent + " let s:content = s:content[:i - 1] + " break + " endif + "endfor + + let s:pos_dx = len(s:content[-1]) + + let s:content[-1] .= s:line_cur_suffix + + call llama#fim_cancel() + + " display virtual text with the suggestion + let l:bufnr = bufnr('%') + + if s:ghost_text_nvim + let l:id_vt_fim = nvim_create_namespace('vt_fim') + endif + + " construct the info message + if g:llama_config.show_info > 0 && l:has_info + let l:prefix = ' ' + + if l:truncated + let l:info = printf("%s | WARNING: the context is full: %d / %d, increase the server context size or reduce g:llama_config.ring_n_chunks", + \ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim', + \ l:n_cached, l:n_ctx + \ ) + else + let l:info = printf("%s | c: %d / %d, r: %d / %d, e: %d, q: %d / 16 | p: %d (%.2f ms, %.2f t/s) | g: %d (%.2f ms, %.2f t/s) | t: %.2f ms", + \ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim', + \ l:n_cached, l:n_ctx, len(s:ring_chunks), g:llama_config.ring_n_chunks, s:ring_n_evict, len(s:ring_queued), + \ l:n_prompt, l:t_prompt_ms, l:s_prompt, + \ l:n_predict, l:t_predict_ms, l:s_predict, + \ 1000.0 * reltimefloat(reltime(s:t_fim_start)) + \ ) + endif + + if g:llama_config.show_info == 1 + " display the info in the statusline + let &statusline = l:info + let l:info = '' + endif + endif + + " display the suggestion and append the info to the end of the first line + if s:ghost_text_nvim + call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, { + \ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']], + \ 'virt_text_win_col': virtcol('.') - 1 + \ }) + + call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, { + \ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}), + \ 'virt_text_win_col': virtcol('.') + \ }) + elseif s:ghost_text_vim + let l:new_suffix = s:content[0] + if !empty(l:new_suffix) + call prop_add(s:pos_y, s:pos_x + 1, { + \ 'type': s:hlgroup_hint, + \ 'text': l:new_suffix + \ }) + endif + for line in s:content[1:] + call prop_add(s:pos_y, 0, { + \ 'type': s:hlgroup_hint, + \ 'text': line, + \ 'text_padding_left': s:get_indent(line), + \ 'text_align': 'below' + \ }) + endfor + if !empty(l:info) + call prop_add(s:pos_y, 0, { + \ 'type': s:hlgroup_info, + \ 'text': l:info, + \ 'text_padding_left': col('$'), + \ 'text_wrap': 'truncate' + \ }) + endif + endif + + " setup accept shortcuts + inoremap :call llama#fim_accept(v:false) + inoremap :call llama#fim_accept(v:true) + + let s:hint_shown = v:true +endfunction + +function! s:fim_on_exit(job_id, exit_code, event = v:null) + if a:exit_code != 0 + echom "Job failed with exit code: " . a:exit_code + endif + + let s:current_job = v:null endfunction diff --git a/examples/llava/CMakeLists.txt b/examples/llava/CMakeLists.txt index bbf5fec58..3ce0d60c8 100644 --- a/examples/llava/CMakeLists.txt +++ b/examples/llava/CMakeLists.txt @@ -11,7 +11,7 @@ target_include_directories(llava PUBLIC .) target_include_directories(llava PUBLIC ../..) target_include_directories(llava PUBLIC ../../common) -target_compile_features(llava PRIVATE cxx_std_11) +target_compile_features(llava PRIVATE cxx_std_17) add_library(llava_static STATIC $) if (BUILD_SHARED_LIBS) @@ -35,11 +35,18 @@ add_executable(${TARGET} llava-cli.cpp) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-cli) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) set(TARGET llama-minicpmv-cli) add_executable(${TARGET} minicpmv-cli.cpp) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-minicpmv-cli) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-qwen2vl-cli) +add_executable(${TARGET} qwen2vl-cli.cpp) +set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-qwen2vl-cli) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/llava/MobileVLM-README.md b/examples/llava/MobileVLM-README.md index 06a65fba4..4f783f3ce 100644 --- a/examples/llava/MobileVLM-README.md +++ b/examples/llava/MobileVLM-README.md @@ -39,7 +39,7 @@ python ./examples/llava/llava_surgery.py -m path/to/MobileVLM-1.7B 3. Use `convert_image_encoder_to_gguf.py` with `--projector-type ldp` (for **V2** please use `--projector-type ldpv2`) to convert the LLaVA image encoder to GGUF: ```sh -python ./examples/llava/convert_image_encoder_to_gguf \ +python ./examples/llava/convert_image_encoder_to_gguf.py \ -m path/to/clip-vit-large-patch14-336 \ --llava-projector path/to/MobileVLM-1.7B/llava.projector \ --output-dir path/to/MobileVLM-1.7B \ @@ -47,7 +47,7 @@ python ./examples/llava/convert_image_encoder_to_gguf \ ``` ```sh -python ./examples/llava/convert_image_encoder_to_gguf \ +python ./examples/llava/convert_image_encoder_to_gguf.py \ -m path/to/clip-vit-large-patch14-336 \ --llava-projector path/to/MobileVLM-1.7B_V2/llava.projector \ --output-dir path/to/MobileVLM-1.7B_V2 \ @@ -57,12 +57,12 @@ python ./examples/llava/convert_image_encoder_to_gguf \ 4. Use `examples/convert_legacy_llama.py` to convert the LLaMA part of LLaVA to GGUF: ```sh -python ./examples/convert_legacy_llama.py path/to/MobileVLM-1.7B +python ./examples/convert_legacy_llama.py path/to/MobileVLM-1.7B --skip-unknown ``` -5. Use `quantize` to convert LLaMA part's DataType from `fp16` to `q4_k` +5. Use `quantize` to convert LLaMA part's DataType from `fp32` to `q4_k` ```sh -./llama-quantize path/to/MobileVLM-1.7B/ggml-model-f16.gguf path/to/MobileVLM-1.7B/ggml-model-q4_k.gguf q4_k_s +./llama-quantize path/to/MobileVLM-1.7B/ggml-model-F32.gguf path/to/MobileVLM-1.7B/ggml-model-q4_k.gguf q4_k_s ``` Now both the LLaMA part and the image encoder is in the `MobileVLM-1.7B` directory. diff --git a/examples/llava/README-minicpmo2.6.md b/examples/llava/README-minicpmo2.6.md new file mode 100644 index 000000000..8713a43d6 --- /dev/null +++ b/examples/llava/README-minicpmo2.6.md @@ -0,0 +1,46 @@ +## MiniCPM-o 2.6 +Currently, this readme only supports minicpm-omni's image capabilities, and we will update the full-mode support as soon as possible. + +### Prepare models and code + +Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch model from huggingface to "MiniCPM-o-2_6" folder. + +Clone llama.cpp: +```bash +git clone git@github.com:OpenBMB/llama.cpp.git +cd llama.cpp +git checkout minicpm-omni +``` + +### Usage of MiniCPM-o 2.6 + +Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us) + +```bash +python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6 +python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4 +python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model + +# quantize int4 version +./llama-quantize ../MiniCPM-o-2_6/model/ggml-model-f16.gguf ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M +``` + +Build llama.cpp using `CMake`: +https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md + +```bash +cmake -B build +cmake --build build --config Release +``` + +Inference on Linux or Mac +``` +# run f16 version +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# run quantized int4 version +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" + +# or run in interactive mode +./llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i +``` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 9b890571e..24073c5a9 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -3,26 +3,31 @@ // I'll gradually clean and extend it // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch #include "clip.h" -#include "log.h" #include "ggml.h" +#include "ggml-cpu.h" #include "ggml-alloc.h" #include "ggml-backend.h" +#include "gguf.h" -#ifdef GGML_USE_CUDA -#include "ggml-cuda.h" -#endif - -#ifdef GGML_USE_METAL -#include "ggml-metal.h" -#endif - -#ifdef GGML_USE_CANN -#include "ggml-cann.h" -#endif - -#ifdef GGML_USE_VULKAN -#include "ggml-vulkan.h" -#endif +//#ifdef GGML_USE_CUDA +//#include "ggml-cuda.h" +//#endif +// +//#ifdef GGML_USE_SYCL +//#include "ggml-sycl.h" +//#endif +// +//#ifdef GGML_USE_METAL +//#include "ggml-metal.h" +//#endif +// +//#ifdef GGML_USE_CANN +//#include "ggml-cann.h" +//#endif +// +//#ifdef GGML_USE_VULKAN +//#include "ggml-vulkan.h" +//#endif #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -40,6 +45,18 @@ #include #include +#if defined(LLAVA_LOG_OFF) +# define LOG_INF(...) +# define LOG_WRN(...) +# define LOG_ERR(...) +# define LOG_DBG(...) +#else // defined(LLAVA_LOG_OFF) +# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) +# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) +# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) +# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) +#endif // defined(LLAVA_LOG_OFF) + //#define CLIP_DEBUG_FUNCTIONS // RGB uint8 image @@ -86,7 +103,9 @@ static std::string format(const char * fmt, ...) { #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_USE_GELU "clip.use_gelu" +#define KEY_USE_SILU "clip.use_silu" #define KEY_N_EMBD "clip.%s.embedding_length" #define KEY_N_FF "clip.%s.feed_forward_length" #define KEY_N_BLOCK "clip.%s.block_count" @@ -113,7 +132,8 @@ static std::string format(const char * fmt, ...) { #define TN_TOKEN_EMBD "%s.token_embd.weight" #define TN_POS_EMBD "%s.position_embd.weight" #define TN_CLASS_EMBD "v.class_embd" -#define TN_PATCH_EMBD "v.patch_embd.weight" +#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat +#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" #define TN_PATCH_BIAS "v.patch_embd.bias" #define TN_ATTN_K "%s.blk.%d.attn_k.%s" #define TN_ATTN_Q "%s.blk.%d.attn_q.%s" @@ -147,6 +167,7 @@ enum projector_type { PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_RESAMPLER, + PROJECTOR_TYPE_MERGER, PROJECTOR_TYPE_UNKNOWN, }; @@ -155,6 +176,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, { PROJECTOR_TYPE_RESAMPLER, "resampler"}, + { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, }; @@ -165,7 +187,7 @@ static std::map PROJECTOR_TYPE_NAMES = { static int get_key_idx(const gguf_context * ctx, const char * key) { int i = gguf_find_key(ctx, key); if (i == -1) { - LOG_TEE("key %s not found in file\n", key); + LOG_ERR("key %s not found in file\n", key); throw std::runtime_error(format("Missing required key: %s", key)); } @@ -241,7 +263,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { { const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); int arr_n = gguf_get_arr_n(ctx_gguf, i); - const void * data = gguf_get_arr_data(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); std::stringstream ss; ss << "["; for (int j = 0; j < arr_n; j++) { @@ -270,7 +292,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { static void print_tensor_info(const ggml_tensor * tensor, const char * prefix = "") { size_t tensor_size = ggml_nbytes(tensor); - LOG_TEE("%s: n_dims = %d, name = %s, tensor_size=%zu, shape:[%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "], type = %s\n", + LOG_INF("%s: n_dims = %d, name = %s, tensor_size=%zu, shape:[%" PRId64 ", %" PRId64 ", %" PRId64 ", %" PRId64 "], type = %s\n", prefix, ggml_n_dims(tensor), tensor->name, tensor_size, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], ggml_type_name(tensor->type)); } @@ -288,7 +310,7 @@ static projector_type clip_projector_type_from_string(const std::string & name) static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) { std::ofstream file(filename, std::ios::binary); if (!file.is_open()) { - LOG_TEE("Failed to open file for writing: %s\n", filename.c_str()); + LOG_ERR("Failed to open file for writing: %s\n", filename.c_str()); return; } @@ -307,7 +329,7 @@ static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::s static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename) { std::ofstream file(filename, std::ios::binary); if (!file.is_open()) { - LOG_TEE("Failed to open file for writing: %s\n", filename.c_str()); + LOG_ERR("Failed to open file for writing: %s\n", filename.c_str()); return; } @@ -447,7 +469,8 @@ struct clip_vision_model { // embeddings struct ggml_tensor * class_embedding; - struct ggml_tensor * patch_embeddings; + struct ggml_tensor * patch_embeddings_0; + struct ggml_tensor * patch_embeddings_1; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) struct ggml_tensor * patch_bias; struct ggml_tensor * position_embeddings; @@ -537,6 +560,7 @@ struct clip_ctx { bool has_vision_encoder = false; bool has_llava_projector = false; bool has_minicpmv_projector = false; + bool has_qwen2vl_merger = false; int minicpmv_version = 2; struct clip_vision_model vision_model; @@ -545,6 +569,7 @@ struct clip_ctx { float image_mean[3]; float image_std[3]; bool use_gelu = false; + bool use_silu = false; int32_t ftype = 1; bool has_class_embedding = true; @@ -568,7 +593,7 @@ struct clip_ctx { static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { if (!ctx->has_vision_encoder) { - LOG_TEE("This gguf file seems to have no vision encoder\n"); + LOG_ERR("This gguf file seems to have no vision encoder\n"); return nullptr; } @@ -582,7 +607,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 if (load_image_size == nullptr) { load_image_size = clip_image_size_init(); } - LOG_TEE("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height); + LOG_DBG("%s: %d %d\n", __func__, load_image_size->width, load_image_size->height); image_size_width = load_image_size->width; image_size_height = load_image_size->height; if (is_inf) { @@ -590,14 +615,26 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 image_size_height = imgs->data->ny; } } + else if (ctx->has_qwen2vl_merger) { + // use the image's native resolution when image is avaible + if (is_inf) { + // if (imgs->data->nx && imgs->data->ny) { + image_size_width = imgs->data->nx; + image_size_height = imgs->data->ny; + } + } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int patches_w = image_size_width / patch_size; + const int patches_h = image_size_height / patch_size; const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0); + const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions; const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; int n_layer = hparams.n_layer; const float eps = hparams.eps; + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; const int batch_size = imgs->size; @@ -618,10 +655,30 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 ggml_set_name(inp_raw, "inp_raw"); ggml_set_input(inp_raw); - struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + if (ctx->has_qwen2vl_merger) { + GGML_ASSERT(image_size_width % (patch_size * 2) == 0); + GGML_ASSERT(image_size_height % (patch_size * 2) == 0); + + auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] + inp = ggml_reshape_4d( + ctx0, inp, + hidden_size * 2, patches_w / 2, patches_h, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); + inp = ggml_reshape_3d( + ctx0, inp, + hidden_size, patches_w * patches_h, batch_size); + } + else { + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + } if (ctx->has_patch_bias) { // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); @@ -643,12 +700,14 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } } - struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); ggml_set_name(positions, "positions"); ggml_set_input(positions); - embeddings = - ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); + if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding + embeddings = + ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); + } if (ctx->has_minicpmv_projector) { int pos_w = image_size_width/patch_size; @@ -659,6 +718,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 else if (ctx->minicpmv_version == 3) { pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); } + else if (ctx->minicpmv_version == 4) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); + } ggml_set_name(pos_embed, "pos_embed"); ggml_set_input(pos_embed); } @@ -672,7 +734,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } // loop over layers - if (ctx->has_minicpmv_projector) { + if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) { + // TODO: figure out why we doing thing in this way ??? n_layer += 1; } for (int il = 0; il < n_layer - 1; il++) { @@ -694,8 +757,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); - Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + if (ctx->has_qwen2vl_merger) { + Q = ggml_rope_multi( + ctx0, Q, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + } + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); @@ -703,6 +771,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + if (ctx->has_qwen2vl_merger) { + K = ggml_rope_multi( + ctx0, K, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + } K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); @@ -742,6 +815,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 if (ctx->use_gelu) { cur = ggml_gelu_inplace(ctx0, cur); + } else if (ctx->use_silu) { + cur = ggml_silu_inplace(ctx0, cur); } else { cur = ggml_gelu_quick_inplace(ctx0, cur); } @@ -753,6 +828,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 cur = ggml_add(ctx0, embeddings, cur); embeddings = cur; + } // post-layernorm @@ -824,7 +900,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3)); mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); // stride = 1, padding = 1, bias is nullptr - block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1); + block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1); // layer norm // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] @@ -872,7 +948,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // block_2 { // stride = 2 - block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1); + block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1); // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1] // layer norm @@ -933,7 +1009,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 // mlp_2 ne [24, 24, 2048, 1] mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); // weight ne = [3, 3, 2048, 1] - struct ggml_tensor * peg_0 = ggml_conv_depthwise_2d(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1); + struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1); peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3)); peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b); mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3)); @@ -980,6 +1056,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 n_head = hidden_size/d_head; num_query = 64; } + else if (ctx->minicpmv_version == 4) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); @@ -1014,6 +1095,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 GGML_ASSERT(false); } } + else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); + + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); + + // GELU activation + embeddings = ggml_gelu(ctx0, embeddings); + + // Second linear layer + embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); + embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + } // build the graph ggml_build_forward_expand(gf, embeddings); @@ -1047,21 +1141,21 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { const int idx_name = gguf_find_key(ctx, KEY_NAME); if (idx_name != -1) { // make name optional temporarily as some of the uploaded models missing it due to a bug const std::string name = gguf_get_val_str(ctx, idx_name); - LOG_TEE("%s: model name: %s\n", __func__, name.c_str()); + LOG_INF("%s: model name: %s\n", __func__, name.c_str()); } - LOG_TEE("%s: description: %s\n", __func__, description.c_str()); - LOG_TEE("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx)); - LOG_TEE("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); - LOG_TEE("%s: n_tensors: %d\n", __func__, n_tensors); - LOG_TEE("%s: n_kv: %d\n", __func__, n_kv); - LOG_TEE("%s: ftype: %s\n", __func__, ftype_str.c_str()); - LOG_TEE("\n"); + LOG_INF("%s: description: %s\n", __func__, description.c_str()); + LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx)); + LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); + LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors); + LOG_INF("%s: n_kv: %d\n", __func__, n_kv); + LOG_INF("%s: ftype: %s\n", __func__, ftype_str.c_str()); + LOG_INF("\n"); } const int n_tensors = gguf_get_n_tensors(ctx); // kv const int n_kv = gguf_get_n_kv(ctx); - LOG_TEE("%s: loaded meta data with %d key-value pairs and %d tensors from %s\n", + LOG_INF("%s: loaded meta data with %d key-value pairs and %d tensors from %s\n", __func__, n_kv, n_tensors, fname); { std::map n_type; @@ -1072,7 +1166,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { n_type[type]++; } - LOG_TEE("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); + LOG_INF("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); for (int i = 0; i < n_kv; i++) { const char * name = gguf_get_key(ctx, i); const enum gguf_type type = gguf_get_kv_type(ctx, i); @@ -1088,7 +1182,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } replace_all(value, "\n", "\\n"); - LOG_TEE("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + LOG_INF("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); } // print type counts @@ -1097,7 +1191,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { continue; } - LOG_TEE("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); + LOG_INF("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); } } @@ -1112,7 +1206,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { size_t tensor_size = ggml_nbytes(cur); model_size += tensor_size; if (verbosity >= 3) { - LOG_TEE("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n", + LOG_INF("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n", __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type)); } } @@ -1137,29 +1231,34 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } } -#ifdef GGML_USE_CUDA - new_clip->backend = ggml_backend_cuda_init(0); - LOG_TEE("%s: CLIP using CUDA backend\n", __func__); -#endif - -#ifdef GGML_USE_METAL - new_clip->backend = ggml_backend_metal_init(); - LOG_TEE("%s: CLIP using Metal backend\n", __func__); -#endif - -#ifdef GGML_USE_CANN - new_clip->backend = ggml_backend_cann_init(0); - LOG_TEE("%s: CLIP using CANN backend\n", __func__); -#endif - -#ifdef GGML_USE_VULKAN - new_clip->backend = ggml_backend_vk_init(0); - LOG_TEE("%s: CLIP using Vulkan backend\n", __func__); -#endif +//#ifdef GGML_USE_CUDA +// new_clip->backend = ggml_backend_cuda_init(0); +// LOG_INF("%s: CLIP using CUDA backend\n", __func__); +//#endif +// +//#ifdef GGML_USE_METAL +// new_clip->backend = ggml_backend_metal_init(); +// LOG_INF("%s: CLIP using Metal backend\n", __func__); +//#endif +// +//#ifdef GGML_USE_CANN +// new_clip->backend = ggml_backend_cann_init(0); +// LOG_INF("%s: CLIP using CANN backend\n", __func__); +//#endif +// +//#ifdef GGML_USE_VULKAN +// new_clip->backend = ggml_backend_vk_init(0); +// LOG_INF("%s: CLIP using Vulkan backend\n", __func__); +//#endif +// +//#ifdef GGML_USE_SYCL +// new_clip->backend = ggml_backend_sycl_init(0); +// LOG_INF("%s: CLIP using SYCL backend\n", __func__); +//#endif if (!new_clip->backend) { new_clip->backend = ggml_backend_cpu_init(); - LOG_TEE("%s: CLIP using CPU backend\n", __func__); + LOG_INF("%s: CLIP using CPU backend\n", __func__); } // model size and capabilities @@ -1185,6 +1284,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); } + idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER); + if (idx != -1) { + new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx); + } // GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search GGML_ASSERT(new_clip->has_vision_encoder); @@ -1193,17 +1296,24 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { idx = get_key_idx(ctx, KEY_USE_GELU); new_clip->use_gelu = gguf_get_val_bool(ctx, idx); + try { + idx = get_key_idx(ctx, KEY_USE_SILU); + new_clip->use_silu = gguf_get_val_bool(ctx, idx); + } catch (std::runtime_error & /*e*/) { + new_clip->use_silu = false; + } + if (verbosity >= 1) { - LOG_TEE("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); - LOG_TEE("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); - LOG_TEE("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); - LOG_TEE("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); - LOG_TEE("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); - LOG_TEE("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); + LOG_INF("%s: text_encoder: %d\n", __func__, new_clip->has_text_encoder); + LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); + LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); + LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); + LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); + LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); } } - LOG_TEE("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors); + LOG_INF("%s: params backend buffer size = % 6.2f MB (%i tensors)\n", __func__, model_size / (1024.0 * 1024.0), n_tensors); // load tensors { @@ -1216,7 +1326,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->ctx_data = ggml_init(params); if (!new_clip->ctx_data) { - LOG_TEE("%s: ggml_init() failed\n", __func__); + LOG_ERR("%s: ggml_init() failed\n", __func__); clip_free(new_clip); gguf_free(ctx); return nullptr; @@ -1224,7 +1334,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { auto fin = std::ifstream(fname, std::ios::binary); if (!fin) { - LOG_TEE("cannot open model file for loading tensors\n"); + LOG_ERR("cannot open model file for loading tensors\n"); clip_free(new_clip); gguf_free(ctx); return nullptr; @@ -1246,7 +1356,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i); fin.seekg(offset, std::ios::beg); if (!fin) { - LOG_TEE("%s: failed to seek for tensor %s\n", __func__, name); + LOG_ERR("%s: failed to seek for tensor %s\n", __func__, name); clip_free(new_clip); gguf_free(ctx); return nullptr; @@ -1317,23 +1427,23 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } if (verbosity >= 2) { - LOG_TEE("\n%s: vision model hparams\n", __func__); - LOG_TEE("image_size %d\n", hparams.image_size); - LOG_TEE("patch_size %d\n", hparams.patch_size); - LOG_TEE("v_hidden_size %d\n", hparams.hidden_size); - LOG_TEE("v_n_intermediate %d\n", hparams.n_intermediate); - LOG_TEE("v_projection_dim %d\n", hparams.projection_dim); - LOG_TEE("v_n_head %d\n", hparams.n_head); - LOG_TEE("v_n_layer %d\n", hparams.n_layer); - LOG_TEE("v_eps %f\n", hparams.eps); - LOG_TEE("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); - LOG_TEE("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); - LOG_TEE("v_image_grid_pinpoints: "); + LOG_INF("\n%s: vision model hparams\n", __func__); + LOG_INF("image_size %d\n", hparams.image_size); + LOG_INF("patch_size %d\n", hparams.patch_size); + LOG_INF("v_hidden_size %d\n", hparams.hidden_size); + LOG_INF("v_n_intermediate %d\n", hparams.n_intermediate); + LOG_INF("v_projection_dim %d\n", hparams.projection_dim); + LOG_INF("v_n_head %d\n", hparams.n_head); + LOG_INF("v_n_layer %d\n", hparams.n_layer); + LOG_INF("v_eps %f\n", hparams.eps); + LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); + LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); + LOG_INF("v_image_grid_pinpoints: "); for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) { - LOG_TEE("%d ", hparams.image_grid_pinpoints[i]); + LOG_INF("%d ", hparams.image_grid_pinpoints[i]); } - LOG_TEE("\n"); - LOG_TEE("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); + LOG_INF("\n"); + LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); } @@ -1368,10 +1478,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } try { - vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); + vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v")); } catch(const std::exception& /*e*/) { - LOG_TEE("%s: failed to load vision model tensors\n", __func__); + LOG_ERR("%s: failed to load vision model tensors\n", __func__); + } + try { + vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1); + } catch(const std::exception& /*e*/) { + new_clip->has_qwen2vl_merger = false; } // LLaVA projection @@ -1400,7 +1515,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } catch (std::runtime_error & /*e*/) { } try { vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE); - // LOG_TEE("%s: image_newline tensor (llava-1.6) found\n", __func__); + // LOG_INF("%s: image_newline tensor (llava-1.6) found\n", __func__); } catch (std::runtime_error & /*e*/) { } } else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) { // MobileVLM projection @@ -1460,6 +1575,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); } + else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) { + vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); + vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); + vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); + vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); + } else { std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); @@ -1498,10 +1619,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend)); clip_image_f32_batch batch; batch.size = 1; + batch.data = nullptr; ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch, nullptr, false); ggml_gallocr_reserve(new_clip->compute_alloc, gf); size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0); - LOG_TEE("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0); + LOG_INF("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0); } return new_clip; @@ -1511,6 +1633,10 @@ void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size ctx_clip->load_image_size = load_image_size; } +struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) { + return ctx_clip->load_image_size; +} + struct clip_image_size * clip_image_size_init() { struct clip_image_size * load_image_size = new struct clip_image_size(); load_image_size->width = 448; @@ -1552,7 +1678,7 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { int nx, ny, nc; auto * data = stbi_load(fname, &nx, &ny, &nc, 3); if (!data) { - LOG_TEE("%s: failed to load image '%s'\n", __func__, fname); + LOG_ERR("%s: failed to load image '%s'\n", __func__, fname); return false; } build_clip_img_from_data(data, nx, ny, img); @@ -1564,7 +1690,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length int nx, ny, nc; auto * data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3); if (!data) { - LOG_TEE("%s: failed to decode image bytes\n", __func__); + LOG_ERR("%s: failed to decode image bytes\n", __func__); return false; } build_clip_img_from_data(data, nx, ny, img); @@ -1754,7 +1880,7 @@ static std::pair select_best_resolution(const std::pair & or int downscaled_height = static_cast(original_height * scale); int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); int wasted_resolution = (width * height) - effective_resolution; - // LOG_TEE("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { max_effective_resolution = effective_resolution; min_wasted_resolution = wasted_resolution; @@ -1872,7 +1998,7 @@ static std::vector> uhd_slice_image(const clip_imag const int multiple = fmin(ceil(ratio), max_slice_nums); std::vector> images; - LOG_TEE("%s: multiple %d\n", __func__, multiple); + LOG_INF("%s: multiple %d\n", __func__, multiple); images.push_back(std::vector()); if (multiple <= 1) { @@ -1887,17 +2013,17 @@ static std::vector> uhd_slice_image(const clip_imag clip_image_u8 * source_image = clip_image_u8_init(); bicubic_resize(*img, *source_image, best_size.first, best_size.second); // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) - LOG_TEE("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); + LOG_INF("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); images[images.size()-1].push_back(source_image); std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); - LOG_TEE("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); + LOG_INF("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); clip_image_u8 * refine_image = clip_image_u8_init(); bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); - LOG_TEE("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); + LOG_INF("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); // split_to_patches int width = refine_image->nx; @@ -1923,6 +2049,7 @@ static std::vector> uhd_slice_image(const clip_imag images[images.size()-1].push_back(patch); } } + clip_image_u8_free(refine_image); } return images; } @@ -1954,19 +2081,43 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli int idx = 0; for (size_t i = 0; i < imgs.size(); ++i) { for (size_t j = 0; j < imgs[i].size(); ++j) { - LOG_TEE("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); + LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); clip_image_f32 * res = clip_image_f32_init(); normalize_image_u8_to_f32(imgs[i][j], res, ctx->image_mean, ctx->image_std); res_imgs->data[idx++] = *res; clip_image_f32_free(res); } } + for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t j = 0; j < imgs[i].size(); ++j) { + if (imgs[i][j] != nullptr) { + clip_image_u8_free(imgs[i][j]); + } + } + } + return true; + } + else if (ctx->has_qwen2vl_merger) { + clip_image_u8 * resized = clip_image_u8_init(); + auto patch_size = clip_patch_size(ctx) * 2; + int nx = ceil((float)img->nx / patch_size) * patch_size; + int ny = ceil((float)img->ny / patch_size) * patch_size; + bicubic_resize(*img, *resized, nx, ny); + + res_imgs->data = new clip_image_f32[1]; + // clip_image_f32 * res = clip_image_f32_init(); + normalize_image_u8_to_f32(resized, res_imgs->data, ctx->image_mean, ctx->image_std); + // res_imgs->data[0] = *res; + res_imgs->size = 1; + + // clip_image_f32_free(res); + clip_image_u8_free(resized); return true; } bool pad_to_square = true; if (!ctx->has_vision_encoder) { - LOG_TEE("This gguf file seems to have no vision encoder\n"); + LOG_ERR("This gguf file seems to have no vision encoder\n"); return false; } auto & params = ctx->vision_model.hparams; @@ -2043,7 +2194,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli } for (size_t i = 0; i < patches.size(); i++) { - // LOG_TEE("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny); + // LOG_DBG("patch %d: %d %d\n", i, patches[i]->nx, patches[i]->ny); clip_image_u8_free(patches[i]); } @@ -2152,6 +2303,13 @@ size_t clip_embd_nbytes(const struct clip_ctx * ctx) { return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float); } +size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) { + clip_image_f32 img; + img.nx = img_w; + img.ny = img_h; + return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float); +} + int32_t clip_image_size(const struct clip_ctx * ctx) { return ctx->vision_model.hparams.image_size; } @@ -2173,6 +2331,13 @@ const int32_t * clip_image_grid(const struct clip_ctx * ctx) { } int clip_n_patches(const struct clip_ctx * ctx) { + clip_image_f32 img; + img.nx = ctx->vision_model.hparams.image_size; + img.ny = ctx->vision_model.hparams.image_size; + return clip_n_patches_by_img(ctx, &img); +} + +int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) { const auto & params = ctx->vision_model.hparams; int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); @@ -2186,6 +2351,14 @@ int clip_n_patches(const struct clip_ctx * ctx) { else if (ctx->minicpmv_version == 3) { n_patches = 64; } + else if (ctx->minicpmv_version == 4) { + n_patches = 64; + } + } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + int patch_size = params.patch_size * 2; + int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); + int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); + n_patches = x_patch * y_patch; } return n_patches; @@ -2279,7 +2452,7 @@ static std::vector> get_2d_sincos_pos_embed(int embed_dim, co bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { if (!ctx->has_vision_encoder) { - LOG_TEE("This gguf file seems to have no vision encoder\n"); + LOG_ERR("This gguf file seems to have no vision encoder\n"); return false; } @@ -2291,7 +2464,7 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3 bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs, float * vec) { if (!ctx->has_vision_encoder) { - LOG_TEE("This gguf file seems to have no vision encoder\n"); + LOG_ERR("This gguf file seems to have no vision encoder\n"); return false; } @@ -2314,7 +2487,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int image_size = hparams.image_size; int image_size_width = image_size; int image_size_height = image_size; - if (ctx->has_minicpmv_projector) { + if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) { image_size_width = imgs->data[0].nx; image_size_height = imgs->data[0].ny; } @@ -2334,7 +2507,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (size_t i = 0; i < imgs->size; i++) { const int nx = imgs->data[i].nx; const int ny = imgs->data[i].ny; - if (!ctx->has_minicpmv_projector) { + if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) { GGML_ASSERT(nx == image_size && ny == image_size); } @@ -2360,8 +2533,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); - int bucket_coords_h[70]; - int bucket_coords_w[70]; + int bucket_coords_h[1024]; + int bucket_coords_w[1024]; for (int i = 0; i < pos_h; i++){ bucket_coords_h[i] = std::floor(70.0*i/pos_h); } @@ -2389,12 +2562,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima else if (ctx->minicpmv_version == 3) { embed_dim = 3584; } + else if (ctx->minicpmv_version == 4) { + embed_dim = 3584; + } auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); - for(int i=0;ihas_qwen2vl_merger) { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); + + const int pw = image_size_width / patch_size; + const int ph = image_size_height / patch_size; + int* positions_data = (int*)malloc(ggml_nbytes(positions)); + + int ptr = 0; + for (int y = 0; y < ph; y+=2) + { + for (int x = 0; x < pw; x+=2) + { + for (int dy = 0; dy < 2; dy++) { + for (int dx = 0; dx < 2; dx++) { + positions_data[ptr] = y + dy; + positions_data[num_patches + ptr] = x + dx; + positions_data[num_patches * 2 + ptr] = y + dy; + positions_data[num_patches * 3 + ptr] = x + dx; + ptr++; + } + } + } + } + + ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); + free(positions_data); + } + else { struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); @@ -2423,16 +2626,16 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); - } - { - struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); - int* patches_data = (int*)malloc(ggml_nbytes(patches)); - for (int i = 0; i < num_patches; i++) { - patches_data[i] = i + 1; + { + struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + int* patches_data = (int*)malloc(ggml_nbytes(patches)); + for (int i = 0; i < num_patches; i++) { + patches_data[i] = i + 1; + } + ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); + free(patches_data); } - ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); - free(patches_data); } } @@ -2440,16 +2643,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(ctx->backend)) { - ggml_backend_metal_set_n_cb(ctx->backend, n_threads); - } -#endif - ggml_backend_graph_compute(ctx->backend, gf); // the last node is the embedding tensor - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); @@ -2521,7 +2718,7 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i new_type = type; if (new_type >= GGML_TYPE_Q2_K && name.find("embd") != std::string::npos) { new_type = GGML_TYPE_Q8_0; // ggml_get_rows needs non K type - // LOG_TEE("%s: quantizing %s to %s\n", __func__, name.c_str(), ggml_type_name(new_type)); + // LOG_ERR("%s: quantizing %s to %s\n", __func__, name.c_str(), ggml_type_name(new_type)); } const size_t n_elms = ggml_nelements(cur); float * f32_data; @@ -2540,7 +2737,7 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i f32_data = (float *)conv_buf.data(); break; default: - LOG_TEE("Please use an input file in f32 or f16\n"); + LOG_ERR("Please use an input file in f32 or f16\n"); gguf_free(ctx_out); return false; } @@ -2560,14 +2757,15 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i total_size_org += orig_size; total_size_new += new_size; gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); + GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size); + gguf_set_tensor_data(ctx_out, name.c_str(), new_data); fout.write((const char *)new_data, new_size); size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size; for (size_t j = 0; j < pad; ++j) { fout.put(0); } - LOG_TEE("%s: n_dims = %d | quantize=%d | size = %f MB -> %f MB\n", name.c_str(), ggml_n_dims(cur), quantize, + LOG_INF("%s: n_dims = %d | quantize=%d | size = %f MB -> %f MB\n", name.c_str(), ggml_n_dims(cur), quantize, orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); } @@ -2583,8 +2781,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i gguf_free(ctx_out); { - LOG_TEE("%s: original size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0); - LOG_TEE("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0); + LOG_INF("%s: original size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0); + LOG_INF("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0); } return true; @@ -2610,6 +2808,12 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { else if (ctx->minicpmv_version == 3) { return 3584; } + else if (ctx->minicpmv_version == 4) { + return 3584; + } + } + if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + return ctx->vision_model.mm_1_b->ne[0]; } std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; @@ -2622,3 +2826,21 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) { } return 0; } + +bool clip_is_qwen2vl(const struct clip_ctx * ctx) { + return ctx->has_qwen2vl_merger; +} + + +bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { + clip_image_f32 clip_img; + clip_img.buf.resize(h * w * 3); + for (int i = 0; i < h*w*3; i++) + { + clip_img.buf[i] = img[i]; + } + clip_img.nx = w; + clip_img.ny = h; + clip_image_encode(ctx, n_threads, &clip_img, vec); + return true; +} diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 78588bdf1..1603edd26 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -45,6 +45,7 @@ CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity CLIP_API void clip_free(struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); +CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w); CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx); CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx); @@ -55,11 +56,13 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); -CLIP_API int clip_n_patches (const struct clip_ctx * ctx); -CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx); +CLIP_API int clip_n_patches (const struct clip_ctx * ctx); +CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img); +CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx); CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); +CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); CLIP_API struct clip_image_size * clip_image_size_init(); CLIP_API struct clip_image_u8 * clip_image_u8_init (); @@ -86,6 +89,9 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); +CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); + +CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); #ifdef __cplusplus } diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/examples/llava/convert_image_encoder_to_gguf.py index 36f6b92fb..4fa1d6cea 100644 --- a/examples/llava/convert_image_encoder_to_gguf.py +++ b/examples/llava/convert_image_encoder_to_gguf.py @@ -274,7 +274,7 @@ fout.add_bool("clip.use_gelu", use_gelu) if has_llava_projector: - model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue] + model.vision_model.encoder.layers.pop(-1) projector = torch.load(args.llava_projector) for name, data in projector.items(): name = get_tensor_name(name) @@ -288,7 +288,7 @@ if has_llava_projector: print("Projector tensors added\n") -state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue] +state_dict = model.state_dict() for name, data in state_dict.items(): if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector): # we don't need this diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 5845d0106..40aa0876f 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -1,14 +1,16 @@ -#include "ggml.h" +#include "arg.h" +#include "base64.hpp" #include "log.h" #include "common.h" +#include "sampling.h" #include "clip.h" #include "llava.h" #include "llama.h" - -#include "base64.hpp" +#include "ggml.h" #include #include +#include #include static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { @@ -18,8 +20,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { - LOG_TEE("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { + LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } *n_past += n_eval; @@ -35,21 +37,25 @@ static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ std::string str2 = str; - std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true); + std::vector embd_inp = common_tokenize(ctx_llama, str2, add_bos, true); eval_tokens(ctx_llama, embd_inp, n_batch, n_past); return true; } -static const char * sample(struct gpt_sampler * smpl, +static const char * sample(struct common_sampler * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1); - gpt_sampler_accept(smpl, id, true); + const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); + common_sampler_accept(smpl, id, true); + + const llama_model * model = llama_get_model(ctx_llama); + const llama_vocab * vocab = llama_model_get_vocab(model); + static std::string ret; - if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { + if (llama_vocab_is_eog(vocab, id)) { ret = ""; } else { - ret = llama_token_to_piece(ctx_llama, id); + ret = common_token_to_piece(ctx_llama, id); } eval_id(ctx_llama, id, n_past); return ret.c_str(); @@ -74,7 +80,7 @@ static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip size_t img_base64_str_start, img_base64_str_end; find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end); if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) { - LOG_TEE("%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END); + LOG_ERR("%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END); return NULL; } @@ -88,7 +94,7 @@ static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size()); if (!embed) { - LOG_TEE("%s: could not load image from base64 string.\n", __func__); + LOG_ERR("%s: could not load image from base64 string.\n", __func__); return NULL; } @@ -113,23 +119,23 @@ struct llava_context { }; static void print_usage(int, char ** argv) { - LOG_TEE("\n example usage:\n"); - LOG_TEE("\n %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); - LOG_TEE("\n note: a lower temperature value like 0.1 is recommended for better quality.\n"); + LOG("\n example usage:\n"); + LOG("\n %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); + LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n"); } -static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params, const std::string & fname) { +static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) { // load and preprocess the image llava_image_embed * embed = NULL; auto prompt = params->prompt; if (prompt_contains_image(prompt)) { if (!params->image.empty()) { - LOG_TEE("using base64 encoded image instead of command line image path\n"); + LOG_INF("using base64 encoded image instead of command line image path\n"); } embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt); if (!embed) { - LOG_TEE("%s: can't load image from prompt\n", __func__); + LOG_ERR("%s: can't load image from prompt\n", __func__); return NULL; } params->prompt = remove_image_from_prompt(prompt); @@ -144,7 +150,7 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para return embed; } -static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, gpt_params * params, const std::string & prompt) { +static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) { int n_past = 0; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; @@ -155,18 +161,18 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ // new templating mode: Provide the full prompt including system message and use as a placeholder for the image system_prompt = prompt.substr(0, image_pos); user_prompt = prompt.substr(image_pos + std::string("").length()); - LOG_TEE("system_prompt: %s\n", system_prompt.c_str()); + LOG_INF("system_prompt: %s\n", system_prompt.c_str()); if (params->verbose_prompt) { - auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, system_prompt, true, true); + auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true); for (int i = 0; i < (int) tmp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); + LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); } } - LOG_TEE("user_prompt: %s\n", user_prompt.c_str()); + LOG_INF("user_prompt: %s\n", user_prompt.c_str()); if (params->verbose_prompt) { - auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); + auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); for (int i = 0; i < (int) tmp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); + LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); } } } else { @@ -174,9 +180,9 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:"; user_prompt = prompt + "\nASSISTANT:"; if (params->verbose_prompt) { - auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); + auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); for (int i = 0; i < (int) tmp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); + LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); } } } @@ -187,11 +193,11 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ // generate the response - LOG_TEE("\n"); + LOG("\n"); - struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams); + struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); if (!smpl) { - fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); + LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); exit(1); } @@ -201,7 +207,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ response += tmp; if (strcmp(tmp, "") == 0) break; if (strstr(tmp, "###")) break; // Yi-VL behavior - printf("%s", tmp); + LOG("%s", tmp); if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6 if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6 @@ -209,25 +215,25 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ fflush(stdout); } - gpt_sampler_free(smpl); - printf("\n"); + common_sampler_free(smpl); + LOG("\n"); } -static struct llama_model * llava_init(gpt_params * params) { +static struct llama_model * llava_init(common_params * params) { llama_backend_init(); llama_numa_init(params->numa); - llama_model_params model_params = llama_model_params_from_gpt_params(*params); + llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); if (model == NULL) { - LOG_TEE("%s: error: unable to load model\n" , __func__); + LOG_ERR("%s: unable to load model\n" , __func__); return NULL; } return model; } -static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) { +static struct llava_context * llava_init_context(common_params * params, llama_model * model) { const char * clip_path = params->mmproj.c_str(); auto prompt = params->prompt; @@ -237,18 +243,17 @@ static struct llava_context * llava_init_context(gpt_params * params, llama_mode auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - - llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); + llama_context_params ctx_params = common_context_params_to_llama(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); if (ctx_llama == NULL) { - LOG_TEE("%s: error: failed to create the llama_context\n" , __func__); + LOG_ERR("%s: failed to create the llama_context\n" , __func__); return NULL; } - auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); + auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); ctx_llava->ctx_llama = ctx_llama; ctx_llava->ctx_clip = ctx_clip; @@ -263,76 +268,65 @@ static void llava_free(struct llava_context * ctx_llava) { } llama_free(ctx_llava->ctx_llama); - llama_free_model(ctx_llava->model); + llama_model_free(ctx_llava->model); llama_backend_free(); } -static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) { - (void) level; - (void) user_data; - LOG_TEE("%s", text); -} - int main(int argc, char ** argv) { ggml_time_init(); - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_LLAVA, print_usage); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) { return 1; } -#ifndef LOG_DISABLE_LOGS - log_set_target(log_filename_generator("llava", "log")); - LOG_TEE("Log start\n"); - log_dump_cmdline(argc, argv); - llama_log_set(llama_log_callback_logTee, nullptr); -#endif // LOG_DISABLE_LOGS + common_init(); if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { print_usage(argc, argv); return 1; } - auto model = llava_init(¶ms); + + auto * model = llava_init(¶ms); if (model == NULL) { fprintf(stderr, "%s: error: failed to init llava model\n", __func__); return 1; } if (prompt_contains_image(params.prompt)) { - auto ctx_llava = llava_init_context(¶ms, model); + auto * ctx_llava = llava_init_context(¶ms, model); - auto image_embed = load_image(ctx_llava, ¶ms, ""); + auto * image_embed = load_image(ctx_llava, ¶ms, ""); // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT); + llama_perf_context_print(ctx_llava->ctx_llama); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); } else { for (auto & image : params.image) { - auto ctx_llava = llava_init_context(¶ms, model); + auto * ctx_llava = llava_init_context(¶ms, model); - auto image_embed = load_image(ctx_llava, ¶ms, image); + auto * image_embed = load_image(ctx_llava, ¶ms, image); if (!image_embed) { - std::cerr << "error: failed to load image " << image << ". Terminating\n\n"; + LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str()); return 1; } // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT); + llama_perf_context_print(ctx_llava->ctx_llama); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); } } - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 851af0f00..2cac7933d 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -1,13 +1,27 @@ #include "clip.h" -#include "common.h" -#include "llama.h" #include "llava.h" -#include "base64.hpp" +#include "llama.h" + +#include +#include #include #include +#include +#include #include -#include + +#if defined(LLAVA_LOG_OFF) +# define LOG_INF(...) +# define LOG_WRN(...) +# define LOG_ERR(...) +# define LOG_DBG(...) +#else // defined(LLAVA_LOG_OFF) +# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) +# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) +# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) +# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) +#endif // defined(LLAVA_LOG_OFF) // RGB uint8 image struct clip_image_u8 { @@ -54,7 +68,7 @@ static std::pair select_best_resolution(const std::pair& ori int downscaled_height = static_cast(original_height * scale); int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); int wasted_resolution = (width * height) - effective_resolution; - // LOG_TEE("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { max_effective_resolution = effective_resolution; min_wasted_resolution = wasted_resolution; @@ -184,7 +198,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector // ggml_tensor_printf(flatten,"flatten",__LINE__,false,false); ggml_build_forward_expand(gf, flatten); ggml_graph_compute_with_ctx(model.ctx, gf, 1); - struct ggml_tensor* result = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor* result = ggml_graph_node(gf, -1); memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context // append without newline tokens (default behavior in llava_arch when not using unpad ): @@ -202,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector return true; } -static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) { +static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { int width = image->nx; int height = image->ny; int num_patches = (height / patch_size) * (width / patch_size); @@ -236,7 +250,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli img_res_v.size = 0; img_res_v.data = nullptr; if (!clip_image_preprocess(ctx_clip, img, &img_res_v)) { - LOG_TEE("%s: unable to preprocess image\n", __func__); + LOG_ERR("%s: unable to preprocess image\n", __func__); delete[] img_res_v.data; return false; } @@ -245,39 +259,44 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); - if (clip_is_minicpmv(ctx_clip)) { + if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) { std::vector image_embd_v; image_embd_v.resize(img_res_v.size); struct clip_image_size * load_image_size = clip_image_size_init(); + for (size_t i = 0; i < img_res_v.size; i++) { const int64_t t_img_enc_step_start_us = ggml_time_us(); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); + image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny)); int patch_size=14; load_image_size->width = img_res_v.data[i].nx; load_image_size->height = img_res_v.data[i].ny; clip_add_load_image_size(ctx_clip, load_image_size); + bool encoded = false; - int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); - if (has_minicpmv_projector == 2) { - encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); - } - else if (has_minicpmv_projector == 3) { + if (clip_is_qwen2vl(ctx_clip)) { encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); } + else { + encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); + } + if (!encoded) { - LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); + LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); return false; } const int64_t t_img_enc_steop_batch_us = ggml_time_us(); - LOG_TEE("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); + LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)img_res_v.size, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); } const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_TEE("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); int n_img_pos_out = 0; for (size_t i = 0; i < image_embd_v.size(); i++) { - std::memcpy(image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), image_embd_v[i], clip_embd_nbytes(ctx_clip)); - n_img_pos_out += clip_n_patches(ctx_clip); + std::memcpy( + image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), + image_embd_v[i], + clip_embd_nbytes_by_img(ctx_clip, img_res_v.data[i].nx, img_res_v.data[i].ny)); + n_img_pos_out += clip_n_patches_by_img(ctx_clip, &img_res_v.data[i]); } *n_img_pos = n_img_pos_out; for (size_t i = 0; i < image_embd_v.size(); i++) { @@ -287,7 +306,10 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli load_image_size->width = img->nx; load_image_size->height = img->ny; clip_add_load_image_size(ctx_clip, load_image_size); - LOG_TEE("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); + LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); + delete[] img_res_v.data; + img_res_v.size = 0; + img_res_v.data = nullptr; } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding @@ -295,7 +317,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); // image_embd shape is 576 x 4096 delete[] img_res_v.data; if (!encoded) { - LOG_TEE("Unable to encode image\n"); + LOG_ERR("Unable to encode image\n"); return false; } @@ -309,12 +331,12 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 const bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside if (!encoded) { - LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); + LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size); return false; } } const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_TEE("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); + LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); const int32_t * image_grid = clip_image_grid(ctx_clip); @@ -347,22 +369,22 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli // clip_image_save_to_bmp(*tmp, "image_feature.bmp"); } - LOG_TEE("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); + LOG_INF("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); const int64_t t_img_enc_end_us = ggml_time_us(); float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; - LOG_TEE("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos); + LOG_INF("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos); return true; } bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama)); auto n_image_embd = clip_n_mmproj_embd(ctx_clip); if (n_image_embd != n_llama_embd) { - LOG_TEE("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); + LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); return false; } return true; @@ -373,15 +395,21 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co if (clip_is_minicpmv(ctx_clip)) { num_max_patches = 10; } - float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model + float * image_embd; + if (clip_is_qwen2vl(ctx_clip)) { + // qwen2vl don't split image into chunks, so `num_max_patches` is not needed. + image_embd = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img->nx, img->ny)); + } else { + image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model + } if (!image_embd) { - LOG_TEE("Unable to allocate memory for image embeddings\n"); + LOG_ERR("Unable to allocate memory for image embeddings\n"); return false; } int n_img_pos; if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) { - LOG_TEE("%s: cannot encode image, aborting\n", __func__); + LOG_ERR("%s: cannot encode image, aborting\n", __func__); free(image_embd); return false; } @@ -391,17 +419,51 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co return true; } +struct llava_embd_batch { + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } +}; + bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { - int n_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { int n_eval = image_embed->n_image_pos - i; if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; - if (llama_decode(ctx_llama, batch)) { - LOG_TEE("%s : failed to eval\n", __func__); + float * embd = image_embed->embed+i*n_embd; + llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); + if (llama_decode(ctx_llama, llava_batch.batch)) { + LOG_ERR("%s : failed to eval\n", __func__); return false; } *n_past += n_eval; @@ -413,7 +475,7 @@ struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * c clip_image_u8 * img = clip_image_u8_init(); if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) { clip_image_u8_free(img); - LOG_TEE("%s: can't load image from bytes, is it a valid image?", __func__); + LOG_ERR("%s: can't load image from bytes, is it a valid image?", __func__); return NULL; } @@ -422,7 +484,7 @@ struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * c bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos); if (!image_embed_result) { clip_image_u8_free(img); - LOG_TEE("%s: coulnd't embed the image\n", __func__); + LOG_ERR("%s: couldn't embed the image\n", __func__); return NULL; } @@ -436,7 +498,7 @@ struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * c static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) { auto file = fopen(path, "rb"); if (file == NULL) { - LOG_TEE("%s: can't read file %s\n", __func__, path); + LOG_ERR("%s: can't read file %s\n", __func__, path); return false; } @@ -446,7 +508,7 @@ static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data if (buffer == NULL) { - LOG_TEE("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path); + LOG_ERR("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path); perror("Memory allocation error"); fclose(file); return false; @@ -454,10 +516,16 @@ static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long errno = 0; size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer if (ferror(file)) { - die_fmt("read error: %s", strerror(errno)); + LOG_ERR("read error: %s", strerror(errno)); + free(buffer); + fclose(file); + return false; } if (ret != (size_t) fileSize) { - die("unexpectedly reached end of file"); + LOG_ERR("unexpectedly reached end of file"); + free(buffer); + fclose(file); + return false; } fclose(file); // Close the file @@ -471,7 +539,7 @@ struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx long image_bytes_length; auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length); if (!loaded) { - LOG_TEE("%s: failed to load %s\n", __func__, image_path); + LOG_ERR("%s: failed to load %s\n", __func__, image_path); return NULL; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 57e7d42c5..53d902d61 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -1,13 +1,18 @@ -#include "ggml.h" +#include "arg.h" #include "log.h" #include "common.h" +#include "sampling.h" #include "clip.h" #include "llava.h" #include "llama.h" +#include "ggml.h" +#include #include #include +#include #include +#include // TODO: remove me struct llava_context { struct clip_ctx * ctx_clip = NULL; @@ -16,53 +21,47 @@ struct llava_context { }; static void show_additional_info(int /*argc*/, char ** argv) { - LOG_TEE("\n example usage: %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); - LOG_TEE(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); + LOG("\nexample usage:\n\n%s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); + LOG("\nnote: a lower temperature value like 0.1 is recommended for better quality.\n"); } -static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) { - (void) level; - (void) user_data; - LOG_TEE("%s", text); -} - -static struct llama_model * llava_init(gpt_params * params) { +static struct llama_model * llava_init(common_params * params) { llama_backend_init(); llama_numa_init(params->numa); - llama_model_params model_params = llama_model_params_from_gpt_params(*params); + llama_model_params model_params = common_model_params_to_llama(*params); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); if (model == NULL) { - LOG_TEE("%s: error: unable to load model\n" , __func__); + LOG_ERR("%s: unable to load model\n" , __func__); return NULL; } return model; } -static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) { +static struct llava_context * llava_init_context(common_params * params, llama_model * model) { auto prompt = params->prompt; if (prompt.empty()) { prompt = "describe the image in detail."; } - llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); + llama_context_params ctx_params = common_context_params_to_llama(*params); if (params->n_ctx < 2048) { // warn user here, "Image processing requires at least 2048 context, setting context to 2048" - LOG_TEE("%s: warn: Image processing requires at least 2048 context, setting context to 2048\n" , __func__); + LOG_WRN("%s: Image processing requires at least 2048 context, setting context to 2048\n" , __func__); ctx_params.n_ctx = 2048; } else { ctx_params.n_ctx = params->n_ctx; } - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); if (ctx_llama == NULL) { - LOG_TEE("%s: error: failed to create the llama_context\n" , __func__); + LOG_ERR("%s: failed to create the llama_context\n" , __func__); return NULL; } - auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); + auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); ctx_llava->ctx_llama = ctx_llama; ctx_llava->model = model; @@ -76,18 +75,18 @@ static void llava_free(struct llava_context * ctx_llava) { } llama_free(ctx_llava->ctx_llama); - llama_free_model(ctx_llava->model); + llama_model_free(ctx_llava->model); llama_backend_free(); } -static struct clip_ctx * clip_init_context(gpt_params * params) { +static struct clip_ctx * clip_init_context(common_params * params) { const char * clip_path = params->mmproj.c_str(); auto prompt = params->prompt; if (prompt.empty()) { prompt = "describe the image in detail."; } - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + auto * ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); return ctx_clip; } @@ -98,8 +97,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) { - LOG_TEE("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); + if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { + LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } *n_past += n_eval; @@ -115,7 +114,7 @@ static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ std::string str2 = str; - std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true); + std::vector embd_inp = common_tokenize(ctx_llama, str2, add_bos, true); return eval_tokens(ctx_llama, embd_inp, n_batch, n_past); } @@ -123,14 +122,14 @@ static void process_eval_image_embed(struct llava_context * ctx_llava, const str float * image_embed = (float *)malloc(clip_embd_nbytes(ctx_llava->ctx_clip)); std::memcpy(image_embed, embeds->embed + idx * clip_n_patches(ctx_llava->ctx_clip) * clip_n_mmproj_embd(ctx_llava->ctx_clip), clip_embd_nbytes(ctx_llava->ctx_clip)); - auto slice_embed = (llava_image_embed*)malloc(sizeof(llava_image_embed)); + auto * slice_embed = (llava_image_embed*)malloc(sizeof(llava_image_embed)); slice_embed->embed = image_embed; slice_embed->n_image_pos = clip_n_patches(ctx_llava->ctx_clip); llava_eval_image_embed(ctx_llava->ctx_llama, slice_embed, n_batch, n_past); llava_image_embed_free(slice_embed); } -static void process_image(struct llava_context * ctx_llava, struct llava_image_embed * embeds, gpt_params * params, int &n_past) { +static void process_image(struct llava_context * ctx_llava, struct llava_image_embed * embeds, common_params * params, int &n_past) { std::string system_prompt; int idx = 0; int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip); @@ -141,7 +140,10 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e else if (has_minicpmv_projector == 3) { system_prompt = "<|im_start|>user\n"; } - LOG_TEE("%s: image token past: %d\n", __func__, n_past); + else if (has_minicpmv_projector == 4) { + system_prompt = "<|im_start|>user\n"; + } + LOG_INF("%s: image token past: %d\n", __func__, n_past); eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false); process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); @@ -160,61 +162,65 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e } eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); } - LOG_TEE("%s: image token past: %d\n", __func__, n_past); + LOG_INF("%s: image token past: %d\n", __func__, n_past); } -static const char * sample(struct gpt_sampler * smpl, +static const char * sample(struct common_sampler * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1); - gpt_sampler_accept(smpl, id, true); + const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); + common_sampler_accept(smpl, id, true); + + const llama_model * model = llama_get_model(ctx_llama); + const llama_vocab * vocab = llama_model_get_vocab(model); + static std::string ret; - if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { + if (llama_vocab_is_eog(vocab, id)) { ret = ""; } else { - ret = llama_token_to_piece(ctx_llama, id); + ret = common_token_to_piece(ctx_llama, id); } eval_id(ctx_llama, id, n_past); return ret.c_str(); } -static struct llava_context * minicpmv_init(gpt_params * params, const std::string & fname, int &n_past){ - auto ctx_clip = clip_init_context(params); - auto embeds = llava_image_embed_make_with_filename(ctx_clip, params->cpuparams.n_threads, fname.c_str()); +static struct llava_context * minicpmv_init(common_params * params, const std::string & fname, int &n_past){ + auto * ctx_clip = clip_init_context(params); + auto * embeds = llava_image_embed_make_with_filename(ctx_clip, params->cpuparams.n_threads, fname.c_str()); if (!embeds) { - std::cerr << "error: failed to load image " << fname << ". Terminating\n\n"; + LOG_ERR("failed to load image %s. Terminating\n\n", fname.c_str()); return NULL; } // process the prompt if (params->prompt.empty() && params->interactive == false) { - LOG_TEE("prompt should be given or interactive mode should be on"); + LOG_ERR("prompt should be given or interactive mode should be on"); return NULL; } - auto model = llava_init(params); + auto * model = llava_init(params); if (model == NULL) { fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__); return NULL; } const int64_t t_llava_init_start_us = ggml_time_us(); - auto ctx_llava = llava_init_context(params, model); + auto * ctx_llava = llava_init_context(params, model); ctx_llava->ctx_clip = ctx_clip; const int64_t t_llava_init_end_us = ggml_time_us(); float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0; - LOG_TEE("\n%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms); + LOG_INF("%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms); const int64_t t_process_image_start_us = ggml_time_us(); process_image(ctx_llava, embeds, params, n_past); const int64_t t_process_image_end_us = ggml_time_us(); float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0; - LOG_TEE("\n%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms); + LOG_INF("%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms); llava_image_embed_free(embeds); return ctx_llava; } -static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ +static struct common_sampler * llama_init(struct llava_context * ctx_llava, common_params * params, const std::string & prompt, int & n_past, bool is_first = false){ std::string user_prompt = prompt; int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); if (!is_first) { @@ -224,6 +230,9 @@ static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_par else if (has_minicpmv_projector == 3) { user_prompt = "<|im_start|>user\n" + prompt; } + else if (has_minicpmv_projector == 4) { + user_prompt = "<|im_start|>user\n" + prompt; + } } eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); @@ -233,16 +242,19 @@ static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_par else if (has_minicpmv_projector == 3) { eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); } + else if (has_minicpmv_projector == 4) { + eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); + } // generate the response - LOG_TEE("\n"); + LOG_INF("\n"); - struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams); + struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); return smpl; } -static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampler * smpl, int &n_past){ +static const char * llama_loop(struct llava_context * ctx_llava,struct common_sampler * smpl, int &n_past){ const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); return tmp; @@ -251,19 +263,13 @@ static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampl int main(int argc, char ** argv) { ggml_time_init(); - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON, show_additional_info); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) { return 1; } -#ifndef LOG_DISABLE_LOGS - log_set_target(log_filename_generator("llava", "log")); - LOG_TEE("Log start\n"); - log_dump_cmdline(argc, argv); - llama_log_set(llama_log_callback_logTee, nullptr); -#endif // LOG_DISABLE_LOGS + common_init(); if (params.mmproj.empty() || (params.image.empty())) { show_additional_info(argc, argv); @@ -272,21 +278,23 @@ int main(int argc, char ** argv) { for (auto & image : params.image) { int n_past = 0; - auto ctx_llava = minicpmv_init(¶ms, image, n_past); + auto * ctx_llava = minicpmv_init(¶ms, image, n_past); if (!params.prompt.empty()) { - LOG_TEE("%s\n", params.prompt.c_str()); - LOG_TEE(""); - auto smpl = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true); + LOG("%s\n", params.prompt.c_str()); + LOG(""); + auto * smpl = llama_init(ctx_llava, ¶ms, params.prompt, n_past, true); const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; - std::string response = ""; + std::string response; bool have_tmp = false; for (int i = 0; i < max_tgt_len; i++) { - auto tmp = llama_loop(ctx_llava, smpl, n_past); + const auto * tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0){ - if(!have_tmp)continue; - else break; + if (!have_tmp) { + continue; + } + break; } if (strstr(tmp, "###")) break; // Yi-VL behavior have_tmp = true; @@ -295,30 +303,29 @@ int main(int argc, char ** argv) { fflush(stdout); } - gpt_sampler_free(smpl); + common_sampler_free(smpl); }else { while (true) { - LOG_TEE(""); + LOG(""); std::string prompt; std::getline(std::cin, prompt); - LOG_TEE(""); - auto smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true); + LOG(""); + auto * smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true); const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; - std::string response = ""; + std::string response; for (int i = 0; i < max_tgt_len; i++) { - auto tmp = llama_loop(ctx_llava, smpl, n_past); + const auto * tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0) break; - if (strstr(tmp, "###")) break; // Yi-VL behavior printf("%s", tmp);// mistral llava-1.6 if (strstr(response.c_str(), "")) break; // minicpm-v fflush(stdout); } - gpt_sampler_free(smpl); + common_sampler_free(smpl); } } printf("\n"); - llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT); + llama_perf_context_print(ctx_llava->ctx_llama); ctx_llava->model = NULL; llava_free(ctx_llava); diff --git a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py index ea773742a..9b196757f 100644 --- a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py +++ b/examples/llava/minicpmv-convert-image-encoder-to-gguf.py @@ -501,7 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073] default_image_std = [0.26862954, 0.26130258, 0.27577711] ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) -ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2) +ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2) # with proper args = ap.parse_args() @@ -545,12 +545,19 @@ if args.use_f32: minicpmv_version = args.minicpmv_version emb_dim = 4096 +block_count = 26 if minicpmv_version == 1: emb_dim = 2304 + block_count = 26 elif minicpmv_version == 2: emb_dim = 4096 + block_count = 27 elif minicpmv_version == 3: emb_dim = 3584 + block_count = 27 +elif minicpmv_version == 4: + emb_dim = 3584 + block_count = 27 default_vision_config = { "hidden_size": 1152, @@ -567,6 +574,9 @@ model = Idefics2VisionTransformer(vision_config) if minicpmv_version == 3: vision_config = SiglipVisionConfig(**default_vision_config) model = SiglipVisionTransformer(vision_config) +elif minicpmv_version == 4: + vision_config = SiglipVisionConfig(**default_vision_config) + model = SiglipVisionTransformer(vision_config) processor = None # if model.attn_pool is not None: @@ -587,7 +597,7 @@ elif args.minicpmv_projector is not None: fname_middle = "mmproj-" has_text_encoder = False has_minicpmv_projector = True - minicpmv_version = 3 + minicpmv_version = 4 elif args.vision_only: fname_middle = "vision-" has_text_encoder = False @@ -625,7 +635,6 @@ if has_vision_encoder: fout.add_uint32("clip.vision.projection_dim", 0) fout.add_uint32(add_key_str(KEY_ATTENTION_HEAD_COUNT, VISION), 16) fout.add_float32(add_key_str(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) - block_count = 26 fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count) if processor is not None: diff --git a/examples/llava/minicpmv-surgery.py b/examples/llava/minicpmv-surgery.py index 748ff5c57..ba8211658 100644 --- a/examples/llava/minicpmv-surgery.py +++ b/examples/llava/minicpmv-surgery.py @@ -8,7 +8,7 @@ ap.add_argument("-m", "--model", help="Path to MiniCPM-V model") args = ap.parse_args() # find the model part that includes the the multimodal projector weights -model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) +model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16) checkpoint = model.state_dict() # get a list of mm tensor names diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py new file mode 100644 index 000000000..c87606b4f --- /dev/null +++ b/examples/llava/qwen2_vl_surgery.py @@ -0,0 +1,165 @@ +import argparse +from typing import Dict + +import torch +import numpy as np +from gguf import * +from transformers import ( + Qwen2VLForConditionalGeneration, + Qwen2VLProcessor, + AutoProcessor, + Qwen2VLConfig +) + + +VISION = "clip.vision" + + +def k(raw_key: str, arch: str) -> str: + return raw_key.format(arch=arch) + + +def to_gguf_name(name: str) -> str: + og = name + name = name.replace("text_model", "t").replace("vision_model", "v") + name = name.replace("blocks", "blk").replace("embeddings.", "") + name = name.replace("attn.", "attn_") + name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") + # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") + name = name.replace("norm1", "ln1").replace("norm2", "ln2") + name = name.replace("merger.mlp", 'mm') + print(f"[to_gguf_name] {og} --> {name}") + return name + + +def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: + vision_model = qwen2vl.visual + tensor_map = {} + for name, ten in vision_model.state_dict().items(): + ten = ten.numpy() + if 'qkv' in name: + if ten.ndim == 2: # weight + c3, _ = ten.shape + else: # bias + c3 = ten.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = ten[:c] + wk = ten[c: c * 2] + wv = ten[c * 2:] + tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq + tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk + tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv + elif 'merger' in name: + if name.endswith("ln_q.weight"): + tensor_map['v.post_ln.weight'] = ten + elif name.endswith("ln_q.bias"): + tensor_map['v.post_ln.bias'] = ten + else: + # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" + tensor_map[to_gguf_name(name)] = ten + elif 'patch_embed.proj.weight' in name: + # NOTE: split Conv3D into Conv2Ds + c1, c2, kt, kh, kw = ten.shape + assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] + tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] + else: + tensor_map[to_gguf_name(f"vision_model.{name}")] = ten + + for new_name, ten in tensor_map.items(): + if ten.ndim <= 1 or new_name.endswith("_norm.weight"): + tensor_map[new_name] = ten.astype(np.float32) + else: + tensor_map[new_name] = ten.astype(dtype) + tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder + return tensor_map + + +def main(args): + if args.data_type == 'fp32': + dtype = torch.float32 + np_dtype = np.float32 + ftype = 0 + elif args.data_type == 'fp16': + dtype = torch.float32 + np_dtype = np.float16 + ftype = 1 + else: + raise ValueError() + + local_model = False + model_path = "" + model_name = args.model_name + print("model_name: ", model_name) + qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( + model_name, torch_dtype=dtype, device_map="cpu" + ) + cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] + vcfg = cfg.vision_config + + if os.path.isdir(model_name): + local_model = True + if model_name.endswith(os.sep): + model_name = model_name[:-1] + model_path = model_name + model_name = os.path.basename(model_name) + fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf" + + fout = GGUFWriter(path=fname_out, arch="clip") + fout.add_description("image encoder for Qwen2VL") + + fout.add_file_type(ftype) + fout.add_bool("clip.has_text_encoder", False) + fout.add_bool("clip.has_vision_encoder", True) + fout.add_bool("clip.has_qwen2vl_merger", True) + fout.add_string("clip.projector_type", "qwen2vl_merger") + + print(cfg.vision_config) + if 'silu' in cfg.vision_config.hidden_act.lower(): + fout.add_bool("clip.use_silu", True) + fout.add_bool("clip.use_gelu", False) + elif 'gelu' in cfg.vision_config.hidden_act.lower(): + fout.add_bool("clip.use_silu", False) + fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower()) + else: + raise ValueError() + + tensor_map = find_vision_tensors(qwen2vl, np_dtype) + for name, data in tensor_map.items(): + fout.add_tensor(name, data) + + fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) + fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2) + fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) + fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) + fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) + fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) + fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) + fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder + fout.add_name(model_name) + """ + HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig, + it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`. + """ + + if local_model: + processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path) + else: + processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name) + fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue] + fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue] + + fout.write_header_to_file() + fout.write_kv_data_to_file() + fout.write_tensors_to_file() + fout.close() + print("save model as: ", fname_out) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") + parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32") + args = parser.parse_args() + main(args) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp new file mode 100644 index 000000000..132a7da54 --- /dev/null +++ b/examples/llava/qwen2vl-cli.cpp @@ -0,0 +1,584 @@ +#include "arg.h" +#include "base64.hpp" +#include "log.h" +#include "common.h" +#include "sampling.h" +#include "clip.h" +#include "llava.h" +#include "llama.h" +#include "ggml.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif +#ifdef NDEBUG +#include "ggml-alloc.h" +#include "ggml-backend.h" +#endif + +#include +#include +#include +#include +#include +#include +#include + + +static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, + int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) { + int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); + const int patch_size = 14 * 2; + const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0); + const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0); + auto img_tokens = image_embed->n_image_pos; + // llama_pos mrope_pos[img_tokens * 4]; + std::vector mrope_pos; + mrope_pos.resize(img_tokens * 4); + + for (int y = 0; y < ph; y++) + { + for (int x = 0; x < pw; x++) + { + int i = y * pw + x; + mrope_pos[i] = *st_pos_id; + mrope_pos[i + img_tokens] = *st_pos_id + y; + mrope_pos[i + img_tokens * 2] = *st_pos_id + x; + mrope_pos[i + img_tokens * 3] = 0; + } + } + *st_pos_id += std::max(pw, ph); + + int processed = 0; + std::vector batch_mrope_pos; + batch_mrope_pos.resize(img_tokens * 4); + + for (int i = 0; i < img_tokens; i += n_batch) { + int n_eval = img_tokens - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + + // llama_pos batch_mrope_pos[n_eval * 4]; + std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0); + memcpy(batch_mrope_pos.data(), &mrope_pos[processed], n_eval * sizeof(llama_pos)); + memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos)); + memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); + memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); + + llama_batch batch = { + int32_t(n_eval), // n_tokens + nullptr, // token + (image_embed->embed+i*n_embd), // embed + batch_mrope_pos.data(), // pos + nullptr, // n_seq_id + nullptr, // seq_id + nullptr, // logits + }; + + if (llama_decode(ctx_llama, batch)) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + *n_past += n_eval; + processed += n_eval; + } + return true; +} + + +static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past, int * st_pos_id) { + int N = (int) tokens.size(); + std::vector pos; + for (int i = 0; i < N; i += n_batch) { + int n_eval = (int) tokens.size() - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + auto batch = llama_batch_get_one(&tokens[i], n_eval); + // TODO: add mrope pos ids somewhere else + pos.resize(batch.n_tokens * 4); + std::fill(pos.begin(), pos.end(), 0); + for (int j = 0; j < batch.n_tokens * 3; j ++) { + pos[j] = *st_pos_id + (j % batch.n_tokens); + } + batch.pos = pos.data(); + + if (llama_decode(ctx_llama, batch)) { + LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); + return false; + } + *n_past += n_eval; + *st_pos_id += n_eval; + } + return true; +} + +static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past, int * st_pos_id) { + std::vector tokens; + tokens.push_back(id); + return eval_tokens(ctx_llama, tokens, 1, n_past, st_pos_id); +} + +static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, int * st_pos_id, bool add_bos){ + std::string str2 = str; + std::vector embd_inp = common_tokenize(ctx_llama, str2, add_bos, true); + eval_tokens(ctx_llama, embd_inp, n_batch, n_past, st_pos_id); + return true; +} + +static const char * sample(struct common_sampler * smpl, + struct llama_context * ctx_llama, + int * n_past, int * st_pos_id) { + const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); + common_sampler_accept(smpl, id, true); + + const llama_model * model = llama_get_model(ctx_llama); + const llama_vocab * vocab = llama_model_get_vocab(model); + + static std::string ret; + if (llama_vocab_is_eog(vocab, id)) { + ret = ""; + } else { + ret = common_token_to_piece(ctx_llama, id); + } + eval_id(ctx_llama, id, n_past, st_pos_id); + return ret.c_str(); +} + +static const char* IMG_BASE64_TAG_BEGIN = ""; + +static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) { + begin_out = prompt.find(IMG_BASE64_TAG_BEGIN); + end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out); +} + +static bool prompt_contains_image(const std::string& prompt) { + size_t begin, end; + find_image_tag_in_prompt(prompt, begin, end); + return (begin != std::string::npos); +} + +// replaces the base64 image tag in the prompt with `replacement` +static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) { + size_t img_base64_str_start, img_base64_str_end; + find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end); + if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) { + LOG_ERR("%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END); + return NULL; + } + + auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN); + auto base64_bytes_count = img_base64_str_end - base64_bytes_start; + auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count ); + + auto required_bytes = base64::required_encode_size(base64_str.size()); + auto img_bytes = std::vector(required_bytes); + base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin()); + + auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size()); + if (!embed) { + LOG_ERR("%s: could not load image from base64 string.\n", __func__); + return NULL; + } + + return embed; +} + +static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") { + size_t begin, end; + find_image_tag_in_prompt(prompt, begin, end); + if (begin == std::string::npos || end == std::string::npos) { + return prompt; + } + auto pre = prompt.substr(0, begin); + auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END)); + return pre + replacement + post; +} + +struct llava_context { + struct clip_ctx * ctx_clip = NULL; + struct llama_context * ctx_llama = NULL; + struct llama_model * model = NULL; +}; + +static void print_usage(int, char ** argv) { + LOG("\n example usage:\n"); + LOG("\n %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); + LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n"); +} + +static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) { + + // load and preprocess the image + llava_image_embed * embed = NULL; + auto prompt = params->prompt; + if (prompt_contains_image(prompt)) { + if (!params->image.empty()) { + LOG_INF("using base64 encoded image instead of command line image path\n"); + } + embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt); + if (!embed) { + LOG_ERR("%s: can't load image from prompt\n", __func__); + return NULL; + } + params->prompt = remove_image_from_prompt(prompt); + } else { + embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str()); + if (!embed) { + fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str()); + return NULL; + } + } + + return embed; +} + +static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) { + int n_past = 0; + int cur_pos_id = 0; + + const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; + + std::string system_prompt, user_prompt; + size_t image_pos = prompt.find("<|vision_start|>"); + if (image_pos != std::string::npos) { + // new templating mode: Provide the full prompt including system message and use as a placeholder for the image + system_prompt = prompt.substr(0, image_pos); + user_prompt = prompt.substr(image_pos + std::string("<|vision_pad|>").length()); + LOG_INF("system_prompt: %s\n", system_prompt.c_str()); + if (params->verbose_prompt) { + auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); + } + } + LOG_INF("user_prompt: %s\n", user_prompt.c_str()); + if (params->verbose_prompt) { + auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); + } + } + } else { + // llava-1.5 native mode + system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|>"; + user_prompt = "<|vision_end|>" + prompt + "<|im_end|>\n<|im_start|>assistant\n"; + if (params->verbose_prompt) { + auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); + } + } + } + + eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, true); + if (image_embed != nullptr) { + auto image_size = clip_get_load_image_size(ctx_llava->ctx_clip); + qwen2vl_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past, &cur_pos_id, image_size); + } + eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, false); + + // generate the response + + LOG("\n"); + + struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); + if (!smpl) { + LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); + exit(1); + } + + std::string response = ""; + for (int i = 0; i < max_tgt_len; i++) { + const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past, &cur_pos_id); + response += tmp; + if (strcmp(tmp, "") == 0) break; + if (strstr(tmp, "###")) break; // Yi-VL behavior + LOG("%s", tmp); + if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) + if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6 + if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6 + + fflush(stdout); + } + + common_sampler_free(smpl); + LOG("\n"); +} + +static struct llama_model * llava_init(common_params * params) { + llama_backend_init(); + llama_numa_init(params->numa); + + llama_model_params model_params = common_model_params_to_llama(*params); + + llama_model * model = llama_model_load_from_file(params->model.c_str(), model_params); + if (model == NULL) { + LOG_ERR("%s: unable to load model\n" , __func__); + return NULL; + } + return model; +} + +static struct llava_context * llava_init_context(common_params * params, llama_model * model) { + const char * clip_path = params->mmproj.c_str(); + + auto prompt = params->prompt; + if (prompt.empty()) { + prompt = "describe the image in detail."; + } + + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + + llama_context_params ctx_params = common_context_params_to_llama(*params); + ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings + + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); + + if (ctx_llama == NULL) { + LOG_ERR("%s: failed to create the llama_context\n" , __func__); + return NULL; + } + + auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); + + ctx_llava->ctx_llama = ctx_llama; + ctx_llava->ctx_clip = ctx_clip; + ctx_llava->model = model; + return ctx_llava; +} + +static void llava_free(struct llava_context * ctx_llava) { + if (ctx_llava->ctx_clip) { + clip_free(ctx_llava->ctx_clip); + ctx_llava->ctx_clip = NULL; + } + + llama_free(ctx_llava->ctx_llama); + llama_model_free(ctx_llava->model); + llama_backend_free(); +} + +#ifndef NDEBUG + +static void debug_test_mrope_2d() { + // 1. Initialize backend + ggml_backend_t backend = NULL; + std::string backend_name = ""; +#ifdef GGML_USE_CUDA + fprintf(stderr, "%s: using CUDA backend\n", __func__); + backend = ggml_backend_cuda_init(0); // init device 0 + backend_name = "cuda"; + if (!backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } +#endif + // if there aren't GPU Backends fallback to CPU backend + if (!backend) { + backend = ggml_backend_cpu_init(); + backend_name = "cpu"; + } + + // Calculate the size needed to allocate + size_t ctx_size = 0; + ctx_size += 2 * ggml_tensor_overhead(); // tensors + // no need to allocate anything else! + + // 2. Allocate `ggml_context` to store tensor data + struct ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors() + }; + struct ggml_context * ctx = ggml_init(params); + + struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 12, 30); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 30 * 4); + ggml_set_name(pos, "pos"); + ggml_set_input(pos); + + std::vector dummy_q; + dummy_q.resize(128 * 12 * 30); + std::fill(dummy_q.begin(), dummy_q.end(), 0.1); + // memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw)); + + std::vector pos_id; + pos_id.resize(30 * 4); + for (int i = 0; i < 30; i ++) { + pos_id[i] = i; + pos_id[i + 30] = i + 10; + pos_id[i + 60] = i + 20; + pos_id[i + 90] = i + 30; + } + int sections[4] = {32, 32, 0, 0}; + + // 4. Allocate a `ggml_backend_buffer` to store all tensors + ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + + // 5. Copy tensor data from main memory (RAM) to backend buffer + ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw)); + ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos)); + + // 6. Create a `ggml_cgraph` for mul_mat operation + struct ggml_cgraph * gf = NULL; + struct ggml_context * ctx_cgraph = NULL; + + // create a temporally context to build the graph + struct ggml_init_params params0 = { + /*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + ctx_cgraph = ggml_init(params0); + gf = ggml_new_graph(ctx_cgraph); + + struct ggml_tensor * result0 = ggml_rope_multi( + ctx_cgraph, inp_raw, pos, nullptr, + 128/2, sections, LLAMA_ROPE_TYPE_VISION, 32768, 1000000, 1, + 0, 1, 32, 1); + + // Add "result" tensor and all of its dependencies to the cgraph + ggml_build_forward_expand(gf, result0); + + // 7. Create a `ggml_gallocr` for cgraph computation + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_alloc_graph(allocr, gf); + + // 9. Run the computation + int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + ggml_backend_graph_compute(backend, gf); + + // 10. Retrieve results (output tensors) + // in this example, output tensor is always the last tensor in the graph + struct ggml_tensor * result = result0; + // struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1]; + float * result_data = (float *)malloc(ggml_nbytes(result)); + // because the tensor data is stored in device buffer, we need to copy it back to RAM + ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); + const std::string bin_file = "mrope_2d_" + backend_name +".bin"; + std::ofstream outFile(bin_file, std::ios::binary); + + if (outFile.is_open()) { + outFile.write(reinterpret_cast(result_data), ggml_nbytes(result)); + outFile.close(); + std::cout << "Data successfully written to " + bin_file << std::endl; + } else { + std::cerr << "Error opening file!" << std::endl; + } + + free(result_data); + // 11. Free memory and exit + ggml_free(ctx_cgraph); + ggml_gallocr_free(allocr); + ggml_free(ctx); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); +} + +static void debug_dump_img_embed(struct llava_context * ctx_llava) { + int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); + int ne = n_embd * 4; + float vals[56 * 56 * 3]; + // float embd[ne]; + std::vector embd; + embd.resize(ne); + + for (int i = 0; i < 56*56; i++) + { + for (int c = 0; c < 3; c++) + vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56); + } + + clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data()); + + std::ofstream outFile("img_embed.bin", std::ios::binary); + if (outFile.is_open()) { + outFile.write(reinterpret_cast(embd.data()), ne * sizeof(float)); + + outFile.close(); + std::cout << "Data successfully written to mrope.bin" << std::endl; + } else { + std::cerr << "Error opening file!" << std::endl; + } +} + +#endif + + +int main(int argc, char ** argv) { + ggml_time_init(); + + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) { + return 1; + } + + common_init(); + + if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { + print_usage(argc, argv); + return 1; + } + + auto * model = llava_init(¶ms); + if (model == NULL) { + fprintf(stderr, "%s: error: failed to init llava model\n", __func__); + return 1; + } + + if (prompt_contains_image(params.prompt)) { + auto * ctx_llava = llava_init_context(¶ms, model); + + auto * image_embed = load_image(ctx_llava, ¶ms, ""); + + // process the prompt + process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); + + llama_perf_context_print(ctx_llava->ctx_llama); + llava_image_embed_free(image_embed); + ctx_llava->model = NULL; + llava_free(ctx_llava); +#ifndef NDEBUG + } else if (params.image[0].empty()) { + auto ctx_llava = llava_init_context(¶ms, model); + + debug_test_mrope_2d(); + debug_dump_img_embed(ctx_llava); + + llama_perf_context_print(ctx_llava->ctx_llama); + ctx_llava->model = NULL; + llava_free(ctx_llava); +#endif + } else { + for (auto & image : params.image) { + auto * ctx_llava = llava_init_context(¶ms, model); + + auto * image_embed = load_image(ctx_llava, ¶ms, image); + if (!image_embed) { + LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str()); + return 1; + } + + // process the prompt + process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); + + llama_perf_context_print(ctx_llava->ctx_llama); + llava_image_embed_free(image_embed); + ctx_llava->model = NULL; + llava_free(ctx_llava); + } + } + + llama_model_free(model); + + return 0; +} diff --git a/examples/lookahead/CMakeLists.txt b/examples/lookahead/CMakeLists.txt index f0ae5cd89..346861314 100644 --- a/examples/lookahead/CMakeLists.txt +++ b/examples/lookahead/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-lookahead) add_executable(${TARGET} lookahead.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 5027a483a..2f0898e62 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -1,4 +1,7 @@ +#include "arg.h" #include "common.h" +#include "sampling.h" +#include "log.h" #include "llama.h" #include @@ -34,54 +37,51 @@ struct ngram_container { }; int main(int argc, char ** argv) { - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } + common_init(); + const int W = 15; // lookahead window const int N = 5; // n-gram size const int G = 15; // max verification n-grams const bool dump_kv_cache = params.dump_kv_cache; -#ifndef LOG_DISABLE_LOGS - log_set_target(log_filename_generator("lookahead", "log")); - LOG_TEE("Log start\n"); - log_dump_cmdline(argc, argv); -#endif // LOG_DISABLE_LOGS - // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); // load the target model - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + + const llama_vocab * vocab = llama_model_get_vocab(model); // Tokenize the prompt std::vector inp; std::vector all; - inp = ::llama_tokenize(ctx, params.prompt, true, true); + inp = common_tokenize(ctx, params.prompt, true, true); all = inp; const int max_context_size = llama_n_ctx(ctx); const int max_tokens_list_size = max_context_size - 4; if ((int) inp.size() > max_tokens_list_size) { - fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); return 1; } - fprintf(stderr, "\n\n"); + LOG("\n\n"); for (auto id : inp) { - fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + LOG("%s", common_token_to_piece(ctx, id).c_str()); } fflush(stderr); @@ -91,8 +91,8 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); for (int s = 1; s < W + G + 1; ++s) { llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); @@ -117,7 +117,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams); + struct common_sampler * smpl = common_sampler_init(model, params.sampling); // verification n-grams std::vector ngrams_cur(G); @@ -149,7 +149,7 @@ int main(int argc, char ** argv) { } // here we keep adding new n-grams as we go - ngram_container ngrams_observed(llama_n_vocab(model), N, G); + ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G); // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1); @@ -158,14 +158,14 @@ int main(int argc, char ** argv) { // sample first token { - id = gpt_sampler_sample(smpl, ctx, 0); + id = common_sampler_sample(smpl, ctx, 0); - gpt_sampler_accept(smpl, id, true); + common_sampler_accept(smpl, id, true); { - const std::string token_str = llama_token_to_piece(ctx, id); + const std::string token_str = common_token_to_piece(ctx, id); - printf("%s", token_str.c_str()); + LOG("%s", token_str.c_str()); fflush(stdout); } } @@ -174,7 +174,7 @@ int main(int argc, char ** argv) { // debug if (dump_kv_cache) { llama_kv_cache_view_update(ctx, &kvc_view); - llama_kv_cache_dump_view_seqs(kvc_view, 40); + common_kv_cache_dump_view_seqs(kvc_view, 40); } // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/ @@ -203,10 +203,10 @@ int main(int argc, char ** argv) { // V V V V V V // id { - llama_batch_clear(batch); + common_batch_clear(batch); // current token - first token of the first level - llama_batch_add(batch, id, n_past, seq_id_all, true); + common_batch_add(batch, id, n_past, seq_id_all, true); // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation { @@ -231,7 +231,7 @@ int main(int argc, char ** argv) { ngrams_cur[g].tokens [j + 1] = t; ngrams_cur[g].i_batch[j + 1] = batch.n_tokens; - llama_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); + common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); } } } @@ -243,19 +243,19 @@ int main(int argc, char ** argv) { seq_id_look[j] = i + j + 1; } - llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); + common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); } // fill the rest of the levels for (int j = 1; j < N - 1; j++) { for (int i = 0; i < W; i++) { - llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); + common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); } } } if (llama_decode(ctx, batch) != 0) { - fprintf(stderr, "\n\n%s: error: llama_decode failed - increase KV cache size\n", __func__); + LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__); return 1; } @@ -283,23 +283,23 @@ int main(int argc, char ** argv) { } // sample the next token - id = gpt_sampler_sample(smpl, ctx, i_batch); + id = common_sampler_sample(smpl, ctx, i_batch); - gpt_sampler_accept(smpl, id, true); + common_sampler_accept(smpl, id, true); // print { - const std::string token_str = llama_token_to_piece(ctx, id); + const std::string token_str = common_token_to_piece(ctx, id); if (v == 0) { - printf("%s", token_str.c_str()); + LOG("%s", token_str.c_str()); } else { // print light cyan - printf("\033[0;96m%s\033[0m", token_str.c_str()); + LOG("\033[0;96m%s\033[0m", token_str.c_str()); } fflush(stdout); - if (llama_token_is_eog(model, id)) { + if (llama_vocab_is_eog(vocab, id)) { has_eos = true; } @@ -329,21 +329,21 @@ int main(int argc, char ** argv) { // print known n-grams starting with token id (debug) if (0 && v == 0) { if (ngrams_observed.cnt[id] > 0) { - printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str()); + LOG("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], common_token_to_piece(ctx, id).c_str()); } for (int i = 0; i < ngrams_observed.cnt[id]; i++) { - printf(" - ngram %2d: ", i); + LOG(" - ngram %2d: ", i); const int idx = id*(N - 1)*G + i*(N - 1); for (int j = 0; j < N - 1; j++) { - const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]); + const std::string token_str = common_token_to_piece(ctx, ngrams_observed.tokens[idx + j]); - printf("%s", token_str.c_str()); + LOG("%s", token_str.c_str()); } - printf("\n"); + LOG("\n"); } } @@ -360,7 +360,7 @@ int main(int argc, char ** argv) { if (v == 0) { // sample from the last level for (int i = 0; i < W; i++) { - tokens_j[N - 2][i] = gpt_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); } } else { for (int i = 0; i < W; i++) { @@ -454,34 +454,31 @@ int main(int argc, char ** argv) { auto t_dec_end = ggml_time_us(); - LOG_TEE("\n\n"); + LOG("\n\n"); - LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); - LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); - LOG_TEE("\n"); - LOG_TEE("W = %2d\n", W); - LOG_TEE("N = %2d\n", N); - LOG_TEE("G = %2d\n", G); - LOG_TEE("\n"); - LOG_TEE("n_predict = %d\n", n_predict); - LOG_TEE("n_accept = %d\n", n_accept); + LOG_INF("\n"); + LOG_INF("W = %2d\n", W); + LOG_INF("N = %2d\n", N); + LOG_INF("G = %2d\n", G); + LOG_INF("\n"); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_accept = %d\n", n_accept); - LOG_TEE("\n"); - gpt_perf_print(ctx, smpl); + LOG_INF("\n"); + common_perf_print(ctx, smpl); - gpt_sampler_free(smpl); + common_sampler_free(smpl); llama_kv_cache_view_free(&kvc_view); llama_batch_free(batch); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); - fprintf(stderr, "\n\n"); + LOG("\n\n"); return 0; } diff --git a/examples/lookup/CMakeLists.txt b/examples/lookup/CMakeLists.txt index ef19fe25e..fba78ceda 100644 --- a/examples/lookup/CMakeLists.txt +++ b/examples/lookup/CMakeLists.txt @@ -2,22 +2,22 @@ set(TARGET llama-lookup) add_executable(${TARGET} lookup.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) set(TARGET llama-lookup-create) add_executable(${TARGET} lookup-create.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) set(TARGET llama-lookup-merge) add_executable(${TARGET} lookup-merge.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) set(TARGET llama-lookup-stats) add_executable(${TARGET} lookup-stats.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index 795b06c88..3da45ed9e 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -1,20 +1,15 @@ -#include "ggml.h" -#include "llama.h" +#include "arg.h" #include "common.h" #include "ngram-cache.h" +#include "llama.h" -#include -#include -#include #include -#include #include int main(int argc, char ** argv){ - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { return 1; } @@ -23,21 +18,23 @@ int main(int argc, char ** argv){ llama_numa_init(params.numa); // load the model - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); + + llama_model_ptr & model = llama_init.model; + llama_context_ptr & ctx = llama_init.context; - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; GGML_ASSERT(model != nullptr); // tokenize the prompt std::vector inp; - inp = ::llama_tokenize(ctx, params.prompt, true, true); + inp = common_tokenize(ctx.get(), params.prompt, true, true); fprintf(stderr, "%s: tokenization done\n", __func__); - - llama_ngram_cache ngram_cache; - llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true); + common_ngram_cache ngram_cache; + common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true); fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str()); - llama_ngram_cache_save(ngram_cache, params.lookup_cache_static); + common_ngram_cache_save(ngram_cache, params.lookup_cache_static); + + return 0; } diff --git a/examples/lookup/lookup-merge.cpp b/examples/lookup/lookup-merge.cpp index 81e2b0436..6871c0f5f 100644 --- a/examples/lookup/lookup-merge.cpp +++ b/examples/lookup/lookup-merge.cpp @@ -33,15 +33,15 @@ int main(int argc, char ** argv){ } fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str()); - llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]); + common_ngram_cache ngram_cache_merged = common_ngram_cache_load(args[0]); for (size_t i = 1; i < args.size()-1; ++i) { fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str()); - llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]); + common_ngram_cache ngram_cache = common_ngram_cache_load(args[i]); - llama_ngram_cache_merge(ngram_cache_merged, ngram_cache); + common_ngram_cache_merge(ngram_cache_merged, ngram_cache); } fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str()); - llama_ngram_cache_save(ngram_cache_merged, args.back()); + common_ngram_cache_save(ngram_cache_merged, args.back()); } diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 93299ef8b..fcb289abe 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -1,44 +1,45 @@ -#include "ggml.h" +#include "arg.h" #include "common.h" -#include "llama.h" #include "log.h" #include "ngram-cache.h" +#include "llama.h" +#include "ggml.h" -#include #include #include +#include #include #include #include -#include int main(int argc, char ** argv){ - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { return 1; } - const int n_draft = params.n_draft; + common_init(); + + const int n_draft = params.speculative.n_max; // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); // load the model - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_context_ptr & ctx = llama_init.context; // tokenize the prompt std::vector inp; - inp = ::llama_tokenize(ctx, params.prompt, true, true); + inp = common_tokenize(ctx.get(), params.prompt, true, true); + + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; - llama_ngram_cache ngram_cache_context; - llama_ngram_cache ngram_cache_dynamic; - llama_ngram_cache ngram_cache_static; int64_t t_draft_flat_us = 0; int64_t t_draft_us = 0; @@ -47,16 +48,16 @@ int main(int argc, char ** argv){ if (!params.lookup_cache_static.empty()) { try { - ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static); + ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static); } catch (std::ifstream::failure const &) { - fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); + LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); exit(1); } } if (!params.lookup_cache_dynamic.empty()) { try { - ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic); + ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic); } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program } @@ -64,7 +65,7 @@ int main(int argc, char ** argv){ } const int n_input = inp.size(); - const int n_ctx = llama_n_ctx(ctx); + const int n_ctx = llama_n_ctx(ctx.get()); int n_drafted = 0; int n_accept = 0; @@ -85,7 +86,7 @@ int main(int argc, char ** argv){ { const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); + common_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); t_draft_us += ggml_time_us() - t_start_draft_us; } @@ -104,7 +105,7 @@ int main(int argc, char ** argv){ { const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); t_draft_us += ggml_time_us() - t_start_draft_us; } } @@ -114,7 +115,7 @@ int main(int argc, char ** argv){ pseudo_output.push_back(inp_slice[pseudo_output.size()]); { const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); t_draft_us += ggml_time_us() - t_start_draft_us; } } @@ -128,32 +129,29 @@ int main(int argc, char ** argv){ const int64_t eta_min = eta_ms / (60*1000); const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000; - LOG_TEE("lookup-stats: %d/%d done, ETA: %02" PRId64 ":%02" PRId64 "\n", i_start, n_input, eta_min, eta_s); + LOG_INF("lookup-stats: %d/%d done, ETA: %02" PRId64 ":%02" PRId64 "\n", i_start, n_input, eta_min, eta_s); } // After each chunk, update the dynamic ngram cache with the context ngram cache: - llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context); + common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context); ngram_cache_context.clear(); } - LOG_TEE("\n"); + LOG("\n"); - LOG_TEE("\n"); - LOG_TEE("n_draft = %d\n", n_draft); - LOG_TEE("n_predict = %d\n", n_input - n_input % n_ctx); - LOG_TEE("n_drafted = %d\n", n_drafted); - LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3); - LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n", + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_input - n_input % n_ctx); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3); + LOG_INF("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n", t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us)); - LOG_TEE("n_accept = %d\n", n_accept); - LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); - - llama_free(ctx); - llama_free_model(model); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); llama_backend_free(); - fprintf(stderr, "\n\n"); + LOG("\n\n"); return 0; } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 9ac7f6b47..dbd0444ec 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -1,7 +1,10 @@ +#include "arg.h" #include "ggml.h" -#include "llama.h" #include "common.h" #include "ngram-cache.h" +#include "sampling.h" +#include "log.h" +#include "llama.h" #include #include @@ -10,61 +13,58 @@ #include int main(int argc, char ** argv){ - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { return 1; } + common_init(); + // max. number of additional tokens to draft if match is found - const int n_draft = params.n_draft; + const int n_draft = params.speculative.n_max; const bool dump_kv_cache = params.dump_kv_cache; -#ifndef LOG_DISABLE_LOGS - log_set_target(log_filename_generator("lookup", "log")); - LOG_TEE("Log start\n"); - log_dump_cmdline(argc, argv); -#endif // LOG_DISABLE_LOGS - // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); // load the model - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + + const llama_vocab * vocab = llama_model_get_vocab(model); // tokenize the prompt std::vector inp; - inp = ::llama_tokenize(ctx, params.prompt, true, true); + inp = common_tokenize(ctx, params.prompt, true, true); - llama_ngram_cache ngram_cache_context; - llama_ngram_cache ngram_cache_dynamic; - llama_ngram_cache ngram_cache_static; + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; int64_t t_draft_flat_us = 0; int64_t t_draft_us = 0; { // Fill up context ngram cache with tokens from user input: const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false); if (!params.lookup_cache_static.empty()) { try { - ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static); + ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static); } catch (std::ifstream::failure const &) { - fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); + LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); exit(1); } } if (!params.lookup_cache_dynamic.empty()) { try { - ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic); + ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic); } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program } @@ -75,14 +75,14 @@ int main(int argc, char ** argv){ const int max_tokens_list_size = max_context_size - 4; if ((int) inp.size() > max_tokens_list_size) { - fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); return 1; } - fprintf(stderr, "\n\n"); + LOG("\n\n"); for (auto id : inp) { - fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + LOG("%s", common_token_to_piece(ctx, id).c_str()); } fflush(stderr); @@ -91,8 +91,8 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); const auto t_enc_end = ggml_time_us(); @@ -104,7 +104,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams); + struct common_sampler * smpl = common_sampler_init(model, params.sampling); std::vector draft; @@ -119,26 +119,26 @@ int main(int argc, char ** argv){ // debug if (dump_kv_cache) { llama_kv_cache_view_update(ctx, &kvc_view); - llama_kv_cache_dump_view_seqs(kvc_view, 40); + common_kv_cache_dump_view_seqs(kvc_view, 40); } // print current draft sequence - LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str()); + LOG_DBG("drafted %s\n", string_from(ctx, draft).c_str()); int i_dft = 0; while (true) { // sample from the target model - llama_token id = gpt_sampler_sample(smpl, ctx, i_dft); + llama_token id = common_sampler_sample(smpl, ctx, i_dft); - gpt_sampler_accept(smpl, id, true); + common_sampler_accept(smpl, id, true); - const std::string token_str = llama_token_to_piece(ctx, id); + const std::string token_str = common_token_to_piece(ctx, id); if (!params.use_color) { - printf("%s", token_str.c_str()); + LOG("%s", token_str.c_str()); } - if (llama_token_is_eog(model, id)) { + if (llama_vocab_is_eog(vocab, id)) { has_eos = true; } @@ -146,7 +146,7 @@ int main(int argc, char ** argv){ // check if the target token matches the draft if (i_dft < (int) draft.size() && id == draft[i_dft]) { - LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); + LOG_DBG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); ++n_accept; ++n_past; ++i_dft; @@ -154,25 +154,25 @@ int main(int argc, char ** argv){ { // Update context ngram cache with the newly accepted token: const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); t_draft_us += ggml_time_us() - t_start_draft_us; } if (params.use_color) { // color accepted draft token - printf("\033[34m%s\033[0m", token_str.c_str()); + LOG("\033[34m%s\033[0m", token_str.c_str()); fflush(stdout); } continue; } if (params.use_color) { - printf("%s", token_str.c_str()); + LOG("%s", token_str.c_str()); } fflush(stdout); - LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); draft.clear(); draft.push_back(id); @@ -180,7 +180,7 @@ int main(int argc, char ** argv){ { // Update context ngram cache with the newly accepted token: const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); t_draft_us += ggml_time_us() - t_start_draft_us; } break; @@ -194,18 +194,18 @@ int main(int argc, char ** argv){ // clean the cache of draft tokens that weren't accepted llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - llama_batch_clear(batch_tgt); - llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); + common_batch_clear(batch_tgt); + common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); // Draft already contains a single token sampled from the model: GGML_ASSERT(draft.size() == 1); GGML_ASSERT(draft[0] == inp.back()); const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); + common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); for (size_t i = 1; i < draft.size(); ++i) { - llama_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); } t_draft_us += ggml_time_us() - t_start_draft_us; @@ -220,38 +220,34 @@ int main(int argc, char ** argv){ auto t_dec_end = ggml_time_us(); // Update dynamic ngram cache with context ngram cache and save it to disk: - llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context); - llama_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic); + common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context); + common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic); - LOG_TEE("\n\n"); + LOG("\n\n"); - LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); - LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); - LOG_TEE("\n"); - LOG_TEE("n_draft = %d\n", n_draft); - LOG_TEE("n_predict = %d\n", n_predict); - LOG_TEE("n_drafted = %d\n", n_drafted); - LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3); - LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n", + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3); + LOG_INF("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n", t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us)); - LOG_TEE("n_accept = %d\n", n_accept); - LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); - LOG_TEE("\ntarget:\n\n"); - llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + LOG_INF("\ntarget:\n\n"); + common_perf_print(ctx, smpl); - gpt_sampler_free(smpl); + common_sampler_free(smpl); llama_batch_free(batch_tgt); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); - fprintf(stderr, "\n\n"); + LOG("\n\n"); return 0; } diff --git a/examples/main-cmake-pkg/CMakeLists.txt b/examples/main-cmake-pkg/CMakeLists.txt deleted file mode 100644 index 3b38db292..000000000 --- a/examples/main-cmake-pkg/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -cmake_minimum_required(VERSION 3.12) -project("llama-cli-cmake-pkg" C CXX) -set(TARGET llama-cli-cmake-pkg) - -find_package(Llama 0.0.1 REQUIRED) - -# Bake common functionality in with target. Because applications -# using the relocatable Llama package should be outside of the -# source tree, llama-cli-cmake-pkg pretends the dependencies are built-in. -set(_common_path "${CMAKE_CURRENT_LIST_DIR}/../../common") -add_library(common OBJECT) -file(GLOB _common_files - "${_common_path}/*.h" - "${_common_path}/*.cpp" -) -target_sources(common PRIVATE ${_common_files}) - -# If the common project was part of "llama-cli-cmake-pkg" the transient -# defines would automatically be attached. Because the common func- -# tionality is separate, but dependent upon the defines, it must be -# explicitly extracted from the "llama" target. -# -get_target_property(_llama_transient_defines llama - INTERFACE_COMPILE_DEFINITIONS) - -target_compile_definitions(common PRIVATE "${_llama_transient_defines}") - -add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../main/main.cpp) -target_include_directories(${TARGET} PRIVATE ${_common_path}) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/main-cmake-pkg/README.md b/examples/main-cmake-pkg/README.md deleted file mode 100644 index 08d83dd08..000000000 --- a/examples/main-cmake-pkg/README.md +++ /dev/null @@ -1,31 +0,0 @@ -# llama.cpp/example/main-cmake-pkg - -This program builds [llama-cli](../main) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggerganov/llama.cpp) in projects which live outside of the source tree. - -## Building - -Because this example is "outside of the source tree", it is important to first build/install llama.cpp using CMake. An example is provided here, but please see the [llama.cpp build instructions](../..) for more detailed build instructions. - -### Considerations - -When hardware acceleration libraries are used (e.g. CUDA, Metal, etc.), CMake must be able to locate the associated CMake package. - -### Build llama.cpp and install to C:\LlamaCPP directory - -```cmd -git clone https://github.com/ggerganov/llama.cpp -cd llama.cpp -cmake -B build -DBUILD_SHARED_LIBS=OFF -G "Visual Studio 17 2022" -A x64 -cmake --build build --config Release -cmake --install build --prefix C:/LlamaCPP -``` - -### Build llama-cli-cmake-pkg - - -```cmd -cd ..\examples\main-cmake-pkg -cmake -B build -DBUILD_SHARED_LIBS=OFF -DCMAKE_PREFIX_PATH="C:/LlamaCPP/lib/cmake/Llama" -G "Visual Studio 17 2022" -A x64 -cmake --build build --config Release -cmake --install build --prefix C:/MyLlamaApp -``` diff --git a/examples/main/CMakeLists.txt b/examples/main/CMakeLists.txt index 5f6efaa9a..af3d9150f 100644 --- a/examples/main/CMakeLists.txt +++ b/examples/main/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-cli) add_executable(${TARGET} main.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/main/README.md b/examples/main/README.md index 9396a34fa..46f92eb7a 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -66,10 +66,10 @@ In this section, we cover the most commonly used options for running the `llama- - `-mu MODEL_URL --model-url MODEL_URL`: Specify a remote http url to download the file (e.g [https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true](https://huggingface.co/ggml-org/gemma-1.1-7b-it-Q4_K_M-GGUF/resolve/main/gemma-1.1-7b-it.Q4_K_M.gguf?download=true)). - `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses. - `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text. -- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. +- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 4096, but if a LLaMA model was built with a longer context, increasing this value will provide better results for longer input/inference. - `-mli, --multiline-input`: Allows you to write or paste multiple lines without ending each in '\' - `-t N, --threads N`: Set the number of threads to use during generation. For optimal performance, it is recommended to set this value to the number of physical CPU cores your system has. -- - `-ngl N, --n-gpu-layers N`: When compiled with GPU support, this option allows offloading some layers to the GPU for computation. Generally results in increased performance. +- `-ngl N, --n-gpu-layers N`: When compiled with GPU support, this option allows offloading some layers to the GPU for computation. Generally results in increased performance. ## Input Prompts @@ -131,7 +131,7 @@ During text generation, LLaMA models have a limited context size, which means th ### Context Size -- `-c N, --ctx-size N`: Set the size of the prompt context (default: 0, 0 = loaded from model). The LLaMA models were built with a context of 2048-8192, which will yield the best results on longer input/inference. +- `-c N, --ctx-size N`: Set the size of the prompt context (default: 4096, 0 = loaded from model). If a LLaMA model was built with a longer context, increasing this value will yield the best results on longer input/inference. ### Extended Context Size @@ -161,6 +161,8 @@ A value of -1 will enable infinite text generation, even though we have a finite If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled. +The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full. + It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter. ### Temperature @@ -175,15 +177,34 @@ Example usage: `--temp 0` - `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled). - `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size). -- `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty. The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1. The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`). -Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases. +### DRY Repetition Penalty -Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl` +DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)). + +- `--dry-multiplier N`: Set the DRY sampling multiplier (default: 0.0, 0.0 = disabled). +- `--dry-base N`: Set the DRY sampling base value (default: 1.75). +- `--dry-allowed-length N`: Set the allowed length for DRY sampling (default: 2). +- `--dry-penalty-last-n N`: Set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size). +- `--dry-sequence-breaker STRING`: Add a sequence breaker for DRY sampling. Can be used more than once to add multiple sequence breakers. Using this clears out the default breakers, which consist of: `['\n', ':', '"', '*']`. If the string `"none"` is supplied, no sequence breakers are used. + +The `dry-multiplier` option controls the strength of the DRY sampling effect. A value of 0.0 disables DRY sampling, while higher values increase its influence. A typical recommended value is 0.8. + +The `dry-base` option sets the base value for the exponential penalty calculation in DRY sampling. Higher values lead to more aggressive penalization of repetitions. + +The `dry-allowed-length` option sets the maximum length of repeated sequences that will not be penalized. Repetitions shorter than or equal to this length are not penalized, allowing for natural repetitions of short phrases or common words. + +The `dry-penalty-last-n` option controls how many recent tokens to consider when applying the DRY penalty. A value of -1 considers the entire context. Use a positive value to limit the consideration to a specific number of recent tokens. + +The `dry-sequence-breaker` option adds a single sequence breaker and can be used more than once to specify multiple sequence breakers. Sequence breakers interrupt sequence matching and break the input into parts where matching can be applied. + +DRY sampling provides more nuanced control over text generation, particularly for reducing long-range repetitions and maintaining global coherence. + +Example usage: `--dry-multiplier 0.8 --dry-base 1.75 --dry-allowed-length 2 --dry-penalty-last-n -1 --dry-sequence-breaker "—" --dry-sequence-breaker "##"` ### Top-K Sampling @@ -209,14 +230,6 @@ The Min-P sampling method was designed as an alternative to Top-P, and aims to e Example usage: `--min-p 0.05` -### Tail-Free Sampling (TFS) - -- `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled). - -Tail-free sampling (TFS) is a text generation technique that aims to reduce the impact of less likely tokens, which may be less relevant, less coherent, or nonsensical, on the output. Similar to Top-P it tries to determine the bulk of the most likely tokens dynamically. But TFS filters out logits based on the second derivative of their probabilities. Adding tokens is stopped after the sum of the second derivatives reaches the parameter z. In short: TFS looks at how quickly the probabilities of the tokens decrease and cuts off the tail of unlikely tokens using the parameter z. Typical values for z are in the range of 0.9 to 0.95. A value of 1.0 would include all tokens and thus disables the effect of TFS. - -Example usage: `--tfs 0.95` - ### Locally Typical Sampling - `--typical N`: Enable locally typical sampling with parameter p (default: 1.0, 1.0 = disabled). @@ -239,6 +252,19 @@ The `--mirostat-ent` option sets the Mirostat target entropy (tau), which repres Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0` +### XTC Sampling + +- `--xtc-probability N`: Sets the chance for token removal (checked once on sampler start) (default: 0.0). +- `--xtc-threshold N`: Sets a minimum probability threshold for tokens to be removed (default: 0.1). + +Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-probability` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one. + +By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models. + +Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 --xtc-probability 0.5`. + +Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1` + ### Logit Bias - `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion. @@ -282,15 +308,11 @@ These options help improve the performance and memory usage of the LLaMA models. These flags attempt optimizations that help on some systems with non-uniform memory access. This currently consists of one of the above strategies, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop_caches' as root. -### Memory Float 32 - -- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. This doubles the context memory requirement and cached prompt file size but does not appear to increase generation quality in a measurable way. Not recommended. - ### Batch Size -- `-b N, --batch-size N`: Set the batch size for prompt processing (default: `2048`). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations. +- `-ub N`, `--ubatch-size N`: Physical batch size. This is the maximum number of tokens that may be processed at a time. Increasing this value may improve performance during prompt processing, at the expense of higher memory usage. Default: `512`. -- `-ub N`, `--ubatch-size N`: physical maximum batch size. This is for pipeline parallelization. Default: `512`. +- `-b N`, `--batch-size N`: Logical batch size. Increasing this value above the value of the physical batch size may improve prompt processing performance when using multiple GPUs with pipeline parallelism. Default: `2048`. ### Prompt Caching @@ -306,14 +328,22 @@ These options help improve the performance and memory usage of the LLaMA models. For information about 4-bit quantization, which can significantly improve performance and reduce memory usage, please refer to llama.cpp's primary [README](../../README.md#prepare-and-quantize). +## LoRA (Low-Rank Adaptation) adapters + +- `--lora FNAME`: Optional path to a LoRA adapter to use with scaling of 1.0. Can be mixed with `--lora-scaled` and can be repeated to use multiple adapters. +- `--lora-scaled FNAME`: Optional path to a LoRA adapter with user-defined scaling. Can be mixed with `--lora` and can repeated to use multiple adapters. + +You can add LoRA adapters using `--lora` or `--lora-scaled`. For example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...` or `--lora-scaled lora_task_A.gguf 0.5 --lora-scaled lora_task_B.gguf 0.5`. + +LoRA adapters should be in GGUF format. To convert from Hugging Face format use the `convert-lora-to-gguf.py` script. LoRA adapters are loaded separately and applied during inference - they are not merged with the main model. This means that mmap model loading is fully supported when using LoRA adapters. The old `--lora-base` flag has been removed now that merging is no longer performed. + ## Additional Options These options provide extra functionality and customization when running the LLaMA models: - `-h, --help`: Display a help message showing all available options and their default values. This is particularly useful for checking the latest options and default values, as they can change frequently, and the information in this document may become outdated. - `--verbose-prompt`: Print the prompt before generating text. +- `--no-display-prompt`: Don't print prompt at generation. - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. -- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains. -- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation. - `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache. diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ef2158842..e654d3542 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -1,11 +1,11 @@ +#include "arg.h" #include "common.h" - #include "console.h" +#include "log.h" +#include "sampling.h" #include "llama.h" +#include "chat-template.hpp" -#include -#include -#include #include #include #include @@ -31,21 +31,25 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant"; + static llama_context ** g_ctx; static llama_model ** g_model; -static gpt_sampler ** g_smpl; -static gpt_params * g_params; +static common_sampler ** g_smpl; +static common_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; static bool is_interacting = false; static bool need_insert_eot = false; -static void print_usage(int, char ** argv) { - printf("\nexample usage:\n"); - printf("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128\n", argv[0]); - printf("\n chat (conversation): %s -m your_model.gguf -p \"You are a helpful assistant\" -cnv\n", argv[0]); - printf("\n"); +static void print_usage(int argc, char ** argv) { + (void) argc; + + LOG("\nexample usage:\n"); + LOG("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128\n", argv[0]); + LOG("\n chat (conversation): %s -m your_model.gguf -p \"You are a helpful assistant\" -cnv\n", argv[0]); + LOG("\n"); } static bool file_exists(const std::string & path) { @@ -60,50 +64,6 @@ static bool file_is_empty(const std::string & path) { return f.tellg() == 0; } -static void write_logfile( - const llama_context * ctx, const gpt_params & params, const llama_model * model, - const std::vector & input_tokens, const std::string & output, - const std::vector & output_tokens -) { - if (params.logdir.empty()) { - return; - } - - const std::string timestamp = string_get_sortable_timestamp(); - - const bool success = fs_create_directory_with_parents(params.logdir); - if (!success) { - fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", - __func__, params.logdir.c_str()); - return; - } - - const std::string logfile_path = params.logdir + timestamp + ".yml"; - FILE * logfile = fopen(logfile_path.c_str(), "w"); - - if (logfile == NULL) { - fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); - return; - } - - fprintf(logfile, "binary: main\n"); - char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); - yaml_dump_non_result_info(logfile, params, ctx, timestamp, input_tokens, model_desc); - - fprintf(logfile, "\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "# Generation Results #\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "\n"); - - yaml_dump_string_multiline(logfile, "output", output.c_str()); - yaml_dump_vector_int(logfile, "output_tokens", output_tokens); - - llama_perf_dump_yaml(logfile, ctx); - fclose(logfile); -} - #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) static void sigint_handler(int signo) { if (signo == SIGINT) { @@ -112,49 +72,29 @@ static void sigint_handler(int signo) { need_insert_eot = true; } else { console::cleanup(); - printf("\n"); - gpt_perf_print(*g_ctx, *g_smpl); - write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); + LOG("\n"); + common_perf_print(*g_ctx, *g_smpl); + + // make sure all logs are flushed + LOG("Interrupted by user\n"); + common_log_pause(common_log_main()); + _exit(130); } } } #endif -static void llama_log_callback_logTee(ggml_log_level level, const char * text, void * user_data) { - (void) level; - (void) user_data; - LOG_TEE("%s", text); -} - -static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, std::string role, std::string content) { - llama_chat_msg new_msg{role, content}; - auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); - chat_msgs.push_back({role, content}); - LOG("formatted: %s\n", formatted.c_str()); - return formatted; -} - int main(int argc, char ** argv) { - gpt_params params; + common_params params; g_params = ¶ms; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_MAIN, print_usage); - - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MAIN, print_usage)) { return 1; } - auto & sparams = params.sparams; + common_init(); -#ifndef LOG_DISABLE_LOGS - log_set_target(log_filename_generator("main", "log")); - LOG_TEE("Log start\n"); - log_dump_cmdline(argc, argv); - llama_log_set(llama_log_callback_logTee, nullptr); -#endif // LOG_DISABLE_LOGS - - // TODO: Dump params ? - //LOG("Params perplexity: %s\n", LOG_TOSTR(params.perplexity)); + auto & sparams = params.sampling; // save choice to use color for later // (note for later: this is a slightly awkward choice) @@ -162,68 +102,70 @@ int main(int argc, char ** argv) { atexit([]() { console::cleanup(); }); if (params.logits_all) { - printf("\n************\n"); - printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); - printf("************\n\n"); + LOG_ERR("************\n"); + LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); + LOG_ERR("************\n\n"); return 0; } if (params.embedding) { - printf("\n************\n"); - printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__); - printf("************\n\n"); + LOG_ERR("************\n"); + LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__); + LOG_ERR("************\n\n"); return 0; } if (params.n_ctx != 0 && params.n_ctx < 8) { - LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + LOG_WRN("%s: warning: minimum context size is 8, using minimum size.\n", __func__); params.n_ctx = 8; } if (params.rope_freq_base != 0.0) { - LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); + LOG_WRN("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); } if (params.rope_freq_scale != 0.0) { - LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); + LOG_WRN("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); } - print_build_info(); + LOG_INF("%s: llama backend init\n", __func__); - LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); - - LOG("%s: llama backend init\n", __func__); llama_backend_init(); llama_numa_init(params.numa); llama_model * model = nullptr; llama_context * ctx = nullptr; - gpt_sampler * smpl = nullptr; - - std::vector chat_msgs; + common_sampler * smpl = nullptr; g_model = &model; g_ctx = &ctx; g_smpl = &smpl; - // load the model and apply lora adapter, if any - LOG("%s: load the model and apply lora adapter, if any\n", __func__); - llama_init_result llama_init = llama_init_from_gpt_params(params); + std::vector chat_msgs; - model = llama_init.model; - ctx = llama_init.context; + // load the model and apply lora adapter, if any + LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); + common_init_result llama_init = common_init_from_params(params); + + model = llama_init.model.get(); + ctx = llama_init.context.get(); if (model == NULL) { - LOG_TEE("%s: error: unable to load model\n", __func__); + LOG_ERR("%s: error: unable to load model\n", __func__); return 1; } - LOG("%s: llama threadpool init = n_threads = %d\n", - __func__, - (int) params.cpuparams.n_threads - ); + const llama_vocab * vocab = llama_model_get_vocab(model); + auto chat_templates = common_chat_templates_from_model(model, params.chat_template); + + LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); + + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); + auto * ggml_threadpool_new_fn = (decltype(ggml_threadpool_new) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_new"); + auto * ggml_threadpool_free_fn = (decltype(ggml_threadpool_free) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_free"); + struct ggml_threadpool_params tpp_batch = ggml_threadpool_params_from_cpu_params(params.cpuparams_batch); struct ggml_threadpool_params tpp = @@ -233,108 +175,134 @@ int main(int argc, char ** argv) { struct ggml_threadpool * threadpool_batch = NULL; if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) { - threadpool_batch = ggml_threadpool_new(&tpp_batch); + threadpool_batch = ggml_threadpool_new_fn(&tpp_batch); if (!threadpool_batch) { - LOG_TEE("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads); - exit(1); + LOG_ERR("%s: batch threadpool create failed : n_threads %d\n", __func__, tpp_batch.n_threads); + return 1; } // Start the non-batch threadpool in the paused state tpp.paused = true; } - struct ggml_threadpool * threadpool = ggml_threadpool_new(&tpp); + struct ggml_threadpool * threadpool = ggml_threadpool_new_fn(&tpp); if (!threadpool) { - LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); - exit(1); + LOG_ERR("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); + return 1; } llama_attach_threadpool(ctx, threadpool, threadpool_batch); - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); - LOG("n_ctx: %d\n", n_ctx); if (n_ctx > n_ctx_train) { - LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n", - __func__, n_ctx_train, n_ctx); + LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); + } + + // auto enable conversation mode if chat template is available + const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default; + if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { + if (has_chat_template) { + LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); + params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED; + } else { + params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED; + } + } + + // in case user force-activate conversation mode (via -cnv) without proper chat template, we show a warning + if (params.conversation_mode && !has_chat_template) { + LOG_WRN("%s: chat template is not available or is not supported. This may cause the model to output suboptimal responses\n", __func__); } // print chat template example in conversation mode - if (params.conversation) { + if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); } else { - LOG_TEE("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); + LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } } // print system information { - LOG_TEE("\n"); - LOG_TEE("%s\n", gpt_params_get_system_info(params).c_str()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); } std::string path_session = params.path_prompt_cache; std::vector session_tokens; if (!path_session.empty()) { - LOG_TEE("%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str()); + LOG_INF("%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str()); if (!file_exists(path_session)) { - LOG_TEE("%s: session file does not exist, will create.\n", __func__); + LOG_INF("%s: session file does not exist, will create.\n", __func__); } else if (file_is_empty(path_session)) { - LOG_TEE("%s: The session file is empty. A new session will be initialized.\n", __func__); + LOG_INF("%s: The session file is empty. A new session will be initialized.\n", __func__); } else { // The file exists and is not empty session_tokens.resize(n_ctx); size_t n_token_count_out = 0; if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { - LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); + LOG_ERR("%s: failed to load session file '%s'\n", __func__, path_session.c_str()); return 1; } session_tokens.resize(n_token_count_out); - LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size()); + LOG_INF("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size()); } } - const bool add_bos = llama_add_bos_token(model); + const bool add_bos = llama_vocab_get_add_bos(vocab) && !params.use_jinja; if (!llama_model_has_encoder(model)) { - GGML_ASSERT(!llama_add_eos_token(model)); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); } - LOG("add_bos: %d\n", add_bos); + + LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos); std::vector embd_inp; + auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { + common_chat_msg new_msg{role, content, {}}; + auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); + chat_msgs.push_back({role, content, {}}); + LOG_DBG("formatted: '%s'\n", formatted.c_str()); + return formatted; + }; + { - auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty()) - ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode + auto prompt = (params.conversation_mode && params.enable_chat_template) + // format the system prompt in conversation mode (fallback to default if empty) + ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) + // otherwise use the prompt as is : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { - LOG("tokenize the prompt\n"); - embd_inp = ::llama_tokenize(ctx, prompt, true, true); + LOG_DBG("tokenize the prompt\n"); + embd_inp = common_tokenize(ctx, prompt, true, true); } else { - LOG("use session tokens\n"); + LOG_DBG("use session tokens\n"); embd_inp = session_tokens; } - LOG("prompt: \"%s\"\n", log_tostr(prompt)); - LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + LOG_DBG("prompt: \"%s\"\n", prompt.c_str()); + LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str()); } // Should not run without any tokens if (embd_inp.empty()) { if (add_bos) { - embd_inp.push_back(llama_token_bos(model)); - LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + embd_inp.push_back(llama_vocab_bos(vocab)); + LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); } else { - LOG_TEE("error: input is empty\n"); + LOG_ERR("input is empty\n"); return -1; } } // Tokenize negative prompt if ((int) embd_inp.size() > n_ctx - 4) { - LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); + LOG_ERR("%s: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); return 1; } @@ -348,29 +316,28 @@ int main(int argc, char ** argv) { n_matching_session_tokens++; } if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { - LOG_TEE("%s: using full prompt from session file\n", __func__); + LOG_INF("%s: using full prompt from session file\n", __func__); } else if (n_matching_session_tokens >= embd_inp.size()) { - LOG_TEE("%s: session file has exact match for prompt!\n", __func__); + LOG_INF("%s: session file has exact match for prompt!\n", __func__); } else if (n_matching_session_tokens < (embd_inp.size() / 2)) { - LOG_TEE("%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", - __func__, n_matching_session_tokens, embd_inp.size()); + LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", + __func__, n_matching_session_tokens, embd_inp.size()); } else { - LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", - __func__, n_matching_session_tokens, embd_inp.size()); + LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n", + __func__, n_matching_session_tokens, embd_inp.size()); } // remove any "future" tokens that we might have inherited from the previous session llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); } - LOGLN( - "recalculate the cached logits (check): embd_inp.empty() %s, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu", - log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size()); + LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n", + embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size()); // if we will use the cache for the full prompt without reaching the end of the cache, force // reevaluation of the last token to recalculate the cached logits if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) { - LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); + LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1); } @@ -382,7 +349,7 @@ int main(int argc, char ** argv) { params.n_keep += add_bos; // always keep the BOS token } - if (params.conversation) { + if (params.conversation_mode) { params.interactive_first = true; } @@ -392,21 +359,20 @@ int main(int argc, char ** argv) { } if (params.verbose_prompt) { - LOG_TEE("\n"); - LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); + LOG_INF("%6d -> '%s'\n", embd_inp[i], common_token_to_piece(ctx, embd_inp[i]).c_str()); } if (params.n_keep > add_bos) { - LOG_TEE("%s: static prompt based on n_keep: '", __func__); + LOG_INF("%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { - LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str()); + LOG_CNT("%s", common_token_to_piece(ctx, embd_inp[i]).c_str()); } - LOG_TEE("'\n"); + LOG_CNT("'\n"); } - LOG_TEE("\n"); + LOG_INF("\n"); } // ctrl+C handling @@ -426,54 +392,56 @@ int main(int argc, char ** argv) { } if (params.interactive) { - LOG_TEE("%s: interactive mode on.\n", __func__); + LOG_INF("%s: interactive mode on.\n", __func__); if (!params.antiprompt.empty()) { for (const auto & antiprompt : params.antiprompt) { - LOG_TEE("Reverse prompt: '%s'\n", antiprompt.c_str()); + LOG_INF("Reverse prompt: '%s'\n", antiprompt.c_str()); if (params.verbose_prompt) { - auto tmp = ::llama_tokenize(ctx, antiprompt, false, true); + auto tmp = common_tokenize(ctx, antiprompt, false, true); for (int i = 0; i < (int) tmp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str()); } } } } if (params.input_prefix_bos) { - LOG_TEE("Input prefix with BOS\n"); + LOG_INF("Input prefix with BOS\n"); } if (!params.input_prefix.empty()) { - LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str()); + LOG_INF("Input prefix: '%s'\n", params.input_prefix.c_str()); if (params.verbose_prompt) { - auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true); + auto tmp = common_tokenize(ctx, params.input_prefix, true, true); for (int i = 0; i < (int) tmp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str()); } } } if (!params.input_suffix.empty()) { - LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); + LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str()); if (params.verbose_prompt) { - auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true); + auto tmp = common_tokenize(ctx, params.input_suffix, false, true); for (int i = 0; i < (int) tmp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx, tmp[i]).c_str()); } } } } - smpl = gpt_sampler_init(model, sparams); + smpl = common_sampler_init(model, sparams); if (!smpl) { - fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); - exit(1); + LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); + return 1; } - LOG_TEE("sampling params: \n%s\n", sparams.print().c_str()); - LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str()); - LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl)); + LOG_INF("sampler params: \n%s\n", sparams.print().c_str()); + LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str()); + + LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); // group-attention state // number of grouped KV tokens so far (used only if params.grp_attn_n > 1) @@ -487,9 +455,9 @@ int main(int argc, char ** argv) { GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT - LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w); + LOG_INF("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w); } - LOG_TEE("\n\n"); + LOG_INF("\n"); if (params.interactive) { const char * control_message; @@ -501,11 +469,15 @@ int main(int argc, char ** argv) { " - To return control without starting a new line, end your input with '/'.\n" " - If you want to submit another line, end your input with '\\'.\n"; } - LOG_TEE("== Running in interactive mode. ==\n"); + LOG_INF("== Running in interactive mode. ==\n"); #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) - LOG_TEE( " - Press Ctrl+C to interject at any time.\n"); + LOG_INF( " - Press Ctrl+C to interject at any time.\n"); #endif - LOG_TEE( "%s\n", control_message); + LOG_INF( "%s", control_message); + if (params.conversation_mode && params.enable_chat_template && params.prompt.empty()) { + LOG_INF( " - Using default system message. To change it, set a different value via -p PROMPT or -f FILE argument.\n"); + } + LOG_INF("\n"); is_interacting = params.interactive_first; } @@ -531,26 +503,28 @@ int main(int argc, char ** argv) { std::vector embd; - // tokenized antiprompts - std::vector> antiprompt_ids; + // single-token antiprompts + std::vector antiprompt_token; - antiprompt_ids.reserve(params.antiprompt.size()); for (const std::string & antiprompt : params.antiprompt) { - antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); + auto ids = ::common_tokenize(ctx, antiprompt, false, true); + if (ids.size() == 1) { + antiprompt_token.push_back(ids[0]); + } } if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) { - LOG_TEE("%s : failed to eval\n", __func__); + if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) { + LOG_ERR("%s : failed to eval\n", __func__); return 1; } llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { - decoder_start_token_id = llama_token_bos(model); + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { + decoder_start_token_id = llama_vocab_bos(vocab); } embd_inp.clear(); @@ -570,9 +544,8 @@ int main(int argc, char ** argv) { embd.resize(max_embd_size); console::set_display(console::error); - printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + LOG_WRN("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); console::set_display(console::reset); - fflush(stdout); } if (ga_n == 1) { @@ -580,16 +553,22 @@ int main(int argc, char ** argv) { // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() >= n_ctx) { + if (!params.ctx_shift){ + LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__); + break; + } + if (params.n_predict == -2) { - LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); break; } const int n_left = n_past - params.n_keep; const int n_discard = n_left/2; - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); @@ -597,11 +576,11 @@ int main(int argc, char ** argv) { n_past -= n_discard; - LOG("after swap: n_past = %d\n", n_past); + LOG_DBG("after swap: n_past = %d\n", n_past); - LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str()); - LOG("clear session path\n"); + LOG_DBG("clear session path\n"); path_session.clear(); } } else { @@ -611,10 +590,10 @@ int main(int argc, char ** argv) { const int bd = (ga_w/ga_n)*(ga_n - 1); const int dd = (ga_w/ga_n) - ib*bd - ga_w; - LOG("\n"); - LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd); - LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); - LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); + LOG_DBG("\n"); + LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd); + LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); + LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); @@ -624,7 +603,7 @@ int main(int argc, char ** argv) { ga_i += ga_w/ga_n; - LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i); + LOG_DBG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i); } } @@ -656,19 +635,19 @@ int main(int argc, char ** argv) { n_eval = params.n_batch; } - LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { - LOG_TEE("%s : failed to eval\n", __func__); + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + LOG_ERR("%s : failed to eval\n", __func__); return 1; } n_past += n_eval; - LOG("n_past = %d\n", n_past); + LOG_DBG("n_past = %d\n", n_past); // Display total tokens alongside total time if (params.n_print > 0 && n_past % params.n_print == 0) { - LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); + LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); } } @@ -686,14 +665,14 @@ int main(int argc, char ** argv) { need_to_save_session = false; llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); - LOG("saved session to %s\n", path_session.c_str()); + LOG_DBG("saved session to %s\n", path_session.c_str()); } - const llama_token id = gpt_sampler_sample(smpl, ctx, -1); + const llama_token id = common_sampler_sample(smpl, ctx, -1); - gpt_sampler_accept(smpl, id, /* apply_grammar= */ true); + common_sampler_accept(smpl, id, /* accept_grammar= */ true); - // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); + // LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str()); embd.push_back(id); @@ -703,16 +682,16 @@ int main(int argc, char ** argv) { // decrement remaining sampling budget --n_remain; - LOG("n_remain: %d\n", n_remain); + LOG_DBG("n_remain: %d\n", n_remain); } else { // some user input remains from prompt or interaction, forward it to processing - LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); + LOG_DBG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - gpt_sampler_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false); + common_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -724,10 +703,10 @@ int main(int argc, char ** argv) { // display text if (input_echo && display) { for (auto id : embd) { - const std::string token_str = llama_token_to_piece(ctx, id, params.special); + const std::string token_str = common_token_to_piece(ctx, id, params.special); // Console/Stream Output - fprintf(stdout, "%s", token_str.c_str()); + LOG("%s", token_str.c_str()); // Record Displayed Tokens To Log // Note: Generated tokens are created one by one hence this check @@ -739,8 +718,6 @@ int main(int argc, char ** argv) { output_tokens.push_back(id); output_ss << token_str; } - - fflush(stdout); } } @@ -755,7 +732,7 @@ int main(int argc, char ** argv) { // check for reverse prompt in the last n_prev tokens if (!params.antiprompt.empty()) { const int n_prev = 32; - const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev); + const std::string last_output = common_sampler_prev_str(smpl, ctx, n_prev); is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. @@ -777,64 +754,61 @@ int main(int argc, char ** argv) { } // check for reverse prompt using special tokens - llama_token last_token = gpt_sampler_last(smpl); - for (std::vector ids : antiprompt_ids) { - if (ids.size() == 1 && last_token == ids[0]) { - if (params.interactive) { - is_interacting = true; - } - is_antiprompt = true; - break; + llama_token last_token = common_sampler_last(smpl); + if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) { + if (params.interactive) { + is_interacting = true; } + is_antiprompt = true; } if (is_antiprompt) { - LOG("found antiprompt: %s\n", last_output.c_str()); + LOG_DBG("found antiprompt: %s\n", last_output.c_str()); } } // deal with end of generation tokens in interactive mode - if (llama_token_is_eog(model, gpt_sampler_last(smpl))) { - LOG("found an EOG token\n"); + if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { + LOG_DBG("found an EOG token\n"); if (params.interactive) { if (!params.antiprompt.empty()) { // tokenize and inject first reverse prompt - const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true); + const auto first_antiprompt = common_tokenize(ctx, params.antiprompt.front(), false, true); embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); is_antiprompt = true; } if (params.enable_chat_template) { - chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str()); + chat_add_and_format("assistant", assistant_ss.str()); } is_interacting = true; - printf("\n"); + LOG("\n"); } } // if current token is not EOG, we add it to current assistant message - if (params.conversation) { - const auto id = gpt_sampler_last(smpl); - assistant_ss << llama_token_to_piece(ctx, id, false); + if (params.conversation_mode) { + const auto id = common_sampler_last(smpl); + assistant_ss << common_token_to_piece(ctx, id, false); } if (n_past > 0 && is_interacting) { - LOG("waiting for user input\n"); + LOG_DBG("waiting for user input\n"); - if (params.conversation) { - printf("\n> "); + if (params.conversation_mode) { + LOG("\n> "); } if (params.input_prefix_bos) { - LOG("adding input prefix BOS token\n"); - embd_inp.push_back(llama_token_bos(model)); + LOG_DBG("adding input prefix BOS token\n"); + embd_inp.push_back(llama_vocab_bos(vocab)); } std::string buffer; - if (!params.input_prefix.empty() && !params.conversation) { - LOG("appending input prefix: '%s'\n", params.input_prefix.c_str()); - printf("%s", params.input_prefix.c_str()); + if (!params.input_prefix.empty() && !params.conversation_mode) { + LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str()); + LOG("%s", params.input_prefix.c_str()); } // color user input only @@ -856,12 +830,12 @@ int main(int argc, char ** argv) { // Entering a empty line lets the user pass control back if (buffer.length() > 1) { // append input suffix if any - if (!params.input_suffix.empty() && !params.conversation) { - LOG("appending input suffix: '%s'\n", params.input_suffix.c_str()); - printf("%s", params.input_suffix.c_str()); + if (!params.input_suffix.empty() && !params.conversation_mode) { + LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str()); + LOG("%s", params.input_suffix.c_str()); } - LOG("buffer: '%s'\n", buffer.c_str()); + LOG_DBG("buffer: '%s'\n", buffer.c_str()); const size_t original_size = embd_inp.size(); @@ -869,21 +843,21 @@ int main(int argc, char ** argv) { string_process_escapes(buffer); } - bool format_chat = params.conversation && params.enable_chat_template; + bool format_chat = params.conversation_mode && params.enable_chat_template; std::string user_inp = format_chat - ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) + ? chat_add_and_format("user", std::move(buffer)) : std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) - const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); - const auto line_inp = ::llama_tokenize(ctx, user_inp, false, format_chat); - const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); + const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); + const auto line_inp = common_tokenize(ctx, user_inp, false, format_chat); + const auto line_sfx = common_tokenize(ctx, params.input_suffix, false, true); - LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); + LOG_DBG("input tokens: %s\n", string_from(ctx, line_inp).c_str()); // if user stop generation mid-way, we must add EOT to finish model's last response if (need_insert_eot && format_chat) { - llama_token eot = llama_token_eot(model); - embd_inp.push_back(eot == -1 ? llama_token_eos(model) : eot); + llama_token eot = llama_vocab_eot(vocab); + embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_vocab_eos(vocab) : eot); need_insert_eot = false; } @@ -894,16 +868,16 @@ int main(int argc, char ** argv) { for (size_t i = original_size; i < embd_inp.size(); ++i) { const llama_token token = embd_inp[i]; output_tokens.push_back(token); - output_ss << llama_token_to_piece(ctx, token); + output_ss << common_token_to_piece(ctx, token); } // reset assistant message assistant_ss.str(""); n_remain -= line_inp.size(); - LOG("n_remain: %d\n", n_remain); + LOG_DBG("n_remain: %d\n", n_remain); } else { - LOG("empty line, passing control back\n"); + LOG_DBG("empty line, passing control back\n"); } input_echo = false; // do not echo this again @@ -911,15 +885,15 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - gpt_sampler_reset(smpl); + common_sampler_reset(smpl); } is_interacting = false; } } // end of generation - if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.interactive)) { - LOG_TEE(" [end of text]\n"); + if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) { + LOG(" [end of text]\n"); break; } @@ -932,27 +906,19 @@ int main(int argc, char ** argv) { } if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { - LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); + LOG("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - LOG_TEE("\n"); - gpt_perf_print(ctx, smpl); - write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + LOG("\n\n"); + common_perf_print(ctx, smpl); - gpt_sampler_free(smpl); - - llama_free(ctx); - llama_free_model(model); + common_sampler_free(smpl); llama_backend_free(); - ggml_threadpool_free(threadpool); - ggml_threadpool_free(threadpool_batch); - -#ifndef LOG_DISABLE_LOGS - LOG_TEE("Log end\n"); -#endif // LOG_DISABLE_LOGS + ggml_threadpool_free_fn(threadpool); + ggml_threadpool_free_fn(threadpool_batch); return 0; } diff --git a/examples/parallel/CMakeLists.txt b/examples/parallel/CMakeLists.txt index c13557bac..847e916de 100644 --- a/examples/parallel/CMakeLists.txt +++ b/examples/parallel/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-parallel) add_executable(${TARGET} parallel.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7f512d8ad..7ef43d5e1 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -1,7 +1,10 @@ // A basic application simulating a server with multiple clients. // The clients submit requests to the server and they are processed in parallel. +#include "arg.h" #include "common.h" +#include "sampling.h" +#include "log.h" #include "llama.h" #include @@ -51,7 +54,7 @@ static std::vector k_prompts = { struct client { ~client() { if (smpl) { - gpt_sampler_free(smpl); + common_sampler_free(smpl); } } @@ -72,7 +75,7 @@ struct client { std::string prompt; std::string response; - struct gpt_sampler * smpl = nullptr; + struct common_sampler * smpl = nullptr; }; static void print_date_time() { @@ -81,7 +84,9 @@ static void print_date_time() { char buffer[80]; strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", local_time); - printf("\n\033[35mrun parameters as at %s\033[0m\n", buffer); + LOG_INF("\n"); + LOG_INF("\033[35mrun parameters as of %s\033[0m\n", buffer); + LOG_INF("\n"); } // Define a split string function to ... @@ -98,13 +103,14 @@ static std::vector split_string(const std::string& input, char deli int main(int argc, char ** argv) { srand(1234); - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) { return 1; } + common_init(); + // number of simultaneous "clients" to simulate const int32_t n_clients = params.n_parallel; @@ -119,41 +125,36 @@ int main(int argc, char ** argv) { const bool dump_kv_cache = params.dump_kv_cache; -#ifndef LOG_DISABLE_LOGS - log_set_target(log_filename_generator("parallel", "log")); - LOG_TEE("Log start\n"); - log_dump_cmdline(argc, argv); -#endif // LOG_DISABLE_LOGS - // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); // load the target model - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); + + const llama_vocab * vocab = llama_model_get_vocab(model); // load the prompts from an external file if there are any if (params.prompt.empty()) { - printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); + LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); } else { // Output each line of the input params.prompts vector and copy to k_prompts int index = 0; - printf("\n\033[32mNow printing the external prompt file %s\033[0m\n\n", params.prompt_file.c_str()); + LOG_INF("\033[32mNow printing the external prompt file %s\033[0m\n\n", params.prompt_file.c_str()); std::vector prompts = split_string(params.prompt, '\n'); for (const auto& prompt : prompts) { k_prompts.resize(index + 1); k_prompts[index] = prompt; index++; - printf("%3d prompt: %s\n", index, prompt.c_str()); + LOG_INF("%3d prompt: %s\n", index, prompt.c_str()); } } - fprintf(stderr, "\n\n"); - fflush(stderr); + LOG_INF("\n\n"); const int n_ctx = llama_n_ctx(ctx); @@ -161,11 +162,11 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.smpl = gpt_sampler_init(model, params.sparams); + client.smpl = common_sampler_init(model, params.sampling); } std::vector tokens_system; - tokens_system = ::llama_tokenize(ctx, k_system, true); + tokens_system = common_tokenize(ctx, k_system, true); const int32_t n_tokens_system = tokens_system.size(); llama_seq_id g_seq_id = 0; @@ -182,19 +183,19 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); - LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__); - LOG_TEE("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); - LOG_TEE("\n"); + LOG_INF("%s: Simulating parallel requests from clients:\n", __func__); + LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); + LOG_INF("\n"); { - LOG_TEE("%s: Evaluating the system prompt ...\n", __func__); + LOG_INF("%s: Evaluating the system prompt ...\n", __func__); for (int32_t i = 0; i < n_tokens_system; ++i) { - llama_batch_add(batch, tokens_system[i], i, { 0 }, false); + common_batch_add(batch, tokens_system[i], i, { 0 }, false); } if (llama_decode(ctx, batch) != 0) { - LOG_TEE("%s: llama_decode() failed\n", __func__); + LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -203,18 +204,18 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } - LOG_TEE("\n"); + LOG_INF("\n"); } - LOG_TEE("Processing requests ...\n\n"); + LOG_INF("Processing requests ...\n\n"); while (true) { if (dump_kv_cache) { llama_kv_cache_view_update(ctx, &kvc_view); - llama_kv_cache_dump_view_seqs(kvc_view, 40); + common_kv_cache_dump_view_seqs(kvc_view, 40); } - llama_batch_clear(batch); + common_batch_clear(batch); // decode any currently ongoing sequences for (auto & client : clients) { @@ -224,7 +225,7 @@ int main(int argc, char ** argv) { client.i_batch = batch.n_tokens; - llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); + common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); client.n_decoded += 1; } @@ -237,7 +238,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } - LOG_TEE("%s: clearing the KV cache\n", __func__); + LOG_INF("%s: clearing the KV cache\n", __func__); } // insert new sequences for decoding @@ -253,14 +254,14 @@ int main(int argc, char ** argv) { client.prompt = client.input + "\nAssistant:"; client.response = ""; - gpt_sampler_reset(client.smpl); + common_sampler_reset(client.smpl); // do not prepend BOS because we have a system prompt! std::vector tokens_prompt; - tokens_prompt = ::llama_tokenize(ctx, client.prompt, false); + tokens_prompt = common_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); + common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); } // extract the logits only for the last token @@ -272,7 +273,7 @@ int main(int argc, char ** argv) { client.n_decoded = 0; client.i_batch = batch.n_tokens - 1; - LOG_TEE("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); + LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); g_seq_id += 1; @@ -309,18 +310,17 @@ int main(int argc, char ** argv) { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); if (ret != 0) { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size - LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); + LOG_ERR("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); return 1; } - LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); + LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); n_cache_miss += 1; @@ -331,7 +331,7 @@ int main(int argc, char ** argv) { continue; } - LOG("%s : decoded batch of %d tokens\n", __func__, n_tokens); + LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens); for (auto & client : clients) { if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { @@ -341,9 +341,9 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i); + const llama_token id = common_sampler_sample(client.smpl, ctx, client.i_batch - i); - gpt_sampler_accept(client.smpl, id, true); + common_sampler_accept(client.smpl, id, true); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -351,7 +351,7 @@ int main(int argc, char ** argv) { client.t_start_gen = ggml_time_us(); } - const std::string token_str = llama_token_to_piece(ctx, id); + const std::string token_str = common_token_to_piece(ctx, id); client.response += token_str; client.sampled = id; @@ -360,7 +360,7 @@ int main(int argc, char ** argv) { // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); if (client.n_decoded > 2 && - (llama_token_is_eog(model, id) || + (llama_vocab_is_eog(vocab, id) || (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || client.response.find("User:") != std::string::npos || client.response.find('\n') != std::string::npos)) { @@ -376,7 +376,7 @@ int main(int argc, char ** argv) { const auto t_main_end = ggml_time_us(); - LOG_TEE("\033[31mClient %3d, seq %3d/%3d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \nInput: %s\n\033[35mResponse: %s\033[0m\n\n", + LOG_INF("\033[31mClient %3d, seq %3d/%3d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \n\nInput: %s\n\033[35mResponse: %s\033[0m\n\n", client.id, client.seq_id, n_seq, client.n_prompt, client.n_decoded, (t_main_end - client.t_start_prompt) / 1e6, (double) (client.n_prompt + client.n_decoded) / (t_main_end - client.t_start_prompt) * 1e6, @@ -399,31 +399,28 @@ int main(int argc, char ** argv) { print_date_time(); - LOG_TEE("\n%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); + LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); if (params.prompt_file.empty()) { params.prompt_file = "used built-in defaults"; } - LOG_TEE("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str()); - LOG_TEE("Model and path used: \033[32m%s\033[0m\n\n", params.model.c_str()); + LOG_INF("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str()); + LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.c_str()); - LOG_TEE("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6); - LOG_TEE("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6); - LOG_TEE("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6); - LOG_TEE("Cache misses: %6d\n", n_cache_miss); + LOG_INF("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6); + LOG_INF("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6); + LOG_INF("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6); + LOG_INF("Cache misses: %6d\n", n_cache_miss); - LOG_TEE("\n"); + LOG_INF("\n"); // TODO: print sampling/grammar timings for all clients - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + llama_perf_context_print(ctx); llama_batch_free(batch); - llama_free(ctx); - llama_free_model(model); - llama_backend_free(); - fprintf(stderr, "\n\n"); + LOG("\n\n"); return 0; } diff --git a/examples/passkey/CMakeLists.txt b/examples/passkey/CMakeLists.txt index dc467a5d3..9bc5110c2 100644 --- a/examples/passkey/CMakeLists.txt +++ b/examples/passkey/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-passkey) add_executable(${TARGET} passkey.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 76d235c2c..5953928d4 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -1,4 +1,6 @@ +#include "arg.h" #include "common.h" +#include "log.h" #include "llama.h" #include @@ -7,23 +9,24 @@ #include static void print_usage(int, char ** argv) { - LOG_TEE("\nexample usage:\n"); - LOG_TEE("\n %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]); - LOG_TEE("\n"); + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]); + LOG("\n"); } int main(int argc, char ** argv) { - gpt_params params; + common_params params; params.n_junk = 250; params.n_keep = 32; params.i_pos = -1; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_PASSKEY, print_usage); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PASSKEY, print_usage)) { return 1; } + common_init(); + int n_junk = params.n_junk; int n_keep = params.n_keep; int n_grp = params.grp_attn_n; @@ -58,26 +61,28 @@ int main(int argc, char ** argv) { // initialize the model - llama_model_params model_params = llama_model_params_from_gpt_params(params); + llama_model_params model_params = common_model_params_to_llama(params); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); if (model == NULL) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); + LOG_ERR("%s: unable to load model\n" , __func__); return 1; } + const llama_vocab * vocab = llama_model_get_vocab(model); + // initialize the context - llama_context_params ctx_params = llama_context_params_from_gpt_params(params); + llama_context_params ctx_params = common_context_params_to_llama(params); - ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; + ctx_params.n_ctx = llama_model_n_ctx_train(model)*n_grp + n_keep; GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + LOG_ERR("%s: failed to create the llama_context\n" , __func__); return 1; } @@ -89,10 +94,10 @@ int main(int argc, char ** argv) { // tokenize the prompt std::vector tokens_list; - tokens_list = ::llama_tokenize(ctx, params.prompt, true); + tokens_list = common_tokenize(ctx, params.prompt, true); // tokenize the prefix and use it as a sink - const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size(); + const int n_tokens_prefix = common_tokenize(ctx, prompt_prefix, true).size(); const int n_tokens_all = tokens_list.size(); @@ -107,14 +112,14 @@ int main(int argc, char ** argv) { const int n_batch = ctx_params.n_batch; const int n_batch_grp = ctx_params.n_batch/n_grp; - LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos); + LOG_INF("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos); // print the prompt token-by-token - LOG_TEE("\n"); - LOG_TEE("prefix tokens: %d\n", n_tokens_prefix); - LOG_TEE("prompt tokens: %d\n", n_tokens_all); - //LOG_TEE("prompt: %s\n", params.prompt.c_str()); + LOG_INF("\n"); + LOG_INF("prefix tokens: %d\n", n_tokens_prefix); + LOG_INF("prompt tokens: %d\n", n_tokens_all); + //LOG_INF("prompt: %s\n", params.prompt.c_str()); llama_batch batch = llama_batch_init(params.n_batch, 0, 1); @@ -134,10 +139,10 @@ int main(int argc, char ** argv) { n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; } - llama_batch_clear(batch); + common_batch_clear(batch); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); } if (i + n_batch >= n_tokens_all) { @@ -145,11 +150,11 @@ int main(int argc, char ** argv) { } if (llama_decode(ctx, batch) != 0) { - LOG_TEE("%s: llama_decode() failed\n", __func__); + LOG_INF("%s: llama_decode() failed\n", __func__); return 1; } - LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); + LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); if (i + n_batch >= n_tokens_all) { break; @@ -159,7 +164,7 @@ int main(int argc, char ** argv) { for (int i = n_ctx; i < n_tokens_all; i += n_batch) { const int n_discard = n_batch; - LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); + LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard); llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); @@ -168,10 +173,10 @@ int main(int argc, char ** argv) { n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; - llama_batch_clear(batch); + common_batch_clear(batch); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); } if (i + n_batch >= n_tokens_all) { @@ -179,18 +184,18 @@ int main(int argc, char ** argv) { } if (llama_decode(ctx, batch) != 0) { - LOG_TEE("%s: llama_decode() failed\n", __func__); + LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } - LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); + LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); } { const int n_discard = n_past - n_ctx + n_predict; if (n_discard > 0) { - LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); + LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); @@ -201,17 +206,16 @@ int main(int argc, char ** argv) { } } - LOG_TEE("\n"); - LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk); - LOG_TEE("\n"); + LOG_INF("\n"); + LOG_INF("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk); + LOG_INF("\n"); // main loop int n_cur = n_tokens_all; int n_decode = 0; - LOG_TEE("%s", prompt_suffix.c_str()); - fflush(stdout); + LOG_INF("%s", prompt_suffix.c_str()); const auto t_main_start = ggml_time_us(); @@ -220,54 +224,51 @@ int main(int argc, char ** argv) { { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); - llama_sampler_accept(smpl, new_token_id); - // is it an end of generation? - if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { - LOG_TEE("\n"); + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { + LOG("\n"); break; } - LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); - fflush(stdout); + LOG("%s", common_token_to_piece(ctx, new_token_id).c_str()); n_decode += 1; // prepare the next batch - llama_batch_clear(batch); + common_batch_clear(batch); // push this new token for next evaluation - llama_batch_add(batch, new_token_id, n_past++, { 0 }, true); + common_batch_add(batch, new_token_id, n_past++, { 0 }, true); } n_cur += 1; // evaluate the current batch with the transformer model if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } } - LOG_TEE("\n"); + LOG("\n"); const auto t_main_end = ggml_time_us(); - LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - LOG_TEE("\n"); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + LOG("\n"); + llama_perf_context_print(ctx); - fprintf(stderr, "\n"); + LOG("\n"); llama_sampler_free(smpl); llama_batch_free(batch); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); llama_backend_free(); diff --git a/examples/perplexity/CMakeLists.txt b/examples/perplexity/CMakeLists.txt index be0f2fd02..3e6864093 100644 --- a/examples/perplexity/CMakeLists.txt +++ b/examples/perplexity/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-perplexity) add_executable(${TARGET} perplexity.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 570ee8aeb..9bf6c5743 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1,18 +1,21 @@ +#include "arg.h" #include "common.h" +#include "log.h" #include "llama.h" +#include +#include +#include #include #include #include #include +#include +#include +#include #include #include -#include -#include #include -#include -#include -#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -31,55 +34,6 @@ struct results_log_softmax { float prob; }; -static void write_logfile( - const llama_context * ctx, const gpt_params & params, const llama_model * model, - const struct results_perplexity & results -) { - if (params.logdir.empty()) { - return; - } - - if (params.hellaswag) { - fprintf(stderr, "%s: warning: logging results is not implemented for HellaSwag. No files will be written.\n", __func__); - return; - } - - const std::string timestamp = string_get_sortable_timestamp(); - - const bool success = fs_create_directory_with_parents(params.logdir); - if (!success) { - fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", - __func__, params.logdir.c_str()); - return; - } - - const std::string logfile_path = params.logdir + timestamp + ".yml"; - FILE * logfile = fopen(logfile_path.c_str(), "w"); - - if (logfile == NULL) { - fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); - return; - } - - fprintf(logfile, "binary: main\n"); - char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); - yaml_dump_non_result_info(logfile, params, ctx, timestamp, results.tokens, model_desc); - - fprintf(logfile, "\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "# Perplexity Results #\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "\n"); - - yaml_dump_vector_float(logfile, "logits", results.logits); - fprintf(logfile, "ppl_value: %f\n", results.ppl_value); - yaml_dump_vector_float(logfile, "probs", results.probs); - - llama_perf_dump_yaml(logfile, ctx); - fclose(logfile); -} - static std::vector softmax(const std::vector& logits) { std::vector probs(logits.size()); float max_logit = logits[0]; @@ -166,7 +120,7 @@ static void process_logits( break; } lock.unlock(); - const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + const results_log_softmax results = log_softmax(n_vocab, logits + size_t(i)*n_vocab, tokens[i+1]); const double v = -results.log_softmax; local_nll += v; local_nll2 += v*v; @@ -200,7 +154,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits, break; } lock.unlock(); - const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]); + const double v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, log_probs.data() + i*nv, tokens[i+1]); local_nll += v; local_nll2 += v*v; } @@ -278,7 +232,9 @@ static std::pair log_softmax(int n_vocab, const float * logits, c kld.sum_kld += sum; kld.sum_kld2 += sum*sum; ++kld.count; - if (imax == imax_base) ++kld.n_same_top; + if (imax == imax_base) { + ++kld.n_same_top; + } const float p_base = expf(-nll_base); const float p = expf(-nll); @@ -320,7 +276,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens break; } lock.unlock(); - std::pair v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld); + std::pair v = log_softmax(n_vocab, logits + size_t(i)*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld); kld_values[i] = (float)v.first; p_diff_values[i] = v.second; } @@ -334,25 +290,28 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens } } -static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { +static results_perplexity perplexity_v2(llama_context * ctx, const common_params & params) { // Download: https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); - fprintf(stderr, "%s: tokenizing the input ..\n", __func__); + const bool add_bos = llama_vocab_get_add_bos(vocab); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); - std::vector tokens = ::llama_tokenize(ctx, params.prompt, true); + LOG_INF("%s: tokenizing the input ..\n", __func__); + + std::vector tokens = common_tokenize(ctx, params.prompt, true); const int n_ctx = llama_n_ctx(ctx); if (int(tokens.size()) < 2*n_ctx) { - fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx, + LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx, n_ctx); - fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); + LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); return {std::move(tokens), 0., {}, {}}; } @@ -363,16 +322,16 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & prob_history.resize(tokens.size()); if (params.ppl_stride <= 0) { - fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride); + LOG_ERR("%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride); return {tokens, -1, logit_history, prob_history}; } const int calc_chunk = n_ctx; - fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk); + LOG_INF("%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk); if (int(tokens.size()) <= calc_chunk) { - fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__, + LOG_ERR("%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__, tokens.size(), n_ctx, params.ppl_stride); return {tokens, -1, logit_history, prob_history}; } @@ -380,20 +339,21 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_batch = params.n_batch; + const int n_vocab = llama_vocab_n_tokens(vocab); + int count = 0; double nll = 0.0; - fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); for (int i = 0; i < n_chunk; ++i) { const int start = i * params.ppl_stride; const int end = start + calc_chunk; const int num_batches = (calc_chunk + n_batch - 1) / n_batch; - //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches); + //LOG_DBG("%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches); std::vector logits; @@ -402,14 +362,21 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // clear the KV cache llama_kv_cache_clear(ctx); + llama_batch batch = llama_batch_init(n_batch, 0, 1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - // TODO: use llama_batch.logits instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { - //fprintf(stderr, "%s : failed to eval\n", __func__); + common_batch_clear(batch); + for (int i = 0; i < batch_size; i++) { + common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + } + + //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); + if (llama_decode(ctx, batch)) { + //LOG_ERR("%s : failed to eval\n", __func__); + llama_batch_free(batch); return {tokens, -1, logit_history, prob_history}; } @@ -418,37 +385,38 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + tokens[batch_start] = llama_vocab_bos(vocab); } - const auto batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + const auto * batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab); if (j == 0) { tokens[batch_start] = token_org; } } + llama_batch_free(batch); + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { const float t_total = std::chrono::duration(t_end - t_start).count(); - fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total); int total_seconds = (int)(t_total * n_chunk); if (total_seconds >= 60*60) { - fprintf(stderr, "%d hours ", total_seconds / (60*60)); + LOG("%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } - fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); + LOG("%.2f minutes\n", total_seconds / 60.0); } - //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start); + //LOG_DBG("%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start); for (int j = n_ctx - params.ppl_stride - 1; j < n_ctx - 1; ++j) { - // Calculate probability of next token, given the previous ones. const std::vector tok_logits( - logits.begin() + (j + 0) * n_vocab, - logits.begin() + (j + 1) * n_vocab); + logits.begin() + size_t(j + 0) * n_vocab, + logits.begin() + size_t(j + 1) * n_vocab); const float prob = softmax(tok_logits)[tokens[start + j + 1]]; logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]]; @@ -459,18 +427,17 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & } // perplexity is e^(average negative log-likelihood) if (params.ppl_output_type == 0) { - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + LOG("[%d]%.4lf,", i + 1, std::exp(nll / count)); } else { - printf("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count)); + LOG("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count)); } - fflush(stdout); } - printf("\n"); + LOG("\n"); return {tokens, std::exp(nll / count), logit_history, prob_history}; } -static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) { +static results_perplexity perplexity(llama_context * ctx, const common_params & params, const int32_t n_ctx) { if (params.ppl_stride > 0) { return perplexity_v2(ctx, params); } @@ -480,33 +447,36 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); std::ofstream logits_stream; if (!params.logits_file.empty()) { logits_stream.open(params.logits_file.c_str(), std::ios::binary); if (!logits_stream.is_open()) { - fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str()); + LOG_ERR("%s: failed to open %s for writing\n", __func__, params.logits_file.c_str()); return {}; } - fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str()); + LOG_INF("%s: saving all logits to %s\n", __func__, params.logits_file.c_str()); logits_stream.write("_logits_", 8); logits_stream.write(reinterpret_cast(&n_ctx), sizeof(n_ctx)); } auto tim1 = std::chrono::high_resolution_clock::now(); - fprintf(stderr, "%s: tokenizing the input ..\n", __func__); + LOG_INF("%s: tokenizing the input ..\n", __func__); - std::vector tokens = ::llama_tokenize(ctx, params.prompt, true); + std::vector tokens = common_tokenize(ctx, params.prompt, true); auto tim2 = std::chrono::high_resolution_clock::now(); - fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); + LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); if (int(tokens.size()) < 2*n_ctx) { - fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx, + LOG_ERR("%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx, n_ctx); - fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); + LOG_ERR("%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); return {std::move(tokens), 0., {}, {}}; } @@ -519,9 +489,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const int n_chunk_max = tokens.size() / n_ctx; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_batch = params.n_batch; + const int n_vocab = llama_vocab_n_tokens(vocab); + int count = 0; double nll = 0.0; double nll2 = 0.0; @@ -536,10 +507,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par std::vector logits; if (num_batches > 1) { - logits.reserve((size_t)n_ctx * n_vocab); + logits.reserve(size_t(n_ctx) * n_vocab); } - fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); + LOG_INF("%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); std::vector workers(std::thread::hardware_concurrency() - 1); @@ -592,7 +563,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[seq_start] = llama_token_bos(llama_get_model(ctx)); + tokens[seq_start] = llama_vocab_bos(vocab); } for (int k = 0; k < batch_size; ++k) { @@ -612,13 +583,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval\n", __func__); + LOG_INF("%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } if (num_batches > 1 && n_outputs > 0) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab); } } @@ -627,13 +598,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par llama_synchronize(ctx); const auto t_end = std::chrono::high_resolution_clock::now(); const float t_total = std::chrono::duration(t_end - t_start).count(); - fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total); int total_seconds = (int)(t_total*n_chunk/n_seq); if (total_seconds >= 60*60) { - fprintf(stderr, "%d hours ", total_seconds / (60*60)); + LOG("%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } - fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); + LOG("%.2f minutes\n", total_seconds / 60.0); } for (int seq = 0; seq < n_seq_batch; seq++) { @@ -655,19 +626,20 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // perplexity is e^(average negative log-likelihood) if (params.ppl_output_type == 0) { - printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count)); + LOG("[%d]%.4lf,", i + seq + 1, std::exp(nll / count)); } else { double av = nll/count; double av2 = nll2/count - av*av; - if (av2 > 0) av2 = sqrt(av2/(count-1)); - printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); + if (av2 > 0) { + av2 = sqrt(av2/(count-1)); + } + LOG("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); } } - fflush(stdout); logits.clear(); } - printf("\n"); + LOG("\n"); nll2 /= count; nll /= count; @@ -675,9 +647,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par nll2 -= nll * nll; if (nll2 > 0) { nll2 = sqrt(nll2/(count-1)); - printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); + LOG_INF("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); } else { - printf("Unexpected negative standard deviation of log(prob)\n"); + LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); } llama_batch_free(batch); @@ -685,10 +657,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par return {tokens, ppl, logit_history, prob_history}; } -static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int32_t n_batch, int32_t n_vocab) { +static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int n_batch, int n_vocab) { int prev_outputs = 0; - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + for (int i = 0; i < (int) batch.n_tokens; i += n_batch) { + const int n_tokens = std::min(n_batch, batch.n_tokens - i); llama_batch batch_view = { n_tokens, @@ -698,12 +670,11 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); if (ret != 0) { - LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); return false; } @@ -712,7 +683,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< n_outputs += batch_view.logits[i] != 0; } - memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float)); + memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); prev_outputs += n_outputs; } @@ -727,7 +698,9 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto if (eval_results.size() != eval_pairs.size()) { eval_results.resize(eval_pairs.size()); } - if (eval_pairs.empty()) return; + if (eval_pairs.empty()) { + return; + } size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size()); @@ -735,11 +708,13 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () { float local_logprobs[K_TOKEN_CHUNK]; while (true) { - size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed); - if (first >= eval_results.size()) break; - size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size()); + const size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed); + if (first >= eval_results.size()) { + break; + } + const size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size()); for (size_t i = first; i < last; ++i) { - auto logits = batch_logits + eval_pairs[i].first * n_vocab; + const auto * logits = batch_logits + eval_pairs[i].first * n_vocab; float max_logit = logits[0]; for (int j = 1; j < n_vocab; ++j) { max_logit = std::max(max_logit, logits[j]); @@ -762,7 +737,10 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto } } -static void hellaswag_score(llama_context * ctx, const gpt_params & params) { +static void hellaswag_score(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + // Calculates hellaswag score (acc_norm) from prompt // // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl @@ -789,15 +767,15 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } if (prompt_lines.size() % 6 != 0) { - fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__); + LOG_ERR("%s : number of lines in prompt not a multiple of 6.\n", __func__); return; } size_t hs_task_count = prompt_lines.size()/6; - fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count); + LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count); - const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; - fprintf(stderr, "================================= is_spm = %d\n", is_spm); + const bool is_spm = llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM; + LOG_INF("================================= is_spm = %d\n", is_spm); // The tasks should be randomized so the score stabilizes quickly. bool randomize_tasks = true; @@ -824,7 +802,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { std::vector seq_tokens[4]; }; - fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") ); + LOG_INF("%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") ); // Select and read data from prompt lines std::vector hs_data(hs_task_count); @@ -843,7 +821,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] ); for (size_t j = 0; j < 4; j++) { hs_cur.ending[j] = prompt_lines[idx*6+2+j]; - hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true); + hs_cur.seq_tokens[j] = common_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], true); } // determine the common prefix of the endings @@ -870,16 +848,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } } - fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__); + LOG_INF("%s : calculating hellaswag score over selected tasks.\n", __func__); - printf("\ntask\tacc_norm\n"); + LOG("\ntask\tacc_norm\n"); double acc = 0.0f; - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; + const int n_vocab = llama_vocab_n_tokens(vocab); + const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); @@ -887,7 +866,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size - std::vector batch_logits(n_vocab*n_ctx); + std::vector batch_logits(size_t(n_ctx)*n_vocab); std::vector> eval_pairs; std::vector eval_results; @@ -899,7 +878,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - llama_batch_clear(batch); + common_batch_clear(batch); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -915,7 +894,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } for (size_t i = 0; i < hs_cur.common_prefix; ++i) { - llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); + common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); } batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; @@ -925,7 +904,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // TODO: don't evaluate the last token of each sequence for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); + common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); n_logits += needs_logits; } } @@ -940,7 +919,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } if (i0 == i1) { - fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0); + LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0); return; } @@ -948,7 +927,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { - fprintf(stderr, "%s: llama_decode() failed\n", __func__); + LOG_ERR("%s: llama_decode() failed\n", __func__); return; } @@ -974,7 +953,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { auto & hs_cur = hs_data[i]; // get the logits of the last token of the common prefix - std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float)); + std::memcpy(tok_logits.data(), batch_logits.data() + hs_cur.i_logits*n_vocab, n_vocab*sizeof(float)); const auto first_probs = softmax(tok_logits); @@ -998,7 +977,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } } - //printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx); + //LOG("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx); // If the gold ending got the maximum logprobe add one accuracy point if (ending_logprob_max_idx == hs_cur.gold_ending_idx) { @@ -1006,8 +985,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } // Print the accumulated accuracy mean x 100 - printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0); - fflush(stdout); + LOG("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0); } i0 = i1 - 1; @@ -1015,7 +993,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { llama_batch_free(batch); - printf("\n"); + LOG("\n"); } struct winogrande_entry { @@ -1059,7 +1037,7 @@ static std::vector load_winogrande_from_csv(const std::string } } if (ipos != 4) { - printf("%s: failed to find comma separators in <%s>\n", __func__, line.c_str()); + LOG_ERR("%s: failed to find comma separators in <%s>\n", __func__, line.c_str()); continue; } auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(comma_pos[0]+2, comma_pos[1] - comma_pos[0] - 3) @@ -1073,13 +1051,13 @@ static std::vector load_winogrande_from_csv(const std::string if (sentence[where] == '_') break; } if (where == int(sentence.size())) { - printf("%s: no _ in <%s>\n", __func__, sentence.c_str()); + LOG_ERR("%s: no _ in <%s>\n", __func__, sentence.c_str()); continue; } std::istringstream stream(answer.c_str()); int i_answer; stream >> i_answer; if (stream.fail() || i_answer < 1 || i_answer > 2) { - printf("%s: failed to parse answer <%s>\n", __func__, answer.c_str()); + LOG_ERR("%s: failed to parse answer <%s>\n", __func__, answer.c_str()); continue; } result.emplace_back(); @@ -1102,20 +1080,22 @@ static std::vector load_winogrande_from_csv(const std::string * 0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2 * */ -static void winogrande_score(llama_context * ctx, const gpt_params & params) { +static void winogrande_score(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); constexpr int k_min_trailing_ctx = 3; auto data = load_winogrande_from_csv(params.prompt); if (data.empty()) { - fprintf(stderr, "%s: no tasks\n", __func__); + LOG_ERR("%s: no tasks\n", __func__); return; } - fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, data.size()); + LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, data.size()); if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) { - fprintf(stderr, "%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks); + LOG_INF("%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks); std::mt19937 rng(1); std::vector aux(data.size()); for (int i = 0; i < int(data.size()); ++i) { @@ -1133,11 +1113,11 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { data = std::move(selected); } - fprintf(stderr, "%s : tokenizing selected tasks\n", __func__); + LOG_INF("%s : tokenizing selected tasks\n", __func__); for (auto & task : data) { - task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, true); - task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, true); + task.seq_tokens[0] = common_tokenize(ctx, task.first + task.choices[0] + task.second, true); + task.seq_tokens[1] = common_tokenize(ctx, task.first + task.choices[1] + task.second, true); task.common_prefix = 0; for (size_t k = 0; k < task.seq_tokens[0].size(); k++) { @@ -1152,16 +1132,17 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { task.seq_tokens[0].size() - task.common_prefix + task.seq_tokens[1].size() - task.common_prefix; - task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], true).size(); - task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], true).size(); + task.n_base1 = common_tokenize(ctx, task.first + task.choices[0], true).size(); + task.n_base2 = common_tokenize(ctx, task.first + task.choices[1], true).size(); } - fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__); + LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; + const int n_vocab = llama_vocab_n_tokens(vocab); + const int max_tasks_per_batch = 128; const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); @@ -1169,7 +1150,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size - std::vector batch_logits(n_vocab*n_ctx); + std::vector batch_logits(size_t(n_ctx)*n_vocab); std::vector> eval_pairs; std::vector eval_results; @@ -1184,7 +1165,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { size_t i1 = i0; size_t i_logits = 0; - llama_batch_clear(batch); + common_batch_clear(batch); while (n_cur + (int) data[i1].required_tokens <= n_ctx) { int n_logits = 0; @@ -1194,7 +1175,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { } for (size_t i = 0; i < data[i1].common_prefix; ++i) { - llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); + common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); } batch.logits[batch.n_tokens - 1] = true; n_logits += 1; @@ -1202,7 +1183,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { for (int s = 0; s < 2; ++s) { // TODO: end before the last token, no need to predict past the end of the sequences for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { - llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); + common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); n_logits += 1; } } @@ -1217,7 +1198,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { } if (i0 == i1) { - fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0); + LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0); return; } @@ -1225,7 +1206,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { - fprintf(stderr, "%s: llama_decode() failed\n", __func__); + LOG_ERR("%s: llama_decode() failed\n", __func__); return; } @@ -1285,20 +1266,20 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { ++n_done; // print the accumulated accuracy mean x 100 - printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer); - fflush(stdout); + LOG("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer); } i0 = i1 - 1; } - printf("\n"); + LOG("\n"); if (n_done < 100) return; const float p = 1.f*n_correct/n_done; const float sigma = 100.f*sqrt(p*(1-p)/(n_done-1)); - printf("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma); + + LOG_INF("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma); } static bool deserialize_string(std::istream & in, std::string & str) { @@ -1347,7 +1328,7 @@ struct multiple_choice_task { static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choice_task& task, bool log_error) { if (task.question.empty() || task.mc1.answers.empty()) { if (log_error) { - printf("%s: found bad task with empty question and/or answers\n", __func__); + LOG_ERR("%s: found bad task with empty question and/or answers\n", __func__); } return false; } @@ -1355,11 +1336,11 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic for (auto& answer : task.mc1.answers) { if (answer.empty()) { if (log_error) { - printf("%s: found empty answer\n", __func__); + LOG_ERR("%s: found empty answer\n", __func__); } return false; } - task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, true)); + task.seq_tokens.emplace_back(::common_tokenize(ctx, task.question + " " + answer, true)); } auto min_len = task.seq_tokens.front().size(); for (auto& seq : task.seq_tokens) { @@ -1403,20 +1384,22 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic // git@hf.co:datasets/Stevross/mmlu // https://huggingface.co/datasets/truthful_qa // -static void multiple_choice_score(llama_context * ctx, const gpt_params & params) { +static void multiple_choice_score(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); std::istringstream strstream(params.prompt); uint32_t n_task; strstream.read((char *)&n_task, sizeof(n_task)); if (strstream.fail() || n_task == 0) { - printf("%s: no tasks\n", __func__); + LOG_ERR("%s: no tasks\n", __func__); return; } - printf("%s: there are %u tasks in prompt\n", __func__, n_task); + LOG_INF("%s: there are %u tasks in prompt\n", __func__, n_task); std::vector task_pos(n_task); strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t)); if (strstream.fail()) { - printf("%s: failed to read task positions from prompt\n", __func__); + LOG_ERR("%s: failed to read task positions from prompt\n", __func__); return; } @@ -1424,21 +1407,21 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) { // Use all tasks tasks.resize(n_task); - printf("%s: reading tasks", __func__); + LOG_INF("%s: reading tasks", __func__); int n_dot = std::max((int) n_task/100, 1); int i = 0; for (auto& task : tasks) { ++i; if (!task.deserialize(strstream)) { - printf("%s: failed to read task %d of %u\n", __func__, i, n_task); + LOG_ERR("%s: failed to read task %d of %u\n", __func__, i, n_task); return; } - if (i%n_dot == 0) printf("."); + if (i%n_dot == 0) LOG("."); } - printf("done\n"); + LOG("done\n"); } else { - printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task); + LOG_INF("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task); std::mt19937 rng(1); std::vector aux(n_task); for (uint32_t i = 0; i < n_task; ++i) aux[i] = i; @@ -1451,18 +1434,16 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params aux.pop_back(); strstream.seekg(task_pos[idx], std::ios::beg); if (!task.deserialize(strstream)) { - printf("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]); + LOG_ERR("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]); return; } } n_task = params.multiple_choice_tasks; } - printf("%s: preparing task data", __func__); - fflush(stdout); + LOG_INF("%s: preparing task data", __func__); if (n_task > 500) { - printf("..."); - fflush(stdout); + LOG("..."); std::atomic counter(0); std::atomic n_bad(0); auto prepare = [&counter, &n_bad, &tasks, ctx] () { @@ -1486,11 +1467,10 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params for (auto& w : workers) w = std::thread(prepare); prepare(); for (auto& w : workers) w.join(); - printf("done\n"); - fflush(stdout); + LOG("done\n"); int nbad = n_bad; if (nbad > 0) { - printf("%s: found %d malformed tasks\n", __func__, nbad); + LOG_ERR("%s: found %d malformed tasks\n", __func__, nbad); return; } } else { @@ -1502,28 +1482,28 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params return; } if (i_task%n_dot == 0) { - printf("."); - fflush(stdout); + LOG("."); } } - printf("done\n"); + LOG("done\n"); } - printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size()); + LOG_INF("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size()); - printf("\ntask\tacc_norm\n"); + LOG("\ntask\tacc_norm\n"); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; + const int n_vocab = llama_vocab_n_tokens(vocab); + const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); std::vector tok_logits(n_vocab); - std::vector batch_logits(n_vocab*n_ctx); + std::vector batch_logits(size_t(n_ctx)*n_vocab); std::vector> eval_pairs; std::vector eval_results; @@ -1540,7 +1520,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - llama_batch_clear(batch); + common_batch_clear(batch); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -1563,7 +1543,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params for (size_t i = 0; i < cur_task.common_prefix; ++i) { //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); - llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); + common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); } batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix n_logits += 1; @@ -1573,7 +1553,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params // TODO: don't evaluate the last token of each sequence for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); + common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); n_logits += needs_logits; } } @@ -1590,7 +1570,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params } if (i0 == i1) { - fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0); + LOG_ERR("%s : task %zu does not fit in the context window\n", __func__, i0); return; } @@ -1598,7 +1578,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { - fprintf(stderr, "%s: llama_decode() failed\n", __func__); + LOG_ERR("%s: llama_decode() failed\n", __func__); return; } @@ -1622,16 +1602,16 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params // compute the logprobs for each ending of the decoded tasks for (size_t i = i0; i < i1; ++i) { auto & cur_task = tasks[i]; - //printf("==== Evaluating <%s> with correct answer ", cur_task.question.c_str()); + //LOG("==== Evaluating <%s> with correct answer ", cur_task.question.c_str()); //for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) { // if (cur_task.mc1.labels[j] == 1) { - // printf("%d", j+1); + // LOG("%d", j+1); // } //} - //printf("\n common_prefix: %zu\n", cur_task.common_prefix); + //LOG("\n common_prefix: %zu\n", cur_task.common_prefix); // get the logits of the last token of the common prefix - std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float)); + std::memcpy(tok_logits.data(), batch_logits.data() + cur_task.i_logits*n_vocab, n_vocab*sizeof(float)); const auto first_probs = softmax(tok_logits); @@ -1640,13 +1620,13 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params size_t count = 1; float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]); for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) { - //printf(" %zu %g\n", ir, eval_results[ir]); + //LOG(" %zu %g\n", ir, eval_results[ir]); ++count; log_prob += eval_results[ir++]; } cur_task.log_probs[s] = log_prob / count; - //printf(" Final: %g\n", log_prob / count); - //printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count); + //LOG(" Final: %g\n", log_prob / count); + //LOG(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count); } // Find the ending with maximum logprob @@ -1666,8 +1646,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params ++n_done; // Print the accumulated accuracy mean x 100 - printf("%d\t%.8lf\n", n_done, 100.*n_correct/n_done); - fflush(stdout); + LOG("%d\t%.8lf\n", n_done, 100.*n_correct/n_done); } i0 = i1 - 1; @@ -1679,29 +1658,33 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params float p = 1.f*n_correct/n_done; float sigma = sqrt(p*(1-p)/(n_done-1)); - printf("\n Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma); + LOG("\n"); + LOG_INF("Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma); p = 1.f*n_done/n_tot_answers; sigma = sqrt(p*(1-p)/(n_done-1)); - printf("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma); + LOG_INF("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma); - printf("\n"); + LOG_INF("\n"); } -static void kl_divergence(llama_context * ctx, const gpt_params & params) { +static void kl_divergence(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + if (params.logits_file.empty()) { - fprintf(stderr, "%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__); + LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__); return; } std::ifstream in(params.logits_file.c_str(), std::ios::binary); if (!in) { - fprintf(stderr, "%s: failed to open %s\n", __func__, params.logits_file.c_str()); + LOG_ERR("%s: failed to open %s\n", __func__, params.logits_file.c_str()); return; } { char check[9]; check[8] = 0; in.read(check, 8); if (in.fail() || strncmp("_logits_", check, 8) != 0) { - fprintf(stderr, "%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str()); + LOG_ERR("%s: %s does not look like a file containing log-probabilities\n", __func__, params.logits_file.c_str()); return; } } @@ -1709,39 +1692,40 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { uint32_t n_ctx; in.read((char *)&n_ctx, sizeof(n_ctx)); if (n_ctx > llama_n_ctx(ctx)) { - fprintf(stderr, "%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n", + LOG_ERR("%s: %s has been computed with %u, while the current context is %d. Increase it with -c and retry\n", __func__, params.logits_file.c_str(), n_ctx, params.n_ctx); } - int n_vocab, n_chunk; + int n_vocab; + int n_chunk; in.read((char *)&n_vocab, sizeof(n_vocab)); in.read((char *)&n_chunk, sizeof(n_chunk)); if (in.fail()) { - fprintf(stderr, "%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str()); + LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str()); return; } - if (n_vocab != llama_n_vocab(llama_get_model(ctx))) { - fprintf(stderr, "%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx))); + if (n_vocab != llama_vocab_n_tokens(vocab)) { + LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_vocab_n_tokens(vocab)); } - std::vector tokens(n_ctx * n_chunk); + std::vector tokens(size_t(n_ctx) * n_chunk); if (in.read((char *)tokens.data(), tokens.size()*sizeof(tokens[0])).fail()) { - fprintf(stderr, "%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str()); + LOG_ERR("%s: failed reading evaluation tokens from %s\n", __func__, params.logits_file.c_str()); return; } const int n_batch = params.n_batch; const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int nv = 2*((n_vocab + 1)/2) + 4; - const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); - GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); + const bool add_bos = llama_vocab_get_add_bos(vocab); + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector logits; if (num_batches > 1) { - logits.reserve(n_ctx * n_vocab); + logits.reserve(size_t(n_ctx) * n_vocab); } std::vector workers(std::thread::hardware_concurrency() - 1); @@ -1775,13 +1759,15 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { const auto t_start = std::chrono::high_resolution_clock::now(); if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) { - fprintf(stderr, "%s: failed reading log-probs for chunk %d\n", __func__, i); + LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i); return; } // clear the KV cache llama_kv_cache_clear(ctx); + llama_batch batch = llama_batch_init(n_batch, 0, 1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -1791,12 +1777,17 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + tokens[batch_start] = llama_vocab_bos(vocab); } - // TODO: use llama_batch.logits instead of relying on logits_all == true - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { - fprintf(stderr, "%s : failed to eval\n", __func__); + common_batch_clear(batch); + for (int i = 0; i < batch_size; i++) { + common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + } + + if (llama_decode(ctx, batch)) { + LOG_ERR("%s : failed to eval\n", __func__); + llama_batch_free(batch); return; } @@ -1805,105 +1796,105 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { if (num_batches > 1) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab); } } + llama_batch_free(batch); + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { const float t_total = std::chrono::duration(t_end - t_start).count(); - fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total); int total_seconds = (int)(t_total * n_chunk); if (total_seconds >= 60*60) { - fprintf(stderr, "%d hours ", total_seconds / (60*60)); + LOG("%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } - fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); - - printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n"); + LOG("%.2f minutes\n", total_seconds / 60.0); } + LOG("\n"); + LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n"); const int first = n_ctx/2; const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); - process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, + process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr); p_diff_ptr += n_ctx - 1 - first; kld_ptr += n_ctx - 1 - first; - printf("%4d", i+1); + LOG("%4d", i+1); auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); const double ppl_val = exp(log_ppl.first); const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) - printf(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); + LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); - printf(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc); + LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc); auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); - printf(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); + LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); const double p_diff_rms_val = sqrt(p_diff_mse.first); const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; - printf(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); double p_top_val = 1.*kld.n_same_top/kld.count; double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1)); - printf(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc); + LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc); - printf("\n"); - - fflush(stdout); + LOG("\n"); logits.clear(); } - printf("\n"); + LOG("\n"); if (kld.count < 100) return; // we do not wish to do statistics on so few values std::sort(kld_values.begin(), kld_values.end()); std::sort(p_diff_values.begin(), p_diff_values.end()); - printf("====== Perplexity statistics ======\n"); + LOG("====== Perplexity statistics ======\n"); auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); const double ppl_val = exp(log_ppl.first); const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) - printf("Mean PPL(Q) : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc); + LOG("Mean PPL(Q) : %10.6lf ± %10.6lf\n", ppl_val, ppl_unc); auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); const double ppl_base_val = exp(log_ppl_base.first); const double ppl_base_unc = ppl_base_val * log_ppl_base.second; // ppl_base_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_base.second ** 2 ) - printf("Mean PPL(base) : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc); + LOG("Mean PPL(base) : %10.6lf ± %10.6lf\n", ppl_base_val, ppl_base_unc); const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); - // printf("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov); + // LOG("Cov(ln(PPL(Q)), ln(PPL(base))): %10.6lf\n", log_ppl_cov); const double log_ppl_cor = log_ppl_cov / (log_ppl.second*log_ppl_base.second); - printf("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor); + LOG("Cor(ln(PPL(Q)), ln(PPL(base))): %6.2lf%%\n", 100.0*log_ppl_cor); const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); - printf("Mean ln(PPL(Q)/PPL(base)) : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc); + LOG("Mean ln(PPL(Q)/PPL(base)) : %10.6lf ± %10.6lf\n", log_ppl_ratio_val, log_ppl_ratio_unc); const double ppl_ratio_val = exp(log_ppl_ratio_val); const double ppl_ratio_unc = ppl_ratio_val * log_ppl_ratio_unc; // ppl_ratio_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl_ratio.second ** 2 ) - printf("Mean PPL(Q)/PPL(base) : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc); + LOG("Mean PPL(Q)/PPL(base) : %10.6lf ± %10.6lf\n", ppl_ratio_val, ppl_ratio_unc); const double ppl_cov = ppl_val * ppl_base_val * log_ppl_cov; const double ppl_diff_val = ppl_val - ppl_base_val; const double ppl_diff_unc = sqrt(ppl_unc*ppl_unc + ppl_base_unc*ppl_base_unc - 2.0*ppl_cov); - printf("Mean PPL(Q)-PPL(base) : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc); + LOG("Mean PPL(Q)-PPL(base) : %10.6lf ± %10.6lf\n", ppl_diff_val, ppl_diff_unc); - printf("\n"); + LOG("\n"); - printf("====== KL divergence statistics ======\n"); + LOG("====== KL divergence statistics ======\n"); auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); - printf("Mean KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second); + LOG("Mean KLD: %10.6lf ± %10.6lf\n", kl_div.first, kl_div.second); auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1]) : kld_values[kld_values.size()/2]; @@ -1915,67 +1906,68 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { return (1 - p)*values[ip] + p*values[std::min(ip+1, values.size()-1)]; }; - printf("Maximum KLD: %10.6f\n", kld_values.back()); - printf("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f)); - printf("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); - printf("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); - printf("Median KLD: %10.6f\n", kld_median); - printf("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f)); - printf(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f)); - printf(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f)); - printf("Minimum KLD: %10.6f\n", kld_values.front()); + LOG("Maximum KLD: %10.6f\n", kld_values.back()); + LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f)); + LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); + LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); + LOG("Median KLD: %10.6f\n", kld_median); + LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f)); + LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f)); + LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f)); + LOG("Minimum KLD: %10.6f\n", kld_values.front()); - printf("\n"); + LOG("\n"); - printf("====== Token probability statistics ======\n"); + LOG("====== Token probability statistics ======\n"); auto p_diff = mean_and_uncertainty(kld.sum_p_diff, kld.sum_p_diff2, kld.count); - printf("Mean Δp: %6.3lf ± %5.3lf %%\n", 100.0*p_diff.first, 100.0*p_diff.second); + LOG("Mean Δp: %6.3lf ± %5.3lf %%\n", 100.0*p_diff.first, 100.0*p_diff.second); auto p_diff_median = p_diff_values.size()%2 == 0 ? 0.5f*(p_diff_values[p_diff_values.size()/2] + p_diff_values[p_diff_values.size()/2-1]) : p_diff_values[p_diff_values.size()/2]; - printf("Maximum Δp: %6.3lf%%\n", 100.0*p_diff_values.back()); - printf("99.9%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f)); - printf("99.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f)); - printf("95.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f)); - printf("90.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f)); - printf("75.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f)); - printf("Median Δp: %6.3lf%%\n", 100.0*p_diff_median); - printf("25.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f)); - printf("10.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f)); - printf(" 5.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f)); - printf(" 1.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f)); - printf(" 0.1%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f)); - printf("Minimum Δp: %6.3lf%%\n", 100.0*p_diff_values.front()); + LOG("Maximum Δp: %6.3lf%%\n", 100.0*p_diff_values.back()); + LOG("99.9%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.999f)); + LOG("99.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.990f)); + LOG("95.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.950f)); + LOG("90.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.900f)); + LOG("75.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.750f)); + LOG("Median Δp: %6.3lf%%\n", 100.0*p_diff_median); + LOG("25.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.250f)); + LOG("10.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.100f)); + LOG(" 5.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.050f)); + LOG(" 1.0%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.010f)); + LOG(" 0.1%% Δp: %6.3lf%%\n", 100.0*percentile(p_diff_values, 0.001f)); + LOG("Minimum Δp: %6.3lf%%\n", 100.0*p_diff_values.front()); auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); - // printf("MSE Δp : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second); + // LOG("MSE Δp : %10.6lf ± %10.6lf\n", p_diff_mse.first, p_diff_mse.second); const double p_diff_rms_val = sqrt(p_diff_mse.first); const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; - printf("RMS Δp : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + LOG("RMS Δp : %6.3lf ± %5.3lf %%\n", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); const double same_top_p = 1.0*kld.n_same_top/kld.count; - printf("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1))); - + LOG("Same top p: %6.3lf ± %5.3lf %%\n", 100.0*same_top_p, 100.0*sqrt(same_top_p*(1.0 - same_top_p)/(kld.count - 1))); } int main(int argc, char ** argv) { - gpt_params params; + common_params params; params.n_ctx = 512; params.logits_all = true; + params.escape = false; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_PERPLEXITY); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { return 1; } + common_init(); + const int32_t n_ctx = params.n_ctx; if (n_ctx <= 0) { - fprintf(stderr, "%s: perplexity tool requires '--ctx-size' > 0\n", __func__); + LOG_ERR("%s: perplexity tool requires '--ctx-size' > 0\n", __func__); return 1; } @@ -2000,39 +1992,36 @@ int main(int argc, char ** argv) { } if (params.ppl_stride > 0) { - fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n", + LOG_INF("Will perform strided perplexity calculation -> adjusting context size from %d to %d\n", params.n_ctx, params.n_ctx + params.ppl_stride/2); params.n_ctx += params.ppl_stride/2; } - print_build_info(); - - LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); - llama_backend_init(); llama_numa_init(params.numa); // load the model and apply lora adapter, if any - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); + + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; if (model == NULL) { - fprintf(stderr, "%s: error: unable to load model\n", __func__); + LOG_ERR("%s: unable to load model\n", __func__); return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); if (params.n_ctx > n_ctx_train) { - fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", + LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, params.n_ctx); } // print system information { - fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } struct results_perplexity results; @@ -2048,12 +2037,8 @@ int main(int argc, char ** argv) { results = perplexity(ctx, params, n_ctx); } - LOG_TEE("\n"); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); - write_logfile(ctx, params, model, results); - - llama_free(ctx); - llama_free_model(model); + LOG("\n"); + llama_perf_context_print(ctx); llama_backend_free(); diff --git a/examples/quantize-stats/CMakeLists.txt b/examples/quantize-stats/CMakeLists.txt index bb986a716..9a3a0d3cd 100644 --- a/examples/quantize-stats/CMakeLists.txt +++ b/examples/quantize-stats/CMakeLists.txt @@ -3,4 +3,4 @@ add_executable(${TARGET} quantize-stats.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT}) target_include_directories(${TARGET} PRIVATE ../../common) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 498cbbe3c..bd2f73467 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,7 +1,7 @@ -#include "common.h" #include "ggml.h" #include "llama.h" -#include "llama-impl.h" +#include "llama-context.h" +#include "common.h" #include #include @@ -9,11 +9,9 @@ #include #include #include -#include #include #include #include -#include #include #include #include @@ -142,7 +140,7 @@ static bool tensor_is_contiguous(const struct ggml_tensor * tensor) { } static void test_roundtrip_on_chunk( - const ggml_tensor * layer, int64_t offset, int64_t chunk_size, const ggml_type_traits_t & qfns, bool use_reference, + const ggml_tensor * layer, int64_t offset, int64_t chunk_size, const ggml_type_traits & qfns, const ggml_type_traits_cpu & qfns_cpu, bool use_reference, float * input_scratch, char * quantized_scratch, float * output_scratch, error_stats & stats ) { if (layer->type == GGML_TYPE_F16) { @@ -156,7 +154,7 @@ static void test_roundtrip_on_chunk( if (use_reference) { qfns.from_float_ref(input_scratch, quantized_scratch, chunk_size); } else { - qfns.from_float(input_scratch, quantized_scratch, chunk_size); + qfns_cpu.from_float(input_scratch, quantized_scratch, chunk_size); } qfns.to_float(quantized_scratch, output_scratch, chunk_size); @@ -166,7 +164,7 @@ static void test_roundtrip_on_chunk( // Run quantization function for a single layer and update error stats static void test_roundtrip_on_layer( - std::string & name, bool print_layer_stats, const ggml_type_traits_t & qfns, bool use_reference, + std::string & name, bool print_layer_stats, const ggml_type_traits & qfns, const ggml_type_traits_cpu & qfns_cpu, bool use_reference, const ggml_tensor * layer, std::vector & input_scratch, std::vector & quantized_scratch, std::vector & output_scratch, error_stats & total_error, int max_thread = 0 ) { @@ -187,13 +185,13 @@ static void test_roundtrip_on_layer( int num_chunks = (nelements + chunk_size - 1)/chunk_size; if (num_chunks < 2 || max_thread < 2) { - test_roundtrip_on_chunk(layer, 0, nelements, qfns, use_reference, input_scratch_ptr, quantized_scratch.data(), + test_roundtrip_on_chunk(layer, 0, nelements, qfns, qfns_cpu, use_reference, input_scratch_ptr, quantized_scratch.data(), output_scratch.data(), print_layer_stats ? layer_error : total_error); } else { auto & stats = print_layer_stats ? layer_error : total_error; std::mutex mutex; uint64_t counter = 0; - auto compute = [&mutex, &counter, &stats, &qfns, nelements, layer, use_reference, input_scratch_ptr, + auto compute = [&mutex, &counter, &stats, &qfns, &qfns_cpu, nelements, layer, use_reference, input_scratch_ptr, &quantized_scratch, &output_scratch, chunk_size] () { error_stats local_stats {}; while (true) { @@ -205,7 +203,7 @@ static void test_roundtrip_on_layer( } lock.unlock(); uint64_t chunk = offset + chunk_size < nelements ? chunk_size : nelements - offset; - test_roundtrip_on_chunk(layer, offset, chunk, qfns, use_reference, input_scratch_ptr + offset, + test_roundtrip_on_chunk(layer, offset, chunk, qfns, qfns_cpu, use_reference, input_scratch_ptr + offset, quantized_scratch.data() + 4*offset, output_scratch.data() + offset, local_stats); } }; @@ -311,7 +309,7 @@ int main(int argc, char ** argv) { auto mparams = llama_model_default_params(); mparams.use_mlock = false; - model = llama_load_model_from_file(params.model.c_str(), mparams); + model = llama_model_load_from_file(params.model.c_str(), mparams); if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); @@ -321,22 +319,22 @@ int main(int argc, char ** argv) { auto cparams = llama_context_default_params(); cparams.n_ctx = 256; - ctx = llama_new_context_with_model(model, cparams); + ctx = llama_init_from_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); + llama_model_free(model); return 1; } } - const auto &tensors = llama_internal_get_tensor_map(ctx); + const auto & tensors = llama_internal_get_tensor_map(ctx); // check layer tensors int included_layers = 0; int64_t max_nelements = 0; bool is_f16 = false; - for (const auto& kv_tensor : tensors) { + for (const auto & kv_tensor : tensors) { if (!layer_included(params, kv_tensor.first)) { continue; } @@ -349,7 +347,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: error: Quantization should be tested with a float model, " "this model contains already quantized layers (%s is type %d)\n", __func__, kv_tensor.first.c_str(), kv_tensor.second->type); llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 1; } included_layers++; @@ -371,8 +369,9 @@ int main(int argc, char ** argv) { if (!params.include_types.empty() && std::find(params.include_types.begin(), params.include_types.end(), i) == params.include_types.end()) { continue; } - ggml_type_traits_t qfns = ggml_internal_get_type_traits(type); - if (qfns.from_float && qfns.to_float) { + const auto * qfns = ggml_get_type_traits(type); + const auto * qfns_cpu = ggml_get_type_traits_cpu(type); + if (qfns_cpu->from_float && qfns->to_float) { if (params.verbose) { printf("testing %s ...\n", ggml_type_name(type)); } @@ -381,7 +380,7 @@ int main(int argc, char ** argv) { error_stats global_stats {}; - for (const auto& kv_tensor : tensors) { + for (const auto & kv_tensor : tensors) { if (!layer_included(params, kv_tensor.first)) { continue; } @@ -393,7 +392,7 @@ int main(int argc, char ** argv) { test_roundtrip_on_layer( layer_name, params.per_layer_stats, - qfns, + *qfns, *qfns_cpu, params.reference, kv_tensor.second, input_scratch, @@ -410,7 +409,7 @@ int main(int argc, char ** argv) { llama_free(ctx); - llama_free_model(model); + llama_model_free(model); // report timing { const int64_t t_main_end_us = ggml_time_us(); diff --git a/examples/quantize/CMakeLists.txt b/examples/quantize/CMakeLists.txt index 3ee4eb971..47e5cbe30 100644 --- a/examples/quantize/CMakeLists.txt +++ b/examples/quantize/CMakeLists.txt @@ -1,6 +1,6 @@ set(TARGET llama-quantize) add_executable(${TARGET} quantize.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_include_directories(${TARGET} PRIVATE ../../common) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/quantize/README.md b/examples/quantize/README.md index 5d1e11c67..f9cce7b21 100644 --- a/examples/quantize/README.md +++ b/examples/quantize/README.md @@ -81,7 +81,7 @@ Several quantization methods are supported. They differ in the resulting model d - [#4930 - imatrix for all k-quants](https://github.com/ggerganov/llama.cpp/pull/4930) - [#4951 - imatrix on the GPU](https://github.com/ggerganov/llama.cpp/pull/4957) - [#4969 - imatrix for legacy quants](https://github.com/ggerganov/llama.cpp/pull/4969) - - [#4996 - k-qunats tuning](https://github.com/ggerganov/llama.cpp/pull/4996) + - [#4996 - k-quants tuning](https://github.com/ggerganov/llama.cpp/pull/4996) - [#5060 - Q3_K_XS](https://github.com/ggerganov/llama.cpp/pull/5060) - [#5196 - 3-bit i-quants](https://github.com/ggerganov/llama.cpp/pull/5196) - [quantization tuning](https://github.com/ggerganov/llama.cpp/pull/5320), [another one](https://github.com/ggerganov/llama.cpp/pull/5334), and [another one](https://github.com/ggerganov/llama.cpp/pull/5361) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 0cde695ed..355aef4a6 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -1,5 +1,6 @@ #include "common.h" #include "llama.h" +#include "gguf.h" #include #include @@ -47,9 +48,6 @@ static const std::vector QUANT_OPTIONS = { { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", }, - { "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, - { "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, - { "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", }, { "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", }, { "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", }, @@ -67,6 +65,16 @@ static const char * const LLM_KV_IMATRIX_DATASET = "imatrix.dataset"; static const char * const LLM_KV_IMATRIX_CHUNK_COUNT = "imatrix.chunk_count"; static const char * const LLM_KV_IMATRIX_CHUNK_SIZE = "imatrix.chunk_size"; +static bool striequals(const char * a, const char * b) { + while (*a && *b) { + if (std::tolower(*a) != std::tolower(*b)) { + return false; + } + a++; b++; + } + return *a == *b; +} + static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) { std::string ftype_str; @@ -74,7 +82,7 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp ftype_str.push_back(std::toupper(ch)); } for (auto & it : QUANT_OPTIONS) { - if (it.name == ftype_str) { + if (striequals(it.name.c_str(), ftype_str.c_str())) { ftype = it.ftype; ftype_str_out = it.name; return true; @@ -276,15 +284,15 @@ static int prepare_imatrix(const std::string & imatrix_file, } static ggml_type parse_ggml_type(const char * arg) { - ggml_type result = GGML_TYPE_COUNT; - for (int j = 0; j < GGML_TYPE_COUNT; ++j) { - auto type = ggml_type(j); + for (int i = 0; i < GGML_TYPE_COUNT; ++i) { + auto type = (ggml_type)i; const auto * name = ggml_type_name(type); - if (name && strcmp(arg, name) == 0) { - result = type; break; + if (name && striequals(name, arg)) { + return type; } } - return result; + fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg); + return GGML_TYPE_COUNT; } int main(int argc, char ** argv) { @@ -305,12 +313,18 @@ int main(int argc, char ** argv) { } else if (strcmp(argv[arg_idx], "--output-tensor-type") == 0) { if (arg_idx < argc-1) { params.output_tensor_type = parse_ggml_type(argv[++arg_idx]); + if (params.output_tensor_type == GGML_TYPE_COUNT) { + usage(argv[0]); + } } else { usage(argv[0]); } } else if (strcmp(argv[arg_idx], "--token-embedding-type") == 0) { if (arg_idx < argc-1) { params.token_embedding_type = parse_ggml_type(argv[++arg_idx]); + if (params.token_embedding_type == GGML_TYPE_COUNT) { + usage(argv[0]); + } } else { usage(argv[0]); } diff --git a/examples/quantize/tests.sh b/examples/quantize/tests.sh index 24bc970e8..70f7610f9 100644 --- a/examples/quantize/tests.sh +++ b/examples/quantize/tests.sh @@ -47,7 +47,7 @@ echo PASS echo # 3a. Test the requanted model is loading properly -$MAIN --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-00001-of-00006.gguf --n-predict 32 echo PASS echo @@ -57,7 +57,7 @@ echo PASS echo # 4b. Test the requanted model is loading properly -$MAIN --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32 +$MAIN -no-cnv --model $WORK_PATH/ggml-model-requant-merge.gguf --n-predict 32 echo PASS echo diff --git a/examples/retrieval/CMakeLists.txt b/examples/retrieval/CMakeLists.txt index 66610f311..512a602ec 100644 --- a/examples/retrieval/CMakeLists.txt +++ b/examples/retrieval/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-retrieval) add_executable(${TARGET} retrieval.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index dd8a82e6e..2439022a2 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -1,13 +1,16 @@ +#include "arg.h" #include "common.h" +#include "log.h" #include "llama.h" #include #include +#include // TODO: remove me static void print_usage(int, char ** argv) { - LOG_TEE("\nexample usage:\n"); - LOG_TEE("\n %s --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator .\n", argv[0]); - LOG_TEE("\n"); + LOG("\nexample usage:\n"); + LOG("\n %s --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator .\n", argv[0]); + LOG("\n"); } struct chunk { @@ -16,7 +19,7 @@ struct chunk { // original file position size_t filepos; // original text data - std::string textdata = ""; + std::string textdata; // tokenized text data std::vector tokens; // embedding @@ -30,14 +33,14 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz std::ifstream f(filename.c_str()); if (!f.is_open()) { - fprintf(stderr, "Error: could not open file %s\n", filename.c_str()); + LOG_ERR("could not open file %s\n", filename.c_str()); return chunks; } chunk current_chunk; char buffer[1024]; int64_t filepos = 0; - std::string current = ""; + std::string current; while (f.read(buffer, 1024)) { current += std::string(buffer, f.gcount()); size_t pos; @@ -74,7 +77,7 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, true); + common_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -83,9 +86,9 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu llama_kv_cache_clear(ctx); // run model - fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); if (llama_decode(ctx, batch) < 0) { - fprintf(stderr, "%s : failed to decode\n", __func__); + LOG_ERR("%s : failed to decode\n", __func__); } for (int i = 0; i < batch.n_tokens; i++) { @@ -98,42 +101,41 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); if (embd == NULL) { - fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); + LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i); continue; } } float * out = output + batch.seq_id[i][0] * n_embd; - llama_embd_normalize(embd, out, n_embd); + common_embd_normalize(embd, out, n_embd, 2); } } int main(int argc, char ** argv) { - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_RETRIEVAL, print_usage); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_RETRIEVAL, print_usage)) { return 1; } + common_init(); + // For BERT models, batch size must be equal to ubatch size params.n_ubatch = params.n_batch; params.embedding = true; if (params.chunk_size <= 0) { - fprintf(stderr, "chunk_size must be positive\n"); + LOG_ERR("chunk_size must be positive\n"); return 1; } if (params.context_files.empty()) { - fprintf(stderr, "context_files must be specified\n"); + LOG_ERR("context_files must be specified\n"); return 1; } - print_build_info(); - - printf("processing files:\n"); + LOG_INF("processing files:\n"); for (auto & context_file : params.context_files) { - printf("%s\n", context_file.c_str()); + LOG_INF("%s\n", context_file.c_str()); } std::vector chunks; @@ -141,40 +143,42 @@ int main(int argc, char ** argv) { std::vector file_chunk = chunk_file(context_file, params.chunk_size, params.chunk_separator); chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end()); } - printf("Number of chunks: %ld\n", chunks.size()); + LOG_INF("Number of chunks: %zu\n", chunks.size()); llama_backend_init(); llama_numa_init(params.numa); // load the model - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); if (model == NULL) { - fprintf(stderr, "%s: error: unable to load model\n", __func__); + LOG_ERR("%s: unable to load model\n", __func__); return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + LOG_ERR("%s: pooling type NONE not supported\n", __func__); return 1; } if (n_ctx > n_ctx_train) { - fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", + LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); } // print system information { - fprintf(stderr, "\n"); - fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } // max batch size @@ -183,15 +187,15 @@ int main(int argc, char ** argv) { // tokenize the prompts and trim for (auto & chunk : chunks) { - auto inp = ::llama_tokenize(ctx, chunk.textdata, true, false); + auto inp = common_tokenize(ctx, chunk.textdata, true, false); if (inp.size() > n_batch) { - fprintf(stderr, "%s: error: chunk size (%lld) exceeds batch size (%lld), increase batch size and re-run\n", + LOG_ERR("%s: chunk size (%lld) exceeds batch size (%lld), increase batch size and re-run\n", __func__, (long long int) inp.size(), (long long int) n_batch); return 1; } // add eos if not present - if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) { - inp.push_back(llama_token_eos(model)); + if (llama_vocab_eos(vocab) >= 0 && (inp.empty() || inp.back() != llama_vocab_eos(vocab))) { + inp.push_back(llama_vocab_eos(vocab)); } chunk.tokens = inp; } @@ -199,12 +203,12 @@ int main(int argc, char ** argv) { // tokenization stats if (params.verbose_prompt) { for (int i = 0; i < (int) chunks.size(); i++) { - fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, chunks[i].textdata.c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, chunks[i].tokens.size()); + LOG_INF("%s: prompt %d: '%s'\n", __func__, i, chunks[i].textdata.c_str()); + LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, chunks[i].tokens.size()); for (int j = 0; j < (int) chunks[i].tokens.size(); j++) { - fprintf(stderr, "%6d -> '%s'\n", chunks[i].tokens[j], llama_token_to_piece(ctx, chunks[i].tokens[j]).c_str()); + LOG_INF("%6d -> '%s'\n", chunks[i].tokens[j], common_token_to_piece(ctx, chunks[i].tokens[j]).c_str()); } - fprintf(stderr, "\n\n"); + LOG_INF("\n\n"); } } @@ -213,7 +217,7 @@ int main(int argc, char ** argv) { struct llama_batch batch = llama_batch_init(n_batch, 0, 1); // allocate output - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embeddings(n_chunks * n_embd, 0); float * emb = embeddings.data(); @@ -230,7 +234,7 @@ int main(int argc, char ** argv) { if (batch.n_tokens + n_toks > n_batch) { float * out = emb + p * n_embd; batch_decode(ctx, batch, out, s, n_embd); - llama_batch_clear(batch); + common_batch_clear(batch); p += s; s = 0; } @@ -256,22 +260,22 @@ int main(int argc, char ** argv) { // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; while (true) { - printf("Enter query: "); + LOG("Enter query: "); std::getline(std::cin, query); - std::vector query_tokens = llama_tokenize(ctx, query, true); + std::vector query_tokens = common_tokenize(ctx, query, true); batch_add_seq(query_batch, query_tokens, 0); std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); - llama_batch_clear(query_batch); + common_batch_clear(query_batch); // compute cosine similarities { std::vector> similarities; for (int i = 0; i < n_chunks; i++) { - float sim = llama_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd); + float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd); similarities.push_back(std::make_pair(i, sim)); } @@ -280,23 +284,21 @@ int main(int argc, char ** argv) { return a.second > b.second; }); - printf("Top %d similar chunks:\n", params.sparams.top_k); - for (int i = 0; i < std::min(params.sparams.top_k, (int) chunks.size()); i++) { - printf("filename: %s\n", chunks[similarities[i].first].filename.c_str()); - printf("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos); - printf("similarity: %f\n", similarities[i].second); - printf("textdata:\n%s\n", chunks[similarities[i].first].textdata.c_str()); - printf("--------------------\n"); + LOG("Top %d similar chunks:\n", params.sampling.top_k); + for (int i = 0; i < std::min(params.sampling.top_k, (int) chunks.size()); i++) { + LOG("filename: %s\n", chunks[similarities[i].first].filename.c_str()); + LOG("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos); + LOG("similarity: %f\n", similarities[i].second); + LOG("textdata:\n%s\n", chunks[similarities[i].first].textdata.c_str()); + LOG("--------------------\n"); } } } - LOG_TEE("\n"); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + LOG("\n"); + llama_perf_context_print(ctx); // clean up llama_batch_free(query_batch); - llama_free(ctx); - llama_free_model(model); llama_backend_free(); } diff --git a/examples/rpc/README.md b/examples/rpc/README.md index adedc8909..312bb634d 100644 --- a/examples/rpc/README.md +++ b/examples/rpc/README.md @@ -10,20 +10,21 @@ This can be used for distributed LLM inference with `llama.cpp` in the following ```mermaid flowchart TD - rpcb---|TCP|srva - rpcb---|TCP|srvb - rpcb-.-|TCP|srvn + rpcb<-->|TCP|srva + rpcb<-->|TCP|srvb + rpcb<-.->|TCP|srvn subgraph hostn[Host N] - srvn[rpc-server]-.-backend3["Backend (CUDA,Metal,etc.)"] + srvn[rpc-server]<-.->backend3["Backend (CUDA,Metal,etc.)"] end subgraph hostb[Host B] - srvb[rpc-server]---backend2["Backend (CUDA,Metal,etc.)"] + srvb[rpc-server]<-->backend2["Backend (CUDA,Metal,etc.)"] end subgraph hosta[Host A] - srva[rpc-server]---backend["Backend (CUDA,Metal,etc.)"] + srva[rpc-server]<-->backend["Backend (CUDA,Metal,etc.)"] end subgraph host[Main Host] - ggml[llama.cpp]---rpcb[RPC backend] + local["Backend (CUDA,Metal,etc.)"]<-->ggml[llama-cli] + ggml[llama-cli]<-->rpcb[RPC backend] end style hostn stroke:#66,stroke-width:2px,stroke-dasharray: 5 5 ``` @@ -62,17 +63,12 @@ $ CUDA_VISIBLE_DEVICES=0 bin/rpc-server -p 50052 This way you can run multiple `rpc-server` instances on the same host, each with a different CUDA device. -On the main host build `llama.cpp` only with `-DGGML_RPC=ON`: - -```bash -mkdir build-rpc -cd build-rpc -cmake .. -DGGML_RPC=ON -cmake --build . --config Release -``` - -Finally, use the `--rpc` option to specify the host and port of each `rpc-server`: +On the main host build `llama.cpp` for the local backend and add `-DGGML_RPC=ON` to the build options. +Finally, when running `llama-cli`, use the `--rpc` option to specify the host and port of each `rpc-server`: ```bash $ bin/llama-cli -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name is" --repeat-penalty 1.0 -n 64 --rpc 192.168.88.10:50052,192.168.88.11:50052 -ngl 99 ``` + +This way you can offload model layers to both local and remote devices. + diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index 6342e6488..8b1b23eda 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -1,3 +1,5 @@ +#include "ggml-cpu.h" + #ifdef GGML_USE_CUDA #include "ggml-cuda.h" #endif @@ -6,6 +8,14 @@ #include "ggml-metal.h" #endif +#ifdef GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif + +#ifdef GGML_USE_SYCL +#include "ggml-sycl.h" +#endif + #include "ggml-rpc.h" #ifdef _WIN32 # include @@ -79,6 +89,18 @@ static ggml_backend_t create_backend() { if (!backend) { fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); } +#elif GGML_USE_VULKAN + fprintf(stderr, "%s: using Vulkan backend\n", __func__); + backend = ggml_backend_vk_init(0); // init device 0 + if (!backend) { + fprintf(stderr, "%s: ggml_backend_vulkan_init() failed\n", __func__); + } +#elif GGML_USE_SYCL + fprintf(stderr, "%s: using SYCL backend\n", __func__); + backend = ggml_backend_sycl_init(0); // init device 0 + if (!backend) { + fprintf(stderr, "%s: ggml_backend_sycl_init() failed\n", __func__); + } #endif // if there aren't GPU Backends fallback to CPU backend @@ -92,6 +114,10 @@ static ggml_backend_t create_backend() { static void get_backend_memory(size_t * free_mem, size_t * total_mem) { #ifdef GGML_USE_CUDA ggml_backend_cuda_get_device_memory(0, free_mem, total_mem); +#elif GGML_USE_VULKAN + ggml_backend_vk_get_device_memory(0, free_mem, total_mem); +#elif GGML_USE_SYCL + ggml_backend_sycl_get_device_memory(0, free_mem, total_mem); #else #ifdef _WIN32 MEMORYSTATUSEX status; @@ -139,7 +165,7 @@ int main(int argc, char * argv[]) { get_backend_memory(&free_mem, &total_mem); } printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024)); - start_rpc_server(backend, endpoint.c_str(), free_mem, total_mem); + ggml_backend_rpc_start_server(backend, endpoint.c_str(), free_mem, total_mem); ggml_backend_free(backend); return 0; } diff --git a/examples/run/CMakeLists.txt b/examples/run/CMakeLists.txt new file mode 100644 index 000000000..cd6b0520e --- /dev/null +++ b/examples/run/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-run) +add_executable(${TARGET} run.cpp linenoise.cpp/linenoise.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/run/README.md b/examples/run/README.md new file mode 100644 index 000000000..89a552079 --- /dev/null +++ b/examples/run/README.md @@ -0,0 +1,50 @@ +# llama.cpp/example/run + +The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models. + +```bash +llama-run granite3-moe +``` + +```bash +Description: + Runs a llm + +Usage: + llama-run [options] model [prompt] + +Options: + -c, --context-size + Context size (default: 2048) + -n, -ngl, --ngl + Number of GPU layers (default: 0) + --temp + Temperature (default: 0.8) + -v, --verbose, --log-verbose + Set verbosity level to infinity (i.e. log all messages, useful for debugging) + -h, --help + Show help message + +Commands: + model + Model is a string with an optional prefix of + huggingface:// (hf://), ollama://, https:// or file://. + If no protocol is specified and a file exists in the specified + path, file:// is assumed, otherwise if a file does not exist in + the specified path, ollama:// is assumed. Models that are being + pulled are downloaded with .partial extension while being + downloaded and then renamed as the file without the .partial + extension when complete. + +Examples: + llama-run llama3 + llama-run ollama://granite-code + llama-run ollama://smollm:135m + llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf + llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf + llama-run https://example.com/some-file1.gguf + llama-run some-file2.gguf + llama-run file://some-file3.gguf + llama-run --ngl 999 some-file4.gguf + llama-run --ngl 999 some-file5.gguf Hello World +``` diff --git a/examples/run/linenoise.cpp/LICENSE b/examples/run/linenoise.cpp/LICENSE new file mode 100644 index 000000000..b006b3b24 --- /dev/null +++ b/examples/run/linenoise.cpp/LICENSE @@ -0,0 +1,26 @@ +Copyright (c) 2010-2014, Salvatore Sanfilippo +Copyright (c) 2010-2013, Pieter Noordhuis +Copyright (c) 2025, Eric Curtin + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/examples/run/linenoise.cpp/linenoise.cpp b/examples/run/linenoise.cpp/linenoise.cpp new file mode 100644 index 000000000..a68f12a1a --- /dev/null +++ b/examples/run/linenoise.cpp/linenoise.cpp @@ -0,0 +1,1350 @@ +#ifndef _WIN32 +/* + * You can find the latest source code at: + * + * http://github.com/ericcurtin/linenoise.cpp + * + * Does a number of crazy assumptions that happen to be true in 99.9999% of + * the 2010 UNIX computers around. + * + * ------------------------------------------------------------------------ + * + * Copyright (c) 2010-2023, Salvatore Sanfilippo + * Copyright (c) 2010-2013, Pieter Noordhuis + * Copyright (c) 2025, Eric Curtin + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * ------------------------------------------------------------------------ + * + * References: + * - http://invisible-island.net/xterm/ctlseqs/ctlseqs.html + * - http://www.3waylabs.com/nw/WWW/products/wizcon/vt220.html + * + * Todo list: + * - Filter bogus Ctrl+ combinations. + * - Win32 support + * + * Bloat: + * - History search like Ctrl+r in readline? + * + * List of escape sequences used by this program, we do everything just + * with three sequences. In order to be so cheap we may have some + * flickering effect with some slow terminal, but the lesser sequences + * the more compatible. + * + * EL (Erase Line) + * Sequence: ESC [ n K + * Effect: if n is 0 or missing, clear from cursor to end of line + * Effect: if n is 1, clear from beginning of line to cursor + * Effect: if n is 2, clear entire line + * + * CUF (CUrsor Forward) + * Sequence: ESC [ n C + * Effect: moves cursor forward n chars + * + * CUB (CUrsor Backward) + * Sequence: ESC [ n D + * Effect: moves cursor backward n chars + * + * The following is used to get the terminal width if getting + * the width with the TIOCGWINSZ ioctl fails + * + * DSR (Device Status Report) + * Sequence: ESC [ 6 n + * Effect: reports the current cusor position as ESC [ n ; m R + * where n is the row and m is the column + * + * When multi line mode is enabled, we also use an additional escape + * sequence. However multi line editing is disabled by default. + * + * CUU (Cursor Up) + * Sequence: ESC [ n A + * Effect: moves cursor up of n chars. + * + * CUD (Cursor Down) + * Sequence: ESC [ n B + * Effect: moves cursor down of n chars. + * + * When linenoiseClearScreen() is called, two additional escape sequences + * are used in order to clear the screen and position the cursor at home + * position. + * + * CUP (Cursor position) + * Sequence: ESC [ H + * Effect: moves the cursor to upper left corner + * + * ED (Erase display) + * Sequence: ESC [ 2 J + * Effect: clear the whole screen + * + */ + +# include "linenoise.h" + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include +# include +# include + +# define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 +# define LINENOISE_MAX_LINE 4096 +static std::vector unsupported_term = { "dumb", "cons25", "emacs" }; +static linenoiseCompletionCallback *completionCallback = NULL; +static linenoiseHintsCallback *hintsCallback = NULL; +static linenoiseFreeHintsCallback *freeHintsCallback = NULL; +static char *linenoiseNoTTY(void); +static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags); +static void refreshLineWithFlags(struct linenoiseState *l, int flags); + +static struct termios orig_termios; /* In order to restore at exit.*/ +static int maskmode = 0; /* Show "***" instead of input. For passwords. */ +static int rawmode = 0; /* For atexit() function to check if restore is needed*/ +static int mlmode = 0; /* Multi line mode. Default is single line. */ +static int atexit_registered = 0; /* Register atexit just 1 time. */ +static int history_max_len = LINENOISE_DEFAULT_HISTORY_MAX_LEN; +static int history_len = 0; +static char **history = NULL; + +enum KEY_ACTION{ + KEY_NULL = 0, /* NULL */ + CTRL_A = 1, /* Ctrl+a */ + CTRL_B = 2, /* Ctrl-b */ + CTRL_C = 3, /* Ctrl-c */ + CTRL_D = 4, /* Ctrl-d */ + CTRL_E = 5, /* Ctrl-e */ + CTRL_F = 6, /* Ctrl-f */ + CTRL_H = 8, /* Ctrl-h */ + TAB = 9, /* Tab */ + CTRL_K = 11, /* Ctrl+k */ + CTRL_L = 12, /* Ctrl+l */ + ENTER = 13, /* Enter */ + CTRL_N = 14, /* Ctrl-n */ + CTRL_P = 16, /* Ctrl-p */ + CTRL_T = 20, /* Ctrl-t */ + CTRL_U = 21, /* Ctrl+u */ + CTRL_W = 23, /* Ctrl+w */ + ESC = 27, /* Escape */ + BACKSPACE = 127 /* Backspace */ +}; + +static void linenoiseAtExit(void); +int linenoiseHistoryAdd(const char *line); +#define REFRESH_CLEAN (1<<0) // Clean the old prompt from the screen +#define REFRESH_WRITE (1<<1) // Rewrite the prompt on the screen. +#define REFRESH_ALL (REFRESH_CLEAN|REFRESH_WRITE) // Do both. +static void refreshLine(struct linenoiseState *l); + +class File { + public: + FILE * file = nullptr; + + FILE * open(const std::string & filename, const char * mode) { + file = fopen(filename.c_str(), mode); + + return file; + } + + int lock() { + if (file) { + fd = fileno(file); + if (flock(fd, LOCK_EX | LOCK_NB) != 0) { + fd = -1; + + return 1; + } + } + + return 0; + } + + ~File() { + if (fd >= 0) { + flock(fd, LOCK_UN); + } + + if (file) { + fclose(file); + } + } + + private: + int fd = -1; +}; + +__attribute__((format(printf, 1, 2))) +/* Debugging function. */ +#if 0 +static void lndebug(const char *fmt, ...) { + static File file; + if (file.file == nullptr) { + file.open("/tmp/lndebug.txt", "a"); + } + + if (file.file != nullptr) { + va_list args; + va_start(args, fmt); + vfprintf(file.file, fmt, args); + va_end(args); + fflush(file.file); + } +} +#else +static void lndebug(const char *, ...) { +} +#endif + +/* ======================= Low level terminal handling ====================== */ + +/* Enable "mask mode". When it is enabled, instead of the input that + * the user is typing, the terminal will just display a corresponding + * number of asterisks, like "****". This is useful for passwords and other + * secrets that should not be displayed. */ +void linenoiseMaskModeEnable(void) { + maskmode = 1; +} + +/* Disable mask mode. */ +void linenoiseMaskModeDisable(void) { + maskmode = 0; +} + +/* Set if to use or not the multi line mode. */ +void linenoiseSetMultiLine(int ml) { + mlmode = ml; +} + +/* Return true if the terminal name is in the list of terminals we know are + * not able to understand basic escape sequences. */ +static int isUnsupportedTerm(void) { + char *term = getenv("TERM"); + if (term == NULL) return 0; + for (size_t j = 0; j < unsupported_term.size(); ++j) { + if (!strcasecmp(term, unsupported_term[j])) { + return 1; + } + } + return 0; +} + +/* Raw mode: 1960 magic shit. */ +static int enableRawMode(int fd) { + struct termios raw; + + if (!isatty(STDIN_FILENO)) goto fatal; + if (!atexit_registered) { + atexit(linenoiseAtExit); + atexit_registered = 1; + } + if (tcgetattr(fd,&orig_termios) == -1) goto fatal; + + raw = orig_termios; /* modify the original mode */ + /* input modes: no break, no CR to NL, no parity check, no strip char, + * no start/stop output control. */ + raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + /* output modes - disable post processing */ + raw.c_oflag &= ~(OPOST); + /* control modes - set 8 bit chars */ + raw.c_cflag |= (CS8); + /* local modes - choing off, canonical off, no extended functions, + * no signal chars (^Z,^C) */ + raw.c_lflag &= ~(ECHO | ICANON | IEXTEN | ISIG); + /* control chars - set return condition: min number of bytes and timer. + * We want read to return every single byte, without timeout. */ + raw.c_cc[VMIN] = 1; raw.c_cc[VTIME] = 0; /* 1 byte, no timer */ + + /* put terminal in raw mode after flushing */ + if (tcsetattr(fd,TCSAFLUSH,&raw) < 0) goto fatal; + rawmode = 1; + return 0; + +fatal: + errno = ENOTTY; + return -1; +} + +static void disableRawMode(int fd) { + /* Don't even check the return value as it's too late. */ + if (rawmode && tcsetattr(fd,TCSAFLUSH,&orig_termios) != -1) + rawmode = 0; +} + +/* Use the ESC [6n escape sequence to query the horizontal cursor position + * and return it. On error -1 is returned, on success the position of the + * cursor. */ +static int getCursorPosition(int ifd, int ofd) { + char buf[32]; + int cols, rows; + unsigned int i = 0; + + /* Report cursor location */ + if (write(ofd, "\x1b[6n", 4) != 4) return -1; + + /* Read the response: ESC [ rows ; cols R */ + while (i < sizeof(buf)-1) { + if (read(ifd,buf+i,1) != 1) break; + if (buf[i] == 'R') break; + i++; + } + buf[i] = '\0'; + + /* Parse it. */ + if (buf[0] != ESC || buf[1] != '[') return -1; + if (sscanf(buf+2,"%d;%d",&rows,&cols) != 2) return -1; + return cols; +} + +/* Try to get the number of columns in the current terminal, or assume 80 + * if it fails. */ +static int getColumns(int ifd, int ofd) { + struct winsize ws; + + if (ioctl(1, TIOCGWINSZ, &ws) == -1 || ws.ws_col == 0) { + /* ioctl() failed. Try to query the terminal itself. */ + int start, cols; + + /* Get the initial position so we can restore it later. */ + start = getCursorPosition(ifd,ofd); + if (start == -1) goto failed; + + /* Go to right margin and get position. */ + if (write(ofd,"\x1b[999C",6) != 6) goto failed; + cols = getCursorPosition(ifd,ofd); + if (cols == -1) goto failed; + + /* Restore position. */ + if (cols > start) { + char seq[32]; + snprintf(seq,32,"\x1b[%dD",cols-start); + if (write(ofd,seq,strlen(seq)) == -1) { + /* Can't recover... */ + } + } + return cols; + } else { + return ws.ws_col; + } + +failed: + return 80; +} + +/* Clear the screen. Used to handle ctrl+l */ +void linenoiseClearScreen(void) { + if (write(STDOUT_FILENO,"\x1b[H\x1b[2J",7) <= 0) { + /* nothing to do, just to avoid warning. */ + } +} + +/* Beep, used for completion when there is nothing to complete or when all + * the choices were already shown. */ +static void linenoiseBeep(void) { + fprintf(stderr, "\x7"); + fflush(stderr); +} + +/* Called by completeLine() and linenoiseShow() to render the current + * edited line with the proposed completion. If the current completion table + * is already available, it is passed as second argument, otherwise the + * function will use the callback to obtain it. + * + * Flags are the same as refreshLine*(), that is REFRESH_* macros. */ +static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags) { + /* Obtain the table of completions if the caller didn't provide one. */ + linenoiseCompletions ctable; + if (lc == NULL) { + completionCallback(ls->buf, &ctable); + lc = &ctable; + } + + /* Show the edited line with completion if possible, or just refresh. */ + if (ls->completion_idx < lc->len) { + struct linenoiseState saved = *ls; + ls->len = ls->pos = strlen(lc->cvec[ls->completion_idx]); + ls->buf = lc->cvec[ls->completion_idx]; + refreshLineWithFlags(ls, flags); + ls->len = saved.len; + ls->pos = saved.pos; + ls->buf = saved.buf; + } else { + refreshLineWithFlags(ls, flags); + } + + if (lc == &ctable) { + ctable.to_free = false; + } +} + +/* This is an helper function for linenoiseEdit*() and is called when the + * user types the key in order to complete the string currently in the + * input. + * + * The state of the editing is encapsulated into the pointed linenoiseState + * structure as described in the structure definition. + * + * If the function returns non-zero, the caller should handle the + * returned value as a byte read from the standard input, and process + * it as usually: this basically means that the function may return a byte + * read from the termianl but not processed. Otherwise, if zero is returned, + * the input was consumed by the completeLine() function to navigate the + * possible completions, and the caller should read for the next characters + * from stdin. */ +static int completeLine(struct linenoiseState *ls, int keypressed) { + linenoiseCompletions lc; + int nwritten; + char c = keypressed; + + completionCallback(ls->buf, &lc); + if (lc.len == 0) { + linenoiseBeep(); + ls->in_completion = 0; + } else { + switch(c) { + case 9: /* tab */ + if (ls->in_completion == 0) { + ls->in_completion = 1; + ls->completion_idx = 0; + } else { + ls->completion_idx = (ls->completion_idx + 1) % (lc.len + 1); + if (ls->completion_idx == lc.len) linenoiseBeep(); + } + c = 0; + break; + case 27: /* escape */ + /* Re-show original buffer */ + if (ls->completion_idx < lc.len) refreshLine(ls); + ls->in_completion = 0; + c = 0; + break; + default: + /* Update buffer and return */ + if (ls->completion_idx < lc.len) { + nwritten = snprintf(ls->buf, ls->buflen, "%s", lc.cvec[ls->completion_idx]); + ls->len = ls->pos = nwritten; + } + ls->in_completion = 0; + break; + } + + /* Show completion or original buffer */ + if (ls->in_completion && ls->completion_idx < lc.len) { + refreshLineWithCompletion(ls, &lc, REFRESH_ALL); + } else { + refreshLine(ls); + } + } + + return c; /* Return last read character */ +} + +/* Register a callback function to be called for tab-completion. */ +void linenoiseSetCompletionCallback(linenoiseCompletionCallback *fn) { + completionCallback = fn; +} + +/* Register a hits function to be called to show hits to the user at the + * right of the prompt. */ +void linenoiseSetHintsCallback(linenoiseHintsCallback *fn) { + hintsCallback = fn; +} + +/* Register a function to free the hints returned by the hints callback + * registered with linenoiseSetHintsCallback(). */ +void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *fn) { + freeHintsCallback = fn; +} + +/* This function is used by the callback function registered by the user + * in order to add completion options given the input string when the + * user typed . See the example.c source code for a very easy to + * understand example. */ +void linenoiseAddCompletion(linenoiseCompletions *lc, const char *str) { + const size_t len = strlen(str); + auto copy = std::make_unique(len + 1); + if (!copy) { + return; + } + + memcpy(copy.get(), str, len + 1); + char ** cvec = static_cast(std::realloc(lc->cvec, sizeof(char *) * (lc->len + 1))); + if (cvec == nullptr) { + return; + } + + lc->cvec = cvec; + lc->cvec[lc->len++] = copy.release(); +} + +/* Helper of refreshSingleLine() and refreshMultiLine() to show hints + * to the right of the prompt. */ +static void refreshShowHints(std::string & ab, struct linenoiseState * l, int plen) { + char seq[64]; + if (hintsCallback && plen+l->len < l->cols) { + int color = -1, bold = 0; + const char *hint = hintsCallback(l->buf,&color,&bold); + if (hint) { + int hintlen = strlen(hint); + int hintmaxlen = l->cols-(plen+l->len); + if (hintlen > hintmaxlen) hintlen = hintmaxlen; + if (bold == 1 && color == -1) color = 37; + if (color != -1 || bold != 0) + snprintf(seq,64,"\033[%d;%d;49m",bold,color); + else + seq[0] = '\0'; + ab.append(seq); + ab.append(hint, hintlen); + if (color != -1 || bold != 0) + ab.append("\033[0m"); + + /* Call the function to free the hint returned. */ + if (freeHintsCallback) freeHintsCallback(hint); + } + } +} + +/* Single line low level line refresh. + * + * Rewrite the currently edited line accordingly to the buffer content, + * cursor position, and number of columns of the terminal. + * + * Flags is REFRESH_* macros. The function can just remove the old + * prompt, just write it, or both. */ +static void refreshSingleLine(struct linenoiseState *l, int flags) { + char seq[64]; + size_t plen = strlen(l->prompt); + int fd = l->ofd; + char *buf = l->buf; + size_t len = l->len; + size_t pos = l->pos; + std::string ab; + while((plen+pos) >= l->cols) { + buf++; + len--; + pos--; + } + while (plen+len > l->cols) { + len--; + } + + /* Cursor to left edge */ + snprintf(seq,sizeof(seq),"\r"); + ab.append(seq); + + if (flags & REFRESH_WRITE) { + /* Write the prompt and the current buffer content */ + ab.append(l->prompt); + if (maskmode == 1) { + while (len--) { + ab.append("*"); + } + } else { + ab.append(buf, len); + } + /* Show hits if any. */ + refreshShowHints(ab, l, plen); + } + + /* Erase to right */ + snprintf(seq,sizeof(seq),"\x1b[0K"); + ab.append(seq); + if (flags & REFRESH_WRITE) { + /* Move cursor to original position. */ + snprintf(seq,sizeof(seq),"\r\x1b[%dC", (int)(pos+plen)); + ab.append(seq); + } + + (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ +} + +/* Multi line low level line refresh. + * + * Rewrite the currently edited line accordingly to the buffer content, + * cursor position, and number of columns of the terminal. + * + * Flags is REFRESH_* macros. The function can just remove the old + * prompt, just write it, or both. */ +static void refreshMultiLine(struct linenoiseState *l, int flags) { + char seq[64]; + int plen = strlen(l->prompt); + int rows = (plen+l->len+l->cols-1)/l->cols; /* rows used by current buf. */ + int rpos = (plen+l->oldpos+l->cols)/l->cols; /* cursor relative row. */ + int rpos2; /* rpos after refresh. */ + int col; /* colum position, zero-based. */ + int old_rows = l->oldrows; + int fd = l->ofd, j; + std::string ab; + l->oldrows = rows; + + /* First step: clear all the lines used before. To do so start by + * going to the last row. */ + if (flags & REFRESH_CLEAN) { + if (old_rows-rpos > 0) { + lndebug("go down %d", old_rows-rpos); + snprintf(seq,64,"\x1b[%dB", old_rows-rpos); + ab.append(seq); + } + + /* Now for every row clear it, go up. */ + for (j = 0; j < old_rows-1; j++) { + lndebug("clear+up"); + snprintf(seq,64,"\r\x1b[0K\x1b[1A"); + ab.append(seq); + } + } + + if (flags & REFRESH_ALL) { + /* Clean the top line. */ + lndebug("clear"); + snprintf(seq,64,"\r\x1b[0K"); + ab.append(seq); + } + + if (flags & REFRESH_WRITE) { + /* Write the prompt and the current buffer content */ + ab.append(l->prompt); + if (maskmode == 1) { + for (unsigned int i = 0; i < l->len; ++i) { + ab.append("*"); + } + } else { + ab.append(l->buf, l->len); + } + + /* Show hits if any. */ + refreshShowHints(ab, l, plen); + + /* If we are at the very end of the screen with our prompt, we need to + * emit a newline and move the prompt to the first column. */ + if (l->pos && + l->pos == l->len && + (l->pos+plen) % l->cols == 0) + { + lndebug(""); + ab.append("\n"); + snprintf(seq,64,"\r"); + ab.append(seq); + rows++; + if (rows > (int)l->oldrows) l->oldrows = rows; + } + + /* Move cursor to right position. */ + rpos2 = (plen+l->pos+l->cols)/l->cols; /* Current cursor relative row */ + lndebug("rpos2 %d", rpos2); + + /* Go up till we reach the expected positon. */ + if (rows-rpos2 > 0) { + lndebug("go-up %d", rows-rpos2); + snprintf(seq,64,"\x1b[%dA", rows-rpos2); + ab.append(seq); + } + + /* Set column. */ + col = (plen+(int)l->pos) % (int)l->cols; + lndebug("set col %d", 1+col); + if (col) + snprintf(seq,64,"\r\x1b[%dC", col); + else + snprintf(seq,64,"\r"); + ab.append(seq); + } + + lndebug("\n"); + l->oldpos = l->pos; + (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ +} + +/* Calls the two low level functions refreshSingleLine() or + * refreshMultiLine() according to the selected mode. */ +static void refreshLineWithFlags(struct linenoiseState *l, int flags) { + if (mlmode) + refreshMultiLine(l,flags); + else + refreshSingleLine(l,flags); +} + +/* Utility function to avoid specifying REFRESH_ALL all the times. */ +static void refreshLine(struct linenoiseState *l) { + refreshLineWithFlags(l,REFRESH_ALL); +} + +/* Hide the current line, when using the multiplexing API. */ +void linenoiseHide(struct linenoiseState *l) { + if (mlmode) + refreshMultiLine(l,REFRESH_CLEAN); + else + refreshSingleLine(l,REFRESH_CLEAN); +} + +/* Show the current line, when using the multiplexing API. */ +void linenoiseShow(struct linenoiseState *l) { + if (l->in_completion) { + refreshLineWithCompletion(l,NULL,REFRESH_WRITE); + } else { + refreshLineWithFlags(l,REFRESH_WRITE); + } +} + +/* Insert the character 'c' at cursor current position. + * + * On error writing to the terminal -1 is returned, otherwise 0. */ +static int linenoiseEditInsert(struct linenoiseState * l, char c) { + if (l->len < l->buflen) { + if (l->len == l->pos) { + l->buf[l->pos] = c; + l->pos++; + l->len++; + l->buf[l->len] = '\0'; + if ((!mlmode && l->plen+l->len < l->cols && !hintsCallback)) { + /* Avoid a full update of the line in the + * trivial case. */ + char d = (maskmode==1) ? '*' : c; + if (write(l->ofd,&d,1) == -1) return -1; + } else { + refreshLine(l); + } + } else { + memmove(l->buf+l->pos+1,l->buf+l->pos,l->len-l->pos); + l->buf[l->pos] = c; + l->len++; + l->pos++; + l->buf[l->len] = '\0'; + refreshLine(l); + } + } + return 0; +} + +/* Move cursor on the left. */ +static void linenoiseEditMoveLeft(struct linenoiseState * l) { + if (l->pos > 0) { + l->pos--; + refreshLine(l); + } +} + +/* Move cursor on the right. */ +static void linenoiseEditMoveRight(struct linenoiseState * l) { + if (l->pos != l->len) { + l->pos++; + refreshLine(l); + } +} + +/* Move cursor to the start of the line. */ +static void linenoiseEditMoveHome(struct linenoiseState * l) { + if (l->pos != 0) { + l->pos = 0; + refreshLine(l); + } +} + +/* Move cursor to the end of the line. */ +static void linenoiseEditMoveEnd(struct linenoiseState * l) { + if (l->pos != l->len) { + l->pos = l->len; + refreshLine(l); + } +} + +/* Substitute the currently edited line with the next or previous history + * entry as specified by 'dir'. */ +#define LINENOISE_HISTORY_NEXT 0 +#define LINENOISE_HISTORY_PREV 1 + +static void linenoiseEditHistoryNext(struct linenoiseState * l, int dir) { + if (history_len > 1) { + /* Update the current history entry before to + * overwrite it with the next one. */ + free(history[history_len - 1 - l->history_index]); + history[history_len - 1 - l->history_index] = strdup(l->buf); + /* Show the new entry */ + l->history_index += (dir == LINENOISE_HISTORY_PREV) ? 1 : -1; + if (l->history_index < 0) { + l->history_index = 0; + return; + } else if (l->history_index >= history_len) { + l->history_index = history_len-1; + return; + } + strncpy(l->buf,history[history_len - 1 - l->history_index],l->buflen); + l->buf[l->buflen-1] = '\0'; + l->len = l->pos = strlen(l->buf); + refreshLine(l); + } +} + +/* Delete the character at the right of the cursor without altering the cursor + * position. Basically this is what happens with the "Delete" keyboard key. */ +static void linenoiseEditDelete(struct linenoiseState * l) { + if (l->len > 0 && l->pos < l->len) { + memmove(l->buf+l->pos,l->buf+l->pos+1,l->len-l->pos-1); + l->len--; + l->buf[l->len] = '\0'; + refreshLine(l); + } +} + +/* Backspace implementation. */ +static void linenoiseEditBackspace(struct linenoiseState * l) { + if (l->pos > 0 && l->len > 0) { + memmove(l->buf+l->pos-1,l->buf+l->pos,l->len-l->pos); + l->pos--; + l->len--; + l->buf[l->len] = '\0'; + refreshLine(l); + } +} + +/* Delete the previosu word, maintaining the cursor at the start of the + * current word. */ +static void linenoiseEditDeletePrevWord(struct linenoiseState * l) { + size_t old_pos = l->pos; + size_t diff; + + while (l->pos > 0 && l->buf[l->pos-1] == ' ') + l->pos--; + while (l->pos > 0 && l->buf[l->pos-1] != ' ') + l->pos--; + diff = old_pos - l->pos; + memmove(l->buf+l->pos,l->buf+old_pos,l->len-old_pos+1); + l->len -= diff; + refreshLine(l); +} + +/* This function is part of the multiplexed API of Linenoise, that is used + * in order to implement the blocking variant of the API but can also be + * called by the user directly in an event driven program. It will: + * + * 1. Initialize the linenoise state passed by the user. + * 2. Put the terminal in RAW mode. + * 3. Show the prompt. + * 4. Return control to the user, that will have to call linenoiseEditFeed() + * each time there is some data arriving in the standard input. + * + * The user can also call linenoiseEditHide() and linenoiseEditShow() if it + * is required to show some input arriving asyncronously, without mixing + * it with the currently edited line. + * + * When linenoiseEditFeed() returns non-NULL, the user finished with the + * line editing session (pressed enter CTRL-D/C): in this case the caller + * needs to call linenoiseEditStop() to put back the terminal in normal + * mode. This will not destroy the buffer, as long as the linenoiseState + * is still valid in the context of the caller. + * + * The function returns 0 on success, or -1 if writing to standard output + * fails. If stdin_fd or stdout_fd are set to -1, the default is to use + * STDIN_FILENO and STDOUT_FILENO. + */ +int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) { + /* Populate the linenoise state that we pass to functions implementing + * specific editing functionalities. */ + l->in_completion = 0; + l->ifd = stdin_fd != -1 ? stdin_fd : STDIN_FILENO; + l->ofd = stdout_fd != -1 ? stdout_fd : STDOUT_FILENO; + l->buf = buf; + l->buflen = buflen; + l->prompt = prompt; + l->plen = strlen(prompt); + l->oldpos = l->pos = 0; + l->len = 0; + + /* Enter raw mode. */ + if (enableRawMode(l->ifd) == -1) return -1; + + l->cols = getColumns(stdin_fd, stdout_fd); + l->oldrows = 0; + l->history_index = 0; + + /* Buffer starts empty. */ + l->buf[0] = '\0'; + l->buflen--; /* Make sure there is always space for the nulterm */ + + /* If stdin is not a tty, stop here with the initialization. We + * will actually just read a line from standard input in blocking + * mode later, in linenoiseEditFeed(). */ + if (!isatty(l->ifd)) return 0; + + /* The latest history entry is always our current buffer, that + * initially is just an empty string. */ + linenoiseHistoryAdd(""); + + if (write(l->ofd,prompt,l->plen) == -1) return -1; + return 0; +} + +const char* linenoiseEditMore = "If you see this, you are misusing the API: when linenoiseEditFeed() is called, if it returns linenoiseEditMore the user is yet editing the line. See the README file for more information."; + +/* This function is part of the multiplexed API of linenoise, see the top + * comment on linenoiseEditStart() for more information. Call this function + * each time there is some data to read from the standard input file + * descriptor. In the case of blocking operations, this function can just be + * called in a loop, and block. + * + * The function returns linenoiseEditMore to signal that line editing is still + * in progress, that is, the user didn't yet pressed enter / CTRL-D. Otherwise + * the function returns the pointer to the heap-allocated buffer with the + * edited line, that the user should free with linenoiseFree(). + * + * On special conditions, NULL is returned and errno is populated: + * + * EAGAIN if the user pressed Ctrl-C + * ENOENT if the user pressed Ctrl-D + * + * Some other errno: I/O error. + */ +const char *linenoiseEditFeed(struct linenoiseState *l) { + /* Not a TTY, pass control to line reading without character + * count limits. */ + if (!isatty(l->ifd)) return linenoiseNoTTY(); + + char c; + int nread; + char seq[3]; + + nread = read(l->ifd,&c,1); + if (nread <= 0) return NULL; + + /* Only autocomplete when the callback is set. It returns < 0 when + * there was an error reading from fd. Otherwise it will return the + * character that should be handled next. */ + if ((l->in_completion || c == 9) && completionCallback != NULL) { + c = completeLine(l,c); + /* Read next character when 0 */ + if (c == 0) return linenoiseEditMore; + } + + switch(c) { + case ENTER: /* enter */ + history_len--; + free(history[history_len]); + if (mlmode) linenoiseEditMoveEnd(l); + if (hintsCallback) { + /* Force a refresh without hints to leave the previous + * line as the user typed it after a newline. */ + linenoiseHintsCallback *hc = hintsCallback; + hintsCallback = NULL; + refreshLine(l); + hintsCallback = hc; + } + return strdup(l->buf); + case CTRL_C: /* ctrl-c */ + errno = EAGAIN; + return NULL; + case BACKSPACE: /* backspace */ + case 8: /* ctrl-h */ + linenoiseEditBackspace(l); + break; + case CTRL_D: /* ctrl-d, remove char at right of cursor, or if the + line is empty, act as end-of-file. */ + if (l->len > 0) { + linenoiseEditDelete(l); + } else { + history_len--; + free(history[history_len]); + errno = ENOENT; + return NULL; + } + break; + case CTRL_T: /* ctrl-t, swaps current character with previous. */ + if (l->pos > 0 && l->pos < l->len) { + int aux = l->buf[l->pos-1]; + l->buf[l->pos-1] = l->buf[l->pos]; + l->buf[l->pos] = aux; + if (l->pos != l->len-1) l->pos++; + refreshLine(l); + } + break; + case CTRL_B: /* ctrl-b */ + linenoiseEditMoveLeft(l); + break; + case CTRL_F: /* ctrl-f */ + linenoiseEditMoveRight(l); + break; + case CTRL_P: /* ctrl-p */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); + break; + case CTRL_N: /* ctrl-n */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); + break; + case ESC: /* escape sequence */ + /* Read the next two bytes representing the escape sequence. + * Use two calls to handle slow terminals returning the two + * chars at different times. */ + if (read(l->ifd,seq,1) == -1) break; + if (read(l->ifd,seq+1,1) == -1) break; + + /* ESC [ sequences. */ + if (seq[0] == '[') { + if (seq[1] >= '0' && seq[1] <= '9') { + /* Extended escape, read additional byte. */ + if (read(l->ifd,seq+2,1) == -1) break; + if (seq[2] == '~') { + switch(seq[1]) { + case '3': /* Delete key. */ + linenoiseEditDelete(l); + break; + } + } + } else { + switch(seq[1]) { + case 'A': /* Up */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); + break; + case 'B': /* Down */ + linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); + break; + case 'C': /* Right */ + linenoiseEditMoveRight(l); + break; + case 'D': /* Left */ + linenoiseEditMoveLeft(l); + break; + case 'H': /* Home */ + linenoiseEditMoveHome(l); + break; + case 'F': /* End*/ + linenoiseEditMoveEnd(l); + break; + } + } + } + + /* ESC O sequences. */ + else if (seq[0] == 'O') { + switch(seq[1]) { + case 'H': /* Home */ + linenoiseEditMoveHome(l); + break; + case 'F': /* End*/ + linenoiseEditMoveEnd(l); + break; + } + } + break; + default: + if (linenoiseEditInsert(l,c)) return NULL; + break; + case CTRL_U: /* Ctrl+u, delete the whole line. */ + l->buf[0] = '\0'; + l->pos = l->len = 0; + refreshLine(l); + break; + case CTRL_K: /* Ctrl+k, delete from current to end of line. */ + l->buf[l->pos] = '\0'; + l->len = l->pos; + refreshLine(l); + break; + case CTRL_A: /* Ctrl+a, go to the start of the line */ + linenoiseEditMoveHome(l); + break; + case CTRL_E: /* ctrl+e, go to the end of the line */ + linenoiseEditMoveEnd(l); + break; + case CTRL_L: /* ctrl+l, clear screen */ + linenoiseClearScreen(); + refreshLine(l); + break; + case CTRL_W: /* ctrl+w, delete previous word */ + linenoiseEditDeletePrevWord(l); + break; + } + return linenoiseEditMore; +} + +/* This is part of the multiplexed linenoise API. See linenoiseEditStart() + * for more information. This function is called when linenoiseEditFeed() + * returns something different than NULL. At this point the user input + * is in the buffer, and we can restore the terminal in normal mode. */ +void linenoiseEditStop(struct linenoiseState *l) { + if (!isatty(l->ifd)) return; + disableRawMode(l->ifd); + printf("\n"); +} + +/* This just implements a blocking loop for the multiplexed API. + * In many applications that are not event-drivern, we can just call + * the blocking linenoise API, wait for the user to complete the editing + * and return the buffer. */ +static const char *linenoiseBlockingEdit(int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) +{ + struct linenoiseState l; + + /* Editing without a buffer is invalid. */ + if (buflen == 0) { + errno = EINVAL; + return NULL; + } + + linenoiseEditStart(&l,stdin_fd,stdout_fd,buf,buflen,prompt); + const char *res; + while((res = linenoiseEditFeed(&l)) == linenoiseEditMore); + linenoiseEditStop(&l); + return res; +} + +/* This special mode is used by linenoise in order to print scan codes + * on screen for debugging / development purposes. It is implemented + * by the linenoise_example program using the --keycodes option. */ +void linenoisePrintKeyCodes(void) { + char quit[4]; + + printf("Linenoise key codes debugging mode.\n" + "Press keys to see scan codes. Type 'quit' at any time to exit.\n"); + if (enableRawMode(STDIN_FILENO) == -1) return; + memset(quit,' ',4); + while(1) { + char c; + int nread; + + nread = read(STDIN_FILENO,&c,1); + if (nread <= 0) continue; + memmove(quit,quit+1,sizeof(quit)-1); /* shift string to left. */ + quit[sizeof(quit)-1] = c; /* Insert current char on the right. */ + if (memcmp(quit,"quit",sizeof(quit)) == 0) break; + + printf("'%c' %02x (%d) (type quit to exit)\n", + isprint(c) ? c : '?', (int)c, (int)c); + printf("\r"); /* Go left edge manually, we are in raw mode. */ + fflush(stdout); + } + disableRawMode(STDIN_FILENO); +} + +/* This function is called when linenoise() is called with the standard + * input file descriptor not attached to a TTY. So for example when the + * program using linenoise is called in pipe or with a file redirected + * to its standard input. In this case, we want to be able to return the + * line regardless of its length (by default we are limited to 4k). */ +static char *linenoiseNoTTY(void) { + char *line = NULL; + size_t len = 0, maxlen = 0; + + while(1) { + if (len == maxlen) { + if (maxlen == 0) maxlen = 16; + maxlen *= 2; + char *oldval = line; + line = (char*) realloc(line,maxlen); + if (line == NULL) { + if (oldval) free(oldval); + return NULL; + } + } + int c = fgetc(stdin); + if (c == EOF || c == '\n') { + if (c == EOF && len == 0) { + free(line); + return NULL; + } else { + line[len] = '\0'; + return line; + } + } else { + line[len] = c; + len++; + } + } +} + +/* The high level function that is the main API of the linenoise library. + * This function checks if the terminal has basic capabilities, just checking + * for a blacklist of stupid terminals, and later either calls the line + * editing function or uses dummy fgets() so that you will be able to type + * something even in the most desperate of the conditions. */ +const char *linenoise(const char *prompt) { + char buf[LINENOISE_MAX_LINE]; + + if (!isatty(STDIN_FILENO)) { + /* Not a tty: read from file / pipe. In this mode we don't want any + * limit to the line size, so we call a function to handle that. */ + return linenoiseNoTTY(); + } else if (isUnsupportedTerm()) { + size_t len; + + printf("%s",prompt); + fflush(stdout); + if (fgets(buf,LINENOISE_MAX_LINE,stdin) == NULL) return NULL; + len = strlen(buf); + while(len && (buf[len-1] == '\n' || buf[len-1] == '\r')) { + len--; + buf[len] = '\0'; + } + return strdup(buf); + } else { + const char *retval = linenoiseBlockingEdit(STDIN_FILENO,STDOUT_FILENO,buf,LINENOISE_MAX_LINE,prompt); + return retval; + } +} + +/* This is just a wrapper the user may want to call in order to make sure + * the linenoise returned buffer is freed with the same allocator it was + * created with. Useful when the main program is using an alternative + * allocator. */ +void linenoiseFree(void *ptr) { + if (ptr == linenoiseEditMore) return; // Protect from API misuse. + free(ptr); +} + +/* ================================ History ================================= */ + +/* Free the history, but does not reset it. Only used when we have to + * exit() to avoid memory leaks are reported by valgrind & co. */ +static void freeHistory(void) { + if (history) { + int j; + + for (j = 0; j < history_len; j++) + free(history[j]); + free(history); + } +} + +/* At exit we'll try to fix the terminal to the initial conditions. */ +static void linenoiseAtExit(void) { + disableRawMode(STDIN_FILENO); + freeHistory(); +} + +/* This is the API call to add a new entry in the linenoise history. + * It uses a fixed array of char pointers that are shifted (memmoved) + * when the history max length is reached in order to remove the older + * entry and make room for the new one, so it is not exactly suitable for huge + * histories, but will work well for a few hundred of entries. + * + * Using a circular buffer is smarter, but a bit more complex to handle. */ +int linenoiseHistoryAdd(const char *line) { + char *linecopy; + + if (history_max_len == 0) return 0; + + /* Initialization on first call. */ + if (history == NULL) { + history = (char**) malloc(sizeof(char*)*history_max_len); + if (history == NULL) return 0; + memset(history,0,(sizeof(char*)*history_max_len)); + } + + /* Don't add duplicated lines. */ + if (history_len && !strcmp(history[history_len-1], line)) return 0; + + /* Add an heap allocated copy of the line in the history. + * If we reached the max length, remove the older line. */ + linecopy = strdup(line); + if (!linecopy) return 0; + if (history_len == history_max_len) { + free(history[0]); + memmove(history,history+1,sizeof(char*)*(history_max_len-1)); + history_len--; + } + history[history_len] = linecopy; + history_len++; + return 1; +} + +/* Set the maximum length for the history. This function can be called even + * if there is already some history, the function will make sure to retain + * just the latest 'len' elements if the new history length value is smaller + * than the amount of items already inside the history. */ +int linenoiseHistorySetMaxLen(int len) { + char **new_ptr; + + if (len < 1) return 0; + if (history) { + int tocopy = history_len; + + new_ptr = (char**) malloc(sizeof(char*)*len); + if (new_ptr == NULL) return 0; + + /* If we can't copy everything, free the elements we'll not use. */ + if (len < tocopy) { + int j; + + for (j = 0; j < tocopy-len; j++) free(history[j]); + tocopy = len; + } + memset(new_ptr,0,sizeof(char*)*len); + memcpy(new_ptr,history+(history_len-tocopy), sizeof(char*)*tocopy); + free(history); + history = new_ptr; + } + history_max_len = len; + if (history_len > history_max_len) + history_len = history_max_len; + return 1; +} + +/* Save the history in the specified file. On success 0 is returned + * otherwise -1 is returned. */ +int linenoiseHistorySave(const char *filename) { + mode_t old_umask = umask(S_IXUSR|S_IRWXG|S_IRWXO); + File file; + file.open(filename, "w"); + umask(old_umask); + if (file.file == NULL) { + return -1; + } + chmod(filename,S_IRUSR|S_IWUSR); + for (int j = 0; j < history_len; ++j) { + fprintf(file.file, "%s\n", history[j]); + } + + return 0; +} + +/* Load the history from the specified file. If the file does not exist + * zero is returned and no operation is performed. + * + * If the file exists and the operation succeeded 0 is returned, otherwise + * on error -1 is returned. */ +int linenoiseHistoryLoad(const char *filename) { + File file; + file.open(filename, "r"); + char buf[LINENOISE_MAX_LINE]; + if (file.file == NULL) { + return -1; + } + + while (fgets(buf, LINENOISE_MAX_LINE, file.file) != NULL) { + char *p; + + p = strchr(buf,'\r'); + if (!p) p = strchr(buf,'\n'); + if (p) *p = '\0'; + linenoiseHistoryAdd(buf); + } + return 0; +} +#endif diff --git a/examples/run/linenoise.cpp/linenoise.h b/examples/run/linenoise.cpp/linenoise.h new file mode 100644 index 000000000..a14ec6c74 --- /dev/null +++ b/examples/run/linenoise.cpp/linenoise.h @@ -0,0 +1,128 @@ +/* linenoise.h -- VERSION 1.0 + * + * Guerrilla line editing library against the idea that a line editing lib + * needs to be 20,000 lines of C++ code. + * + * See linenoise.cpp for more information. + * + * ------------------------------------------------------------------------ + * + * Copyright (c) 2010-2023, Salvatore Sanfilippo + * Copyright (c) 2010-2013, Pieter Noordhuis + * Copyright (c) 2025, Eric Curtin + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef __LINENOISE_H +#define __LINENOISE_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include /* For size_t. */ +#include + +extern const char *linenoiseEditMore; + +/* The linenoiseState structure represents the state during line editing. + * We pass this state to functions implementing specific editing + * functionalities. */ +struct linenoiseState { + int in_completion; /* The user pressed TAB and we are now in completion + * mode, so input is handled by completeLine(). */ + size_t completion_idx; /* Index of next completion to propose. */ + int ifd; /* Terminal stdin file descriptor. */ + int ofd; /* Terminal stdout file descriptor. */ + char *buf; /* Edited line buffer. */ + size_t buflen; /* Edited line buffer size. */ + const char *prompt; /* Prompt to display. */ + size_t plen; /* Prompt length. */ + size_t pos; /* Current cursor position. */ + size_t oldpos; /* Previous refresh cursor position. */ + size_t len; /* Current edited line length. */ + size_t cols; /* Number of columns in terminal. */ + size_t oldrows; /* Rows used by last refrehsed line (multiline mode) */ + int history_index; /* The history index we are currently editing. */ +}; + +struct linenoiseCompletions { + size_t len = 0; + char ** cvec = nullptr; + bool to_free = true; + + ~linenoiseCompletions() { + if (!to_free) { + return; + } + + for (size_t i = 0; i < len; ++i) { + free(cvec[i]); + } + + free(cvec); + } +}; + +/* Non blocking API. */ +int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt); +const char *linenoiseEditFeed(struct linenoiseState *l); +void linenoiseEditStop(struct linenoiseState *l); +void linenoiseHide(struct linenoiseState *l); +void linenoiseShow(struct linenoiseState *l); + +/* Blocking API. */ +const char *linenoise(const char *prompt); +void linenoiseFree(void *ptr); + +/* Completion API. */ +typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *); +typedef const char*(linenoiseHintsCallback)(const char *, int *color, int *bold); +typedef void(linenoiseFreeHintsCallback)(const char *); +void linenoiseSetCompletionCallback(linenoiseCompletionCallback *); +void linenoiseSetHintsCallback(linenoiseHintsCallback *); +void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *); +void linenoiseAddCompletion(linenoiseCompletions *, const char *); + +/* History API. */ +int linenoiseHistoryAdd(const char *line); +int linenoiseHistorySetMaxLen(int len); +int linenoiseHistorySave(const char *filename); +int linenoiseHistoryLoad(const char *filename); + +/* Other utilities. */ +void linenoiseClearScreen(void); +void linenoiseSetMultiLine(int ml); +void linenoisePrintKeyCodes(void); +void linenoiseMaskModeEnable(void); +void linenoiseMaskModeDisable(void); + +#ifdef __cplusplus +} +#endif + +#endif /* __LINENOISE_H */ diff --git a/examples/run/run.cpp b/examples/run/run.cpp new file mode 100644 index 000000000..9cecae48c --- /dev/null +++ b/examples/run/run.cpp @@ -0,0 +1,1128 @@ +#if defined(_WIN32) +# include +# include +#else +# include +# include +# include +#endif + +#if defined(LLAMA_USE_CURL) +# include +#endif + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "json.hpp" +#include "linenoise.cpp/linenoise.h" +#include "llama-cpp.h" +#include "chat-template.hpp" + +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) +[[noreturn]] static void sigint_handler(int) { + printf("\n\033[0m"); + exit(0); // not ideal, but it's the only way to guarantee exit in all cases +} +#endif + +GGML_ATTRIBUTE_FORMAT(1, 2) +static std::string fmt(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + const int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::string buf; + buf.resize(size); + const int size2 = vsnprintf(const_cast(buf.data()), buf.size() + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + + return buf; +} + +GGML_ATTRIBUTE_FORMAT(1, 2) +static int printe(const char * fmt, ...) { + va_list args; + va_start(args, fmt); + const int ret = vfprintf(stderr, fmt, args); + va_end(args); + + return ret; +} + +class Opt { + public: + int init(int argc, const char ** argv) { + ctx_params = llama_context_default_params(); + model_params = llama_model_default_params(); + context_size_default = ctx_params.n_batch; + ngl_default = model_params.n_gpu_layers; + common_params_sampling sampling; + temperature_default = sampling.temp; + + if (argc < 2) { + printe("Error: No arguments provided.\n"); + print_help(); + return 1; + } + + // Parse arguments + if (parse(argc, argv)) { + printe("Error: Failed to parse arguments.\n"); + print_help(); + return 1; + } + + // If help is requested, show help and exit + if (help) { + print_help(); + return 2; + } + + ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default; + ctx_params.n_ctx = ctx_params.n_batch; + model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default; + temperature = temperature >= 0 ? temperature : temperature_default; + + return 0; // Success + } + + llama_context_params ctx_params; + llama_model_params model_params; + std::string model_; + std::string user; + bool use_jinja = false; + int context_size = -1, ngl = -1; + float temperature = -1; + bool verbose = false; + + private: + int context_size_default = -1, ngl_default = -1; + float temperature_default = -1; + bool help = false; + + bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) { + return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0; + } + + int handle_option_with_value(int argc, const char ** argv, int & i, int & option_value) { + if (i + 1 >= argc) { + return 1; + } + + option_value = std::atoi(argv[++i]); + + return 0; + } + + int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) { + if (i + 1 >= argc) { + return 1; + } + + option_value = std::atof(argv[++i]); + + return 0; + } + + int parse(int argc, const char ** argv) { + bool options_parsing = true; + for (int i = 1, positional_args_i = 0; i < argc; ++i) { + if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) { + if (handle_option_with_value(argc, argv, i, context_size) == 1) { + return 1; + } + } else if (options_parsing && + (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) { + if (handle_option_with_value(argc, argv, i, ngl) == 1) { + return 1; + } + } else if (options_parsing && strcmp(argv[i], "--temp") == 0) { + if (handle_option_with_value(argc, argv, i, temperature) == 1) { + return 1; + } + } else if (options_parsing && + (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) { + verbose = true; + } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) { + use_jinja = true; + } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { + help = true; + return 0; + } else if (options_parsing && strcmp(argv[i], "--") == 0) { + options_parsing = false; + } else if (positional_args_i == 0) { + if (!argv[i][0] || argv[i][0] == '-') { + return 1; + } + + ++positional_args_i; + model_ = argv[i]; + } else if (positional_args_i == 1) { + ++positional_args_i; + user = argv[i]; + } else { + user += " " + std::string(argv[i]); + } + } + + if (model_.empty()){ + return 1; + } + + return 0; + } + + void print_help() const { + printf( + "Description:\n" + " Runs a llm\n" + "\n" + "Usage:\n" + " llama-run [options] model [prompt]\n" + "\n" + "Options:\n" + " -c, --context-size \n" + " Context size (default: %d)\n" + " -n, -ngl, --ngl \n" + " Number of GPU layers (default: %d)\n" + " --temp \n" + " Temperature (default: %.1f)\n" + " -v, --verbose, --log-verbose\n" + " Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n" + " -h, --help\n" + " Show help message\n" + "\n" + "Commands:\n" + " model\n" + " Model is a string with an optional prefix of \n" + " huggingface:// (hf://), ollama://, https:// or file://.\n" + " If no protocol is specified and a file exists in the specified\n" + " path, file:// is assumed, otherwise if a file does not exist in\n" + " the specified path, ollama:// is assumed. Models that are being\n" + " pulled are downloaded with .partial extension while being\n" + " downloaded and then renamed as the file without the .partial\n" + " extension when complete.\n" + "\n" + "Examples:\n" + " llama-run llama3\n" + " llama-run ollama://granite-code\n" + " llama-run ollama://smollm:135m\n" + " llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" + " llama-run " + "huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" + " llama-run https://example.com/some-file1.gguf\n" + " llama-run some-file2.gguf\n" + " llama-run file://some-file3.gguf\n" + " llama-run --ngl 999 some-file4.gguf\n" + " llama-run --ngl 999 some-file5.gguf Hello World\n", + context_size_default, ngl_default, temperature_default); + } +}; + +struct progress_data { + size_t file_size = 0; + std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); + bool printed = false; +}; + +static int get_terminal_width() { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO csbi; + GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); + return csbi.srWindow.Right - csbi.srWindow.Left + 1; +#else + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + return w.ws_col; +#endif +} + +#ifdef LLAMA_USE_CURL +class File { + public: + FILE * file = nullptr; + + FILE * open(const std::string & filename, const char * mode) { + file = fopen(filename.c_str(), mode); + + return file; + } + + int lock() { + if (file) { +# ifdef _WIN32 + fd = _fileno(file); + hFile = (HANDLE) _get_osfhandle(fd); + if (hFile == INVALID_HANDLE_VALUE) { + fd = -1; + + return 1; + } + + OVERLAPPED overlapped = {}; + if (!LockFileEx(hFile, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD, + &overlapped)) { + fd = -1; + + return 1; + } +# else + fd = fileno(file); + if (flock(fd, LOCK_EX | LOCK_NB) != 0) { + fd = -1; + + return 1; + } +# endif + } + + return 0; + } + + ~File() { + if (fd >= 0) { +# ifdef _WIN32 + if (hFile != INVALID_HANDLE_VALUE) { + OVERLAPPED overlapped = {}; + UnlockFileEx(hFile, 0, MAXDWORD, MAXDWORD, &overlapped); + } +# else + flock(fd, LOCK_UN); +# endif + } + + if (file) { + fclose(file); + } + } + + private: + int fd = -1; +# ifdef _WIN32 + HANDLE hFile = nullptr; +# endif +}; + +class HttpClient { + public: + int init(const std::string & url, const std::vector & headers, const std::string & output_file, + const bool progress, std::string * response_str = nullptr) { + if (std::filesystem::exists(output_file)) { + return 0; + } + + std::string output_file_partial; + curl = curl_easy_init(); + if (!curl) { + return 1; + } + + progress_data data; + File out; + if (!output_file.empty()) { + output_file_partial = output_file + ".partial"; + if (!out.open(output_file_partial, "ab")) { + printe("Failed to open file\n"); + + return 1; + } + + if (out.lock()) { + printe("Failed to exclusively lock file\n"); + + return 1; + } + } + + set_write_options(response_str, out); + data.file_size = set_resume_point(output_file_partial); + set_progress_options(progress, data); + set_headers(headers); + CURLcode res = perform(url); + if (res != CURLE_OK){ + printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res)); + return 1; + } + if (!output_file.empty()) { + std::filesystem::rename(output_file_partial, output_file); + } + + return 0; + } + + ~HttpClient() { + if (chunk) { + curl_slist_free_all(chunk); + } + + if (curl) { + curl_easy_cleanup(curl); + } + } + + private: + CURL * curl = nullptr; + struct curl_slist * chunk = nullptr; + + void set_write_options(std::string * response_str, const File & out) { + if (response_str) { + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str); + } else { + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.file); + } + } + + size_t set_resume_point(const std::string & output_file) { + size_t file_size = 0; + if (std::filesystem::exists(output_file)) { + file_size = std::filesystem::file_size(output_file); + curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast(file_size)); + } + + return file_size; + } + + void set_progress_options(bool progress, progress_data & data) { + if (progress) { + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data); + curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, update_progress); + } + } + + void set_headers(const std::vector & headers) { + if (!headers.empty()) { + if (chunk) { + curl_slist_free_all(chunk); + chunk = 0; + } + + for (const auto & header : headers) { + chunk = curl_slist_append(chunk, header.c_str()); + } + + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk); + } + } + + CURLcode perform(const std::string & url) { + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https"); + curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); + return curl_easy_perform(curl); + } + + static std::string human_readable_time(double seconds) { + int hrs = static_cast(seconds) / 3600; + int mins = (static_cast(seconds) % 3600) / 60; + int secs = static_cast(seconds) % 60; + + if (hrs > 0) { + return fmt("%dh %02dm %02ds", hrs, mins, secs); + } else if (mins > 0) { + return fmt("%dm %02ds", mins, secs); + } else { + return fmt("%ds", secs); + } + } + + static std::string human_readable_size(curl_off_t size) { + static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" }; + char length = sizeof(suffix) / sizeof(suffix[0]); + int i = 0; + double dbl_size = size; + if (size > 1024) { + for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) { + dbl_size = size / 1024.0; + } + } + + return fmt("%.2f %s", dbl_size, suffix[i]); + } + + static int update_progress(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t, + curl_off_t) { + progress_data * data = static_cast(ptr); + if (total_to_download <= 0) { + return 0; + } + + total_to_download += data->file_size; + const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size; + const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download); + std::string progress_prefix = generate_progress_prefix(percentage); + + const double speed = calculate_speed(now_downloaded, data->start_time); + const double tim = (total_to_download - now_downloaded) / speed; + std::string progress_suffix = + generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, tim); + + int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix); + std::string progress_bar; + generate_progress_bar(progress_bar_width, percentage, progress_bar); + + print_progress(progress_prefix, progress_bar, progress_suffix); + data->printed = true; + + return 0; + } + + static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) { + return (now_downloaded_plus_file_size * 100) / total_to_download; + } + + static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", static_cast(percentage)); } + + static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) { + const auto now = std::chrono::steady_clock::now(); + const std::chrono::duration elapsed_seconds = now - start_time; + return now_downloaded / elapsed_seconds.count(); + } + + static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download, + double speed, double estimated_time) { + const int width = 10; + return fmt("%*s/%*s%*s/s%*s", width, human_readable_size(now_downloaded_plus_file_size).c_str(), width, + human_readable_size(total_to_download).c_str(), width, human_readable_size(speed).c_str(), width, + human_readable_time(estimated_time).c_str()); + } + + static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) { + int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 3; + if (progress_bar_width < 1) { + progress_bar_width = 1; + } + + return progress_bar_width; + } + + static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage, + std::string & progress_bar) { + const curl_off_t pos = (percentage * progress_bar_width) / 100; + for (int i = 0; i < progress_bar_width; ++i) { + progress_bar.append((i < pos) ? "█" : " "); + } + + return progress_bar; + } + + static void print_progress(const std::string & progress_prefix, const std::string & progress_bar, + const std::string & progress_suffix) { + printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(), + progress_suffix.c_str()); + } + // Function to write data to a file + static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) { + FILE * out = static_cast(stream); + return fwrite(ptr, size, nmemb, out); + } + + // Function to capture data into a string + static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) { + std::string * str = static_cast(stream); + str->append(static_cast(ptr), size * nmemb); + return size * nmemb; + } +}; +#endif + +class LlamaData { + public: + llama_model_ptr model; + llama_sampler_ptr sampler; + llama_context_ptr context; + std::vector messages; + std::list msg_strs; + std::vector fmtted; + + int init(Opt & opt) { + model = initialize_model(opt); + if (!model) { + return 1; + } + + context = initialize_context(model, opt); + if (!context) { + return 1; + } + + sampler = initialize_sampler(opt); + + return 0; + } + + private: +#ifdef LLAMA_USE_CURL + int download(const std::string & url, const std::string & output_file, const bool progress, + const std::vector & headers = {}, std::string * response_str = nullptr) { + HttpClient http; + if (http.init(url, headers, output_file, progress, response_str)) { + return 1; + } + + return 0; + } +#else + int download(const std::string &, const std::string &, const bool, const std::vector & = {}, + std::string * = nullptr) { + printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); + + return 1; + } +#endif + + // Helper function to handle model tag extraction and URL construction + std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { + std::string model_tag = "latest"; + const size_t colon_pos = model.find(':'); + if (colon_pos != std::string::npos) { + model_tag = model.substr(colon_pos + 1); + model = model.substr(0, colon_pos); + } + + std::string url = base_url + model + "/manifests/" + model_tag; + + return { model, url }; + } + + // Helper function to download and parse the manifest + int download_and_parse_manifest(const std::string & url, const std::vector & headers, + nlohmann::json & manifest) { + std::string manifest_str; + int ret = download(url, "", false, headers, &manifest_str); + if (ret) { + return ret; + } + + manifest = nlohmann::json::parse(manifest_str); + + return 0; + } + + int huggingface_dl(std::string & model, const std::string & bn) { + // Find the second occurrence of '/' after protocol string + size_t pos = model.find('/'); + pos = model.find('/', pos + 1); + std::string hfr, hff; + std::vector headers = { "User-Agent: llama-cpp", "Accept: application/json" }; + std::string url; + + if (pos == std::string::npos) { + auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/"); + hfr = model_name; + + nlohmann::json manifest; + int ret = download_and_parse_manifest(manifest_url, headers, manifest); + if (ret) { + return ret; + } + + hff = manifest["ggufFile"]["rfilename"]; + } else { + hfr = model.substr(0, pos); + hff = model.substr(pos + 1); + } + + url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; + + return download(url, bn, true, headers); + } + + int ollama_dl(std::string & model, const std::string & bn) { + const std::vector headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" }; + if (model.find('/') == std::string::npos) { + model = "library/" + model; + } + + auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/"); + nlohmann::json manifest; + int ret = download_and_parse_manifest(manifest_url, {}, manifest); + if (ret) { + return ret; + } + + std::string layer; + for (const auto & l : manifest["layers"]) { + if (l["mediaType"] == "application/vnd.ollama.image.model") { + layer = l["digest"]; + break; + } + } + + std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer; + + return download(blob_url, bn, true, headers); + } + + int github_dl(const std::string & model, const std::string & bn) { + std::string repository = model; + std::string branch = "main"; + const size_t at_pos = model.find('@'); + if (at_pos != std::string::npos) { + repository = model.substr(0, at_pos); + branch = model.substr(at_pos + 1); + } + + const std::vector repo_parts = string_split(repository, "/"); + if (repo_parts.size() < 3) { + printe("Invalid GitHub repository format\n"); + return 1; + } + + const std::string & org = repo_parts[0]; + const std::string & project = repo_parts[1]; + std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch; + for (size_t i = 2; i < repo_parts.size(); ++i) { + url += "/" + repo_parts[i]; + } + + return download(url, bn, true); + } + + std::string basename(const std::string & path) { + const size_t pos = path.find_last_of("/\\"); + if (pos == std::string::npos) { + return path; + } + + return path.substr(pos + 1); + } + + int rm_until_substring(std::string & model_, const std::string & substring) { + const std::string::size_type pos = model_.find(substring); + if (pos == std::string::npos) { + return 1; + } + + model_ = model_.substr(pos + substring.size()); // Skip past the substring + return 0; + } + + int resolve_model(std::string & model_) { + int ret = 0; + if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) { + rm_until_substring(model_, "://"); + + return ret; + } + + const std::string bn = basename(model_); + if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") || + string_starts_with(model_, "hf.co/")) { + rm_until_substring(model_, "hf.co/"); + rm_until_substring(model_, "://"); + ret = huggingface_dl(model_, bn); + } else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) && + !string_starts_with(model_, "https://ollama.com/library/")) { + ret = download(model_, bn, true); + } else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) { + rm_until_substring(model_, "github:"); + rm_until_substring(model_, "://"); + ret = github_dl(model_, bn); + } else { // ollama:// or nothing + rm_until_substring(model_, "ollama.com/library/"); + rm_until_substring(model_, "://"); + ret = ollama_dl(model_, bn); + } + + model_ = bn; + + return ret; + } + + // Initializes the model and returns a unique pointer to it + llama_model_ptr initialize_model(Opt & opt) { + ggml_backend_load_all(); + resolve_model(opt.model_); + printe( + "\r%*s" + "\rLoading model", + get_terminal_width(), " "); + llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params)); + if (!model) { + printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str()); + } + + printe("\r%*s\r", static_cast(sizeof("Loading model")), " "); + return model; + } + + // Initializes the context with the specified parameters + llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { + llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params)); + if (!context) { + printe("%s: error: failed to create the llama_context\n", __func__); + } + + return context; + } + + // Initializes and configures the sampler + llama_sampler_ptr initialize_sampler(const Opt & opt) { + llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + + return sampler; + } +}; + +// Add a message to `messages` and store its content in `msg_strs` +static void add_message(const char * role, const std::string & text, LlamaData & llama_data) { + llama_data.msg_strs.push_back(std::move(text)); + llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); +} + +// Function to apply the chat template and resize `formatted` if needed +static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { + if (use_jinja) { + json messages = json::array(); + for (const auto & msg : llama_data.messages) { + messages.push_back({ + {"role", msg.role}, + {"content", msg.content}, + }); + } + try { + auto result = tmpl.apply(messages, /* tools= */ json(), append); + llama_data.fmtted.resize(result.size() + 1); + memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); + return result.size(); + } catch (const std::exception & e) { + printe("failed to render the chat template: %s\n", e.what()); + return -1; + } + } + int result = llama_chat_apply_template( + tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, + append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); + if (append && result > static_cast(llama_data.fmtted.size())) { + llama_data.fmtted.resize(result); + result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(), + llama_data.messages.size(), append, llama_data.fmtted.data(), + llama_data.fmtted.size()); + } + + return result; +} + +// Function to tokenize the prompt +static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, + std::vector & prompt_tokens, const LlamaData & llama_data) { + const bool is_first = llama_get_kv_cache_used_cells(llama_data.context.get()) == 0; + + const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); + prompt_tokens.resize(n_prompt_tokens); + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, + true) < 0) { + printe("failed to tokenize the prompt\n"); + return -1; + } + + return n_prompt_tokens; +} + +// Check if we have enough space in the context to evaluate this batch +static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { + const int n_ctx = llama_n_ctx(ctx.get()); + const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get()); + if (n_ctx_used + batch.n_tokens > n_ctx) { + printf("\033[0m\n"); + printe("context size exceeded\n"); + return 1; + } + + return 0; +} + +// convert the token to a string +static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) { + char buf[256]; + int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true); + if (n < 0) { + printe("failed to convert token to piece\n"); + return 1; + } + + piece = std::string(buf, n); + return 0; +} + +static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) { + printf("%s", piece.c_str()); + fflush(stdout); + response += piece; +} + +// helper function to evaluate a prompt and generate a response +static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { + const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); + + std::vector tokens; + if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) { + return 1; + } + + // prepare a batch for the prompt + llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); + llama_token new_token_id; + while (true) { + check_context_size(llama_data.context, batch); + if (llama_decode(llama_data.context.get(), batch)) { + printe("failed to decode\n"); + return 1; + } + + // sample the next token, check is it an end of generation? + new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); + if (llama_vocab_is_eog(vocab, new_token_id)) { + break; + } + + std::string piece; + if (convert_token_to_string(vocab, new_token_id, piece)) { + return 1; + } + + print_word_and_concatenate_to_response(piece, response); + + // prepare the next batch with the sampled token + batch = llama_batch_get_one(&new_token_id, 1); + } + + printf("\033[0m"); + return 0; +} + +static int read_user_input(std::string & user_input) { + static const char * prompt_prefix = "> "; +#ifdef WIN32 + printf( + "\r%*s" + "\r\033[0m%s", + get_terminal_width(), " ", prompt_prefix); + + std::getline(std::cin, user_input); + if (std::cin.eof()) { + printf("\n"); + return 1; + } +#else + std::unique_ptr line(const_cast(linenoise(prompt_prefix)), free); + if (!line) { + return 1; + } + + user_input = line.get(); +#endif + + if (user_input == "/bye") { + return 1; + } + + if (user_input.empty()) { + return 2; + } + +#ifndef WIN32 + linenoiseHistoryAdd(line.get()); +#endif + + return 0; // Should have data in happy path +} + +// Function to generate a response based on the prompt +static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response, + const bool stdout_a_terminal) { + // Set response color + if (stdout_a_terminal) { + printf("\033[33m"); + } + + if (generate(llama_data, prompt, response)) { + printe("failed to generate response\n"); + return 1; + } + + // End response with color reset and newline + printf("\n%s", stdout_a_terminal ? "\033[0m" : ""); + return 0; +} + +// Helper function to apply the chat template and handle errors +static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { + const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); + if (new_len < 0) { + printe("failed to apply the chat template\n"); + return -1; + } + + output_length = new_len; + return 0; +} + +// Helper function to handle user input +static int handle_user_input(std::string & user_input, const std::string & user) { + if (!user.empty()) { + user_input = user; + return 0; // No need for interactive input + } + + return read_user_input(user_input); // Returns true if input ends the loop +} + +static bool is_stdin_a_terminal() { +#if defined(_WIN32) + HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); + DWORD mode; + return GetConsoleMode(hStdin, &mode); +#else + return isatty(STDIN_FILENO); +#endif +} + +static bool is_stdout_a_terminal() { +#if defined(_WIN32) + HANDLE hStdout = GetStdHandle(STD_OUTPUT_HANDLE); + DWORD mode; + return GetConsoleMode(hStdout, &mode); +#else + return isatty(STDOUT_FILENO); +#endif +} + +// Function to handle user input +static int get_user_input(std::string & user_input, const std::string & user) { + while (true) { + const int ret = handle_user_input(user_input, user); + if (ret == 1) { + return 1; + } + + if (ret == 2) { + continue; + } + + break; + } + + return 0; +} + +// Main chat loop function +static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { + int prev_len = 0; + llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); + auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), ""); + GGML_ASSERT(chat_templates.template_default); + static const bool stdout_a_terminal = is_stdout_a_terminal(); + while (true) { + // Get user input + std::string user_input; + if (get_user_input(user_input, user) == 1) { + return 0; + } + + add_message("user", user.empty() ? user_input : user, llama_data); + int new_len; + if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) { + return 1; + } + + std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); + std::string response; + if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { + return 1; + } + + if (!user.empty()) { + break; + } + + add_message("assistant", response, llama_data); + if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) { + return 1; + } + } + + return 0; +} + +static void log_callback(const enum ggml_log_level level, const char * text, void * p) { + const Opt * opt = static_cast(p); + if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) { + printe("%s", text); + } +} + +static std::string read_pipe_data() { + std::ostringstream result; + result << std::cin.rdbuf(); // Read all data from std::cin + return result.str(); +} + +static void ctrl_c_handling() { +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset(&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined(_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif +} + +int main(int argc, const char ** argv) { + ctrl_c_handling(); + Opt opt; + const int ret = opt.init(argc, argv); + if (ret == 2) { + return 0; + } else if (ret) { + return 1; + } + + if (!is_stdin_a_terminal()) { + if (!opt.user.empty()) { + opt.user += "\n\n"; + } + + opt.user += read_pipe_data(); + } + + llama_log_set(log_callback, &opt); + LlamaData llama_data; + if (llama_data.init(opt)) { + return 1; + } + + if (chat_loop(llama_data, opt.user, opt.use_jinja)) { + return 1; + } + + return 0; +} diff --git a/examples/save-load-state/CMakeLists.txt b/examples/save-load-state/CMakeLists.txt index 0fb5e359b..0f50e50de 100644 --- a/examples/save-load-state/CMakeLists.txt +++ b/examples/save-load-state/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-save-load-state) add_executable(${TARGET} save-load-state.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index b54ec3bd8..cf7cbd815 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -1,3 +1,4 @@ +#include "arg.h" #include "common.h" #include "llama.h" @@ -5,13 +6,12 @@ #include int main(int argc, char ** argv) { - gpt_params params; + common_params params; params.prompt = "The quick brown fox"; - params.sparams.seed = 1234; + params.sampling.seed = 1234; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } @@ -28,10 +28,10 @@ int main(int argc, char ** argv) { std::string result2; // init - llama_init_result llama_init = llama_init_from_gpt_params(params); + common_init_result llama_init = common_init_from_params(params); - llama_model * model = llama_init.model; - llama_context * ctx = llama_init.context; + llama_model * model = llama_init.model.get(); + llama_context * ctx = llama_init.context.get(); if (model == nullptr || ctx == nullptr) { fprintf(stderr, "%s : failed to init\n", __func__); @@ -42,15 +42,21 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl, llama_sampler_init_softmax()); - llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed)); // tokenize prompt - auto tokens = llama_tokenize(ctx, params.prompt, true); + auto tokens = common_tokenize(ctx, params.prompt, true); + + // prepare the batch + llama_batch batch = llama_batch_init(tokens.size(), 0, 1); + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], i, {0}, false); + } + batch.logits[batch.n_tokens - 1] = true; // generate next token // evaluate prompt - llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0)); - n_past += tokens.size(); + llama_decode(ctx, batch); + n_past += batch.n_tokens; // save state (rng, logits, embedding and kv_cache) to file { @@ -72,17 +78,17 @@ int main(int argc, char ** argv) { for (auto i = 0; i < params.n_predict; i++) { auto next_token = llama_sampler_sample(smpl, ctx, -1); - auto next_token_str = llama_token_to_piece(ctx, next_token); - - llama_sampler_accept(smpl, next_token); + auto next_token_str = common_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); result0 += next_token_str; - if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {0}, true); + + if (llama_decode(ctx, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_free(ctx); - llama_free_model(model); + llama_batch_free(batch); return 1; } n_past += 1; @@ -90,16 +96,12 @@ int main(int argc, char ** argv) { printf("\n\n"); - // free old context - llama_free(ctx); - // make new context - auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params)); llama_sampler * smpl2 = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl2, llama_sampler_init_softmax()); - llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed)); + llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sampling.seed)); printf("\nsecond run: %s", params.prompt.c_str()); @@ -116,8 +118,6 @@ int main(int argc, char ** argv) { if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) { fprintf(stderr, "\n%s : failed to read state\n", __func__); - llama_free(ctx2); - llama_free_model(model); return 1; } @@ -130,17 +130,17 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { auto next_token = llama_sampler_sample(smpl2, ctx2, -1); - auto next_token_str = llama_token_to_piece(ctx2, next_token); - - llama_sampler_accept(smpl2, next_token); + auto next_token_str = common_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); result1 += next_token_str; - if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) { + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {0}, true); + + if (llama_decode(ctx2, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_free(ctx2); - llama_free_model(model); + llama_batch_free(batch); return 1; } n_past += 1; @@ -148,20 +148,17 @@ int main(int argc, char ** argv) { printf("\n\n"); - llama_free(ctx2); - if (result0 != result1) { fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__); return 1; } // make new context - auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params)); llama_sampler * smpl3 = llama_sampler_chain_init(sparams); - llama_sampler_chain_add(smpl3, llama_sampler_init_softmax()); - llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed)); + llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed)); printf("\nsingle seq run: %s", params.prompt.c_str()); @@ -178,8 +175,6 @@ int main(int argc, char ** argv) { if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) { fprintf(stderr, "\n%s : failed to read state\n", __func__); - llama_free(ctx3); - llama_free_model(model); return 1; } @@ -196,8 +191,6 @@ int main(int argc, char ** argv) { const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0); if (ncopy != seq_store.size()) { fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); - llama_free(ctx3); - llama_free_model(model); return 1; } fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); @@ -210,8 +203,6 @@ int main(int argc, char ** argv) { const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1); if (nset != seq_store.size()) { fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); - llama_free(ctx3); - llama_free_model(model); return 1; } fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset); @@ -220,17 +211,17 @@ int main(int argc, char ** argv) { // third run with seq 1 instead of 0 for (auto i = 0; i < params.n_predict; i++) { auto next_token = llama_sampler_sample(smpl3, ctx3, -1); - auto next_token_str = llama_token_to_piece(ctx3, next_token); - - llama_sampler_accept(smpl3, next_token); + auto next_token_str = common_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); result2 += next_token_str; - if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {1}, true); + + if (llama_decode(ctx3, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_free(ctx3); - llama_free_model(model); + llama_batch_free(batch); return 1; } n_past += 1; @@ -242,8 +233,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl2); llama_sampler_free(smpl3); - llama_free(ctx3); - llama_free_model(model); + llama_batch_free(batch); if (result0 != result2) { fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index dbe41f1fd..1b7cc8c13 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -1,6 +1,6 @@ set(TARGET llama-server) -option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) -option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF) + +option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF) include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) @@ -15,21 +15,8 @@ set(TARGET_SRCS httplib.h ) set(PUBLIC_ASSETS - colorthemes.css - style.css - theme-beeninorder.css - theme-ketivah.css - theme-mangotango.css - theme-playground.css - theme-polarnight.css - theme-snowstorm.css - index.html - index-new.html - index.js - completion.js - system-prompts.js - prompt-formats.js - json-schema-to-grammar.mjs + index.html.gz + loading.html ) foreach(asset ${PUBLIC_ASSETS}) @@ -41,14 +28,13 @@ foreach(asset ${PUBLIC_ASSETS}) OUTPUT "${output}" COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake" ) + set_source_files_properties(${output} PROPERTIES GENERATED TRUE) endforeach() add_executable(${TARGET} ${TARGET_SRCS}) install(TARGETS ${TARGET} RUNTIME) -target_compile_definitions(${TARGET} PRIVATE - SERVER_VERBOSE=$ -) +target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) if (LLAMA_SERVER_SSL) @@ -61,4 +47,4 @@ if (WIN32) TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) endif() -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/server/README.md b/examples/server/README.md index ed1201ba8..ce1ae8858 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * Reranking endoint (WIP: https://github.com/ggerganov/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching * Multimodal (wip) @@ -17,91 +18,60 @@ The project is under active development, and we are [looking for feedback and co ## Usage + + +**Common params** + | Argument | Explanation | | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | -| `-v, --verbose` | print verbose information | -| `--verbosity N` | set specific verbosity level (default: 0) | | `--verbose-prompt` | print a verbose prompt before generation (default: false) | -| `--no-display-prompt` | don't print prompt at generation (default: false) | -| `-s, --seed SEED` | RNG seed (default: -1, use random seed for < 0) | | `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | | `-Cr, --cpu-range lo-hi` | range of CPUs for affinity. Complements --cpu-mask | | `--cpu-strict <0\|1>` | use strict CPU placement (default: 0)
| +| `--prio N` | set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: 0)
| | `--poll <0...100>` | use polling level to wait for work (0 - no polling, default: 50)
| | `-Cb, --cpu-mask-batch M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range-batch (default: same as --cpu-mask) | | `-Crb, --cpu-range-batch lo-hi` | ranges of CPUs for affinity. Complements --cpu-mask-batch | | `--cpu-strict-batch <0\|1>` | use strict CPU placement (default: same as --cpu-strict) | +| `--prio-batch N` | set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: 0)
| | `--poll-batch <0\|1>` | use polling to wait for work (default: same as --poll) | -| `-lcs, --lookup-cache-static FNAME` | path to static lookup cache to use for lookup decoding (not updated by generation) | -| `-lcd, --lookup-cache-dynamic FNAME` | path to dynamic lookup cache to use for lookup decoding (updated by generation) | -| `-c, --ctx-size N` | size of the prompt context (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE) | +| `-c, --ctx-size N` | size of the prompt context (default: 4096, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE) | | `-n, --predict, --n-predict N` | number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled)
(env: LLAMA_ARG_N_PREDICT) | | `-b, --batch-size N` | logical maximum batch size (default: 2048)
(env: LLAMA_ARG_BATCH) | | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | -| `--chunks N` | max number of chunks to process (default: -1, -1 = all) | | `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) | -| `-p, --prompt PROMPT` | prompt to start generation with | -| `-f, --file FNAME` | a file containing the prompt (default: none) | -| `--in-file FNAME` | an input file (repeat to specify multiple files) | -| `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) | +| `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--no-escape` | do not process escape sequences | -| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | -| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: top_k;tfs_z;typical_p;top_p;min_p;temperature) | -| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: kfypmt) | -| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | -| `--penalize-nl` | penalize newline tokens (default: false) | -| `--temp N` | temperature (default: 0.8) | -| `--top-k N` | top-k sampling (default: 40, 0 = disabled) | -| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | -| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) | -| `--tfs N` | tail free sampling, parameter z (default: 1.0, 1.0 = disabled) | -| `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) | -| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | -| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) | -| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) | -| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) | -| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) | -| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) | -| `--mirostat N` | use Mirostat sampling.
Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | -| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.1) | -| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.0) | -| `-l, --logit-bias TOKEN_ID(+/-)BIAS` | modifies the likelihood of token appearing in the completion,
i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',
or `--logit-bias 15043-1` to decrease likelihood of token ' Hello' | -| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | -| `--grammar-file FNAME` | file to read grammar from | -| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | -| `--rope-scaling {none,linear,yarn}` | RoPE frequency scaling method, defaults to linear unless specified by the model | -| `--rope-scale N` | RoPE context scaling factor, expands context by a factor of N | -| `--rope-freq-base N` | RoPE base frequency, used by NTK-aware scaling (default: loaded from model) | -| `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N | -| `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size) | -| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation) | -| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0) | -| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0) | -| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0) | -| `-gan, --grp-attn-n N` | group-attention factor (default: 1) | -| `-gaw, --grp-attn-w N` | group-attention width (default: 512.0) | +| `--rope-scaling {none,linear,yarn}` | RoPE frequency scaling method, defaults to linear unless specified by the model
(env: LLAMA_ARG_ROPE_SCALING_TYPE) | +| `--rope-scale N` | RoPE context scaling factor, expands context by a factor of N
(env: LLAMA_ARG_ROPE_SCALE) | +| `--rope-freq-base N` | RoPE base frequency, used by NTK-aware scaling (default: loaded from model)
(env: LLAMA_ARG_ROPE_FREQ_BASE) | +| `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N
(env: LLAMA_ARG_ROPE_FREQ_SCALE) | +| `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)
(env: LLAMA_ARG_YARN_ORIG_CTX) | +| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | +| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | +| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | +| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | | `-dkvc, --dump-kv-cache` | verbose print of the KV cache | -| `-nkvo, --no-kv-offload` | disable KV offload | -| `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16) | -| `-ctv, --cache-type-v TYPE` | KV cache data type for V (default: f16) | -| `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: -1.0, < 0 - disabled)
(env: LLAMA_ARG_DEFRAG_THOLD) | -| `-np, --parallel N` | number of parallel sequences to decode (default: 1) | -| `-ns, --sequences N` | number of sequences to decode (default: 1) | -| `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | -| `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | -| `--mlock` | force system to keep model in RAM rather than swapping or compressing | -| `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock) | -| `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggerganov/llama.cpp/issues/1437 | -| `-ngl, --gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) | -| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs | -| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1 | -| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0) | +| `-nkvo, --no-kv-offload` | disable KV offload
(env: LLAMA_ARG_NO_KV_OFFLOAD) | +| `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | +| `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | +| `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: 0.1, < 0 - disabled)
(env: LLAMA_ARG_DEFRAG_THOLD) | +| `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | +| `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | +| `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock)
(env: LLAMA_ARG_NO_MMAP) | +| `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggerganov/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | +| `-dev, --device ` | comma-separated list of devices to use for offloading (none = don't offload)
use --list-devices to see a list of available devices
(env: LLAMA_ARG_DEVICE) | +| `--list-devices` | print list of available devices and exit | +| `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) | +| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
(env: LLAMA_ARG_SPLIT_MODE) | +| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
(env: LLAMA_ARG_TENSOR_SPLIT) | +| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)
(env: LLAMA_ARG_MAIN_GPU) | | `--check-tensors` | check model tensor data for invalid values (default: false) | | `--override-kv KEY=TYPE:VALUE` | advanced option to override model metadata by key. may be specified multiple times.
types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false | | `--lora FNAME` | path to LoRA adapter (can be repeated to use multiple adapters) | @@ -109,37 +79,96 @@ The project is under active development, and we are [looking for feedback and co | `--control-vector FNAME` | add a control vector
note: this argument can be repeated to add multiple control vectors | | `--control-vector-scaled FNAME SCALE` | add a control vector with user defined scaling SCALE
note: this argument can be repeated to add multiple scaled control vectors | | `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive | -| `-a, --alias STRING` | set alias for model name (to be used by REST API)
(env: LLAMA_ARG_MODEL) | | `-m, --model FNAME` | model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set, otherwise models/7B/ggml-model-f16.gguf)
(env: LLAMA_ARG_MODEL) | | `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) | | `-hfr, --hf-repo REPO` | Hugging Face model repository (default: unused)
(env: LLAMA_ARG_HF_REPO) | | `-hff, --hf-file FILE` | Hugging Face model file (default: unused)
(env: LLAMA_ARG_HF_FILE) | | `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) | +| `--log-disable` | Log disable | +| `--log-file FNAME` | Log to file | +| `--log-colors` | Enable colored logging
(env: LLAMA_LOG_COLORS) | +| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | +| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.
(env: LLAMA_LOG_VERBOSITY) | +| `--log-prefix` | Enable prefx in log messages
(env: LLAMA_LOG_PREFIX) | +| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | + + +**Sampling params** + +| Argument | Explanation | +| -------- | ----------- | +| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: dry;top_k;typ_p;top_p;min_p;xtc;temperature) | +| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | +| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) | +| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | +| `--temp N` | temperature (default: 0.8) | +| `--top-k N` | top-k sampling (default: 40, 0 = disabled) | +| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) | +| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) | +| `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) | +| `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) | +| `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) | +| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | +| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) | +| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) | +| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) | +| `--dry-multiplier N` | set DRY sampling multiplier (default: 0.0, 0.0 = disabled) | +| `--dry-base N` | set DRY sampling base value (default: 1.75) | +| `--dry-allowed-length N` | set allowed length for DRY sampling (default: 2) | +| `--dry-penalty-last-n N` | set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) | +| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers ('\n', ':', '"', '*') in the process; use "none" to not use any sequence breakers
| +| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) | +| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) | +| `--mirostat N` | use Mirostat sampling.
Top K, Nucleus and Locally Typical samplers are ignored if used.
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | +| `--mirostat-lr N` | Mirostat learning rate, parameter eta (default: 0.1) | +| `--mirostat-ent N` | Mirostat target entropy, parameter tau (default: 5.0) | +| `-l, --logit-bias TOKEN_ID(+/-)BIAS` | modifies the likelihood of token appearing in the completion,
i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',
or `--logit-bias 15043-1` to decrease likelihood of token ' Hello' | +| `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | +| `--grammar-file FNAME` | file to read grammar from | +| `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | +| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) | + +**Example-specific params** + +| Argument | Explanation | +| -------- | ----------- | +| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | +| `-sp, --special` | special tokens output enabled (default: false) | +| `--no-warmup` | skip warming up the model with an empty run | +| `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | +| `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | +| `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | +| `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | +| `-a, --alias STRING` | set alias for model name (to be used by REST API)
(env: LLAMA_ARG_ALIAS) | | `--host HOST` | ip address to listen (default: 127.0.0.1)
(env: LLAMA_ARG_HOST) | | `--port PORT` | port to listen (default: 8080)
(env: LLAMA_ARG_PORT) | -| `--path PATH` | path to serve static files from (default: ) | +| `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | +| `--no-webui` | Disable the Web UI (default: enabled)
(env: LLAMA_ARG_NO_WEBUI) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | +| `--reranking, --rerank` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication (default: none)
(env: LLAMA_API_KEY) | | `--api-key-file FNAME` | path to file containing API keys (default: none) | -| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key | -| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate | -| `--timeout N` | server read/write timeout in seconds (default: 600) | +| `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key
(env: LLAMA_ARG_SSL_KEY_FILE) | +| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate
(env: LLAMA_ARG_SSL_CERT_FILE) | +| `-to, --timeout N` | server read/write timeout in seconds (default: 600)
(env: LLAMA_ARG_TIMEOUT) | | `--threads-http N` | number of threads used to process HTTP requests (default: -1)
(env: LLAMA_ARG_THREADS_HTTP) | -| `-spf, --system-prompt-file FNAME` | set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications | -| `--log-format {text, json}` | log output format: json or text (default: json) | +| `--cache-reuse N` | min chunk size to attempt reusing from the cache via KV shifting (default: 0)
(env: LLAMA_ARG_CACHE_REUSE) | | `--metrics` | enable prometheus compatible metrics endpoint (default: disabled)
(env: LLAMA_ARG_ENDPOINT_METRICS) | -| `--no-slots` | disables slots monitoring endpoint (default: enabled)
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | +| `--slots` | enable slots monitoring endpoint (default: disabled)
(env: LLAMA_ARG_ENDPOINT_SLOTS) | +| `--props` | enable changing global properties via POST /props (default: disabled)
(env: LLAMA_ARG_ENDPOINT_PROPS) | +| `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | -| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted:
https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
list of built-in templates:
chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, exaone3, gemma, granite, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, monarch, openchat, orion, phi3, rwkv-world, vicuna, vicuna-orca, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | -| `-ld, --logdir LOGDIR` | path under which to save YAML logs (no logging if unset) | -| `--log-test` | Log test | -| `--log-disable` | Log disable | -| `--log-enable` | Log enable | -| `--log-new` | Log new | -| `--log-append` | Log append | -| `--log-file FNAME` | Log file | +| `--draft-max, --draft, --draft-n N` | number of tokens to draft for speculative decoding (default: 16)
(env: LLAMA_ARG_DRAFT_MAX) | +| `--draft-min, --draft-n-min N` | minimum number of draft tokens to use for speculative decoding (default: 5)
(env: LLAMA_ARG_DRAFT_MIN) | +| `--draft-p-min P` | minimum speculative decoding probability (greedy) (default: 0.9)
(env: LLAMA_ARG_DRAFT_P_MIN) | +| `-cd, --ctx-size-draft N` | size of the prompt context for the draft model (default: 0, 0 = loaded from model)
(env: LLAMA_ARG_CTX_SIZE_DRAFT) | +| `-devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | +| `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | +| `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_MODEL_DRAFT) | + Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. @@ -166,12 +195,6 @@ services: `llama-server` is built alongside everything else from the root of the project -- Using `make`: - - ```bash - make llama-server - ``` - - Using `CMake`: ```bash @@ -185,15 +208,6 @@ services: `llama-server` can also be built with SSL support using OpenSSL 3 -- Using `make`: - - ```bash - # NOTE: For non-system openssl, use the following: - # CXXFLAGS="-I /path/to/openssl/include" - # LDFLAGS="-L /path/to/openssl/lib" - make LLAMA_SERVER_SSL=true llama-server - ``` - - Using `CMake`: ```bash @@ -201,6 +215,41 @@ services: cmake --build build --config Release -t llama-server ``` +## Web UI + +The project includes a web-based user interface that enables interaction with the model through the `/chat/completions` endpoint. + +The web UI is developed using: +- `vue` framework for frontend development +- `tailwindcss` and `daisyui` for styling +- `vite` for build tooling + +A pre-built version is available as a single HTML file under `/public` directory. + +To build or to run the dev server (with hot reload): + +```sh +# make sure you have nodejs installed +cd examples/server/webui +npm i + +# to run the dev server +npm run dev + +# to build the public/index.html.gz +npm run build +``` +After `public/index.html.gz` has been generated we need to generate the c++ +headers (like build/examples/server/index.html.gz.hpp) that will be included +by server.cpp. This is done by building `llama-server` as described in the +[build](#build) section above. + +NOTE: if you are using the vite dev server, you can change the API base URL to llama.cpp. To do that, run this code snippet in browser's console: + +```js +localStorage.setItem('base', 'http://localhost:8080') +``` + ## Quick Start To get started right away, run the following command, making sure to use the correct path for the model you have: @@ -255,23 +304,23 @@ mkdir llama-client cd llama-client ``` -Create a index.js file and put this inside: +Create an index.js file and put this inside: ```javascript -const prompt = `Building a website can be done in 10 simple steps:`; +const prompt = "Building a website can be done in 10 simple steps:" -async function Test() { +async function test() { let response = await fetch("http://127.0.0.1:8080/completion", { - method: 'POST', + method: "POST", body: JSON.stringify({ prompt, - n_predict: 512, + n_predict: 64, }) }) console.log((await response.json()).content) } -Test() +test() ``` And run it: @@ -295,266 +344,472 @@ node index.js ### POST `/completion`: Given a `prompt`, it returns the predicted completion. - *Options:* +> [!IMPORTANT] +> +> This endpoint is **not** OAI-compatible. For OAI-compatible client, use `/v1/completions` instead. - `prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, if `cache_prompt` is `true`, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. A `BOS` token is inserted at the start, if all of the following conditions are true: +*Options:* - - The prompt is a string or an array with the first element given as a string - - The model's `tokenizer.ggml.add_bos_token` metadata is `true` - - The system prompt is empty +`prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, if `cache_prompt` is `true`, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. A `BOS` token is inserted at the start, if all of the following conditions are true: - `temperature`: Adjust the randomness of the generated text. Default: `0.8` + - The prompt is a string or an array with the first element given as a string + - The model's `tokenizer.ggml.add_bos_token` metadata is `true` - `dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled. +These input shapes and data type are allowed for `prompt`: - `dynatemp_exponent`: Dynamic temperature exponent. Default: `1.0` + - Single string: `"string"` + - Single sequence of tokens: `[12, 34, 56]` + - Mixed tokens and strings: `[12, 34, "string", 56, 78]` - `top_k`: Limit the next token selection to the K most probable tokens. Default: `40` +Multiple prompts are also supported. In this case, the completion result will be an array. - `top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. Default: `0.95` + - Only strings: `["string1", "string2"]` + - Strings and sequences of tokens: `["string1", [12, 34, 56]]` + - Mixed types: `[[12, 34, "string", 56, 78], [12, 34, 56], "string"]` - `min_p`: The minimum probability for a token to be considered, relative to the probability of the most likely token. Default: `0.05` +`temperature`: Adjust the randomness of the generated text. Default: `0.8` - `n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity. +`dynatemp_range`: Dynamic temperature range. The final temperature will be in the range of `[temperature - dynatemp_range; temperature + dynatemp_range]` Default: `0.0`, which is disabled. - `n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token. - By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt. +`dynatemp_exponent`: Dynamic temperature exponent. Default: `1.0` - `stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`. +`top_k`: Limit the next token selection to the K most probable tokens. Default: `40` - `stop`: Specify a JSON array of stopping strings. - These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]` +`top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P. Default: `0.95` - `tfs_z`: Enable tail free sampling with parameter z. Default: `1.0`, which is disabled. +`min_p`: The minimum probability for a token to be considered, relative to the probability of the most likely token. Default: `0.05` - `typical_p`: Enable locally typical sampling with parameter p. Default: `1.0`, which is disabled. +`n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity. - `repeat_penalty`: Control the repetition of token sequences in the generated text. Default: `1.1` +`n_indent`: Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks. Default: `0` - `repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size. +`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token. +By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt. - `penalize_nl`: Penalize newline tokens when applying the repeat penalty. Default: `true` +`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`. - `presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled. +`stop`: Specify a JSON array of stopping strings. +These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]` - `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled. +`typical_p`: Enable locally typical sampling with parameter p. Default: `1.0`, which is disabled. - `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0. +`repeat_penalty`: Control the repetition of token sequences in the generated text. Default: `1.1` - `mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0` +`repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size. - `mirostat_eta`: Set the Mirostat learning rate, parameter eta. Default: `0.1` +`presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled. - `grammar`: Set grammar for grammar-based sampling. Default: no grammar +`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled. - `json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema. +`dry_multiplier`: Set the DRY (Don't Repeat Yourself) repetition penalty multiplier. Default: `0.0`, which is disabled. - `seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed. +`dry_base`: Set the DRY repetition penalty base value. Default: `1.75` - `ignore_eos`: Ignore end of stream token and continue generating. Default: `false` +`dry_allowed_length`: Tokens that extend repetition beyond this receive exponentially increasing penalty: multiplier * base ^ (length of repeating sequence before token - allowed length). Default: `2` - `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]` +`dry_penalty_last_n`: How many tokens to scan for repetitions. Default: `-1`, where `0` is disabled and `-1` is context size. - `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0` +`dry_sequence_breakers`: Specify an array of sequence breakers for DRY sampling. Only a JSON array of strings is accepted. Default: `['\n', ':', '"', '*']` - `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0` +`xtc_probability`: Set the chance for token removal via XTC sampler. Default: `0.0`, which is disabled. - `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. +`xtc_threshold`: Set a minimum probability threshold for tokens to be removed via XTC sampler. Default: `0.1` (> `0.5` disables XTC) - `id_slot`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot. Default: `-1` +`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0. - `cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `false` +`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0` - `system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime) +`mirostat_eta`: Set the Mirostat learning rate, parameter eta. Default: `0.1` - `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"]` - these are all the available values. +`grammar`: Set grammar for grammar-based sampling. Default: no grammar + +`json_schema`: Set a JSON schema for grammar-based sampling (e.g. `{"items": {"type": "string"}, "minItems": 10, "maxItems": 100}` of a list of strings, or `{}` for any JSON). See [tests](../../tests/test-json-schema-to-grammar.cpp) for supported features. Default: no JSON schema. + +`seed`: Set the random number generator (RNG) seed. Default: `-1`, which is a random seed. + +`ignore_eos`: Ignore end of stream token and continue generating. Default: `false` + +`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]` + +`n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0` + +`min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0` + +`t_max_predict_ms`: Set a time limit in milliseconds for the prediction (a.k.a. text-generation) phase. The timeout will trigger if the generation takes more than the specified time (measured since the first token was generated) and if a new-line character has already been generated. Useful for FIM applications. Default: `0`, which is disabled. + +`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. + +`id_slot`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot. Default: `-1` + +`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `true` + +`return_tokens`: Return the raw generated token ids in the `tokens` field. Otherwise `tokens` remains empty. Default: `false` + +`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values. + +`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` + +`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain. + +`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name. + +`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation. **Response format** -- Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. +- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. -- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure: - -```json -{ - "content": "", - "probs": [ - { - "prob": float, - "tok_str": "" - }, - { - "prob": float, - "tok_str": "" - }, +- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements: + ``` + { + "content": "", + "tokens": [ generated token ids if requested ], ... - ] -}, -``` - -Notice that each `probs` is an array of length `n_probs`. + "probs": [ + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + "top_logprobs": [ + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + }, + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + }, + ... + ] + }, + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + "top_logprobs": [ + ... + ] + }, + ... + ] + }, + ``` + Please note that if `post_sampling_probs` is set to `true`: + - `logprob` will be replaced with `prob`, with the value between 0.0 and 1.0 + - `top_logprobs` will be replaced with `top_probs`. Each element contains: + - `id`: token ID + - `token`: token in string + - `bytes`: token in bytes + - `prob`: token probability, with the value between 0.0 and 1.0 + - Number of elements in `top_probs` may be less than `n_probs` - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. +- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request. - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.). -- `model`: The path to the model loaded with `-m` -- `prompt`: The provided `prompt` -- `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token -- `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered -- `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided +- `model`: The model alias (for model path, please use `/props` endpoint) +- `prompt`: The processed `prompt` (special tokens may be added) +- `stop_type`: Indicating whether the completion has stopped. Possible values are: + - `none`: Generating (not stopped) + - `eos`: Stopped because it encountered the EOS token + - `limit`: Stopped because `n_predict` tokens were generated before stop words or EOS was encountered + - `word`: Stopped due to encountering a stopping word from `stop` JSON array provided - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) - `tokens_evaluated`: Number of tokens evaluated in total from the prompt - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) + ### POST `/tokenize`: Tokenize a given text - *Options:* +*Options:* - `content`: Set the text to tokenize. +`content`: (Required) The text to tokenize. - `add_special`: Boolean indicating if special tokens, i.e. `BOS`, should be inserted. Default: `false` +`add_special`: (Optional) Boolean indicating if special tokens, i.e. `BOS`, should be inserted. Default: `false` + +`with_pieces`: (Optional) Boolean indicating whether to return token pieces along with IDs. Default: `false` + +**Response:** + +Returns a JSON object with a `tokens` field containing the tokenization result. The `tokens` array contains either just token IDs or objects with `id` and `piece` fields, depending on the `with_pieces` parameter. The piece field is a string if the piece is valid unicode or a list of bytes otherwise. + + +If `with_pieces` is `false`: +```json +{ + "tokens": [123, 456, 789] +} +``` + +If `with_pieces` is `true`: +```json +{ + "tokens": [ + {"id": 123, "piece": "Hello"}, + {"id": 456, "piece": " world"}, + {"id": 789, "piece": "!"} + ] +} +``` + +With input 'á' (utf8 hex: C3 A1) on tinyllama/stories260k +``` +{ + "tokens": [ + {"id": 198, "piece": [195]}, // hex C3 + {"id": 164, "piece": [161]} // hex A1 + ] +} +``` ### POST `/detokenize`: Convert tokens to text - *Options:* +*Options:* - `tokens`: Set the tokens to detokenize. +`tokens`: Set the tokens to detokenize. + +### POST `/apply-template`: Apply chat template to a conversation + +Uses the server's prompt template formatting functionality to convert chat messages to a single string expected by a chat model as input, but does not perform inference. Instead, the prompt string is returned in the `prompt` field of the JSON response. The prompt can then be modified as desired (for example, to insert "Sure!" at the beginning of the model's response) before sending to `/completion` to generate the chat response. + +*Options:* + +`messages`: (Required) Chat turns in the same format as `/v1/chat/completions`. + +**Response format** + +Returns a JSON object with a field `prompt` containing a string of the input messages formatted according to the model's chat template format. ### POST `/embedding`: Generate embedding of a given text +> [!IMPORTANT] +> +> This endpoint is **not** OAI-compatible. For OAI-compatible client, use `/v1/embeddings` instead. + The same as [the embedding example](../embedding) does. - *Options:* +*Options:* - `content`: Set the text to process. +`content`: Set the text to process. - `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. +`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. + +### POST `/reranking`: Rerank documents according to a given query + +Similar to https://jina.ai/reranker/ but might change in the future. +Requires a reranker model (such as [bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)) and the `--embedding --pooling rank` options. + +*Options:* + +`query`: The query against which the documents will be ranked. + +`documents`: An array strings representing the documents to be ranked. + +*Aliases:* + - `/rerank` + - `/v1/rerank` + - `/v1/reranking` + +*Examples:* + +```shell +curl http://127.0.0.1:8012/v1/rerank \ + -H "Content-Type: application/json" \ + -d '{ + "model": "some-model", + "query": "What is panda?", + "top_n": 3, + "documents": [ + "hi", + "it is a bear", + "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." + ] + }' | jq +``` ### POST `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream. - *Options:* +*Options:* - `input_prefix`: Set the prefix of the code to infill. +- `input_prefix`: Set the prefix of the code to infill. +- `input_suffix`: Set the suffix of the code to infill. +- `input_extra`: Additional context inserted before the FIM prefix. +- `prompt`: Added after the `FIM_MID` token - `input_suffix`: Set the suffix of the code to infill. +`input_extra` is array of `{"filename": string, "text": string}` objects. - It also accepts all the options of `/completion` except `stream` and `prompt`. +The endpoint also accepts all the options of `/completion`. -- **GET** `/props`: Return current server settings. +If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used: + +```txt +myproject +{chunk 0 filename} +{chunk 0 text} +{chunk 1 filename} +{chunk 1 text} +... +filename +[input_prefix][input_suffix][prompt] +``` + +If the tokens are missing, then the extra context is simply prefixed at the start: + +```txt +[input_extra][input_prefix][input_suffix][prompt] +``` + +### **GET** `/props`: Get server global properties. + +This endpoint is public (no API key check). By default, it is read-only. To make POST request to change global properties, you need to start server with `--props` **Response format** ```json { - "assistant_name": "", - "user_name": "", - "default_generation_settings": { ... }, + "default_generation_settings": { + "id": 0, + "id_task": -1, + "n_ctx": 1024, + "speculative": false, + "is_processing": false, + "params": { + "n_predict": -1, + "seed": 4294967295, + "temperature": 0.800000011920929, + "dynatemp_range": 0.0, + "dynatemp_exponent": 1.0, + "top_k": 40, + "top_p": 0.949999988079071, + "min_p": 0.05000000074505806, + "xtc_probability": 0.0, + "xtc_threshold": 0.10000000149011612, + "typical_p": 1.0, + "repeat_last_n": 64, + "repeat_penalty": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "dry_multiplier": 0.0, + "dry_base": 1.75, + "dry_allowed_length": 2, + "dry_penalty_last_n": -1, + "dry_sequence_breakers": [ + "\n", + ":", + "\"", + "*" + ], + "mirostat": 0, + "mirostat_tau": 5.0, + "mirostat_eta": 0.10000000149011612, + "stop": [], + "max_tokens": -1, + "n_keep": 0, + "n_discard": 0, + "ignore_eos": false, + "stream": true, + "n_probs": 0, + "min_keep": 0, + "grammar": "", + "samplers": [ + "dry", + "top_k", + "typ_p", + "top_p", + "min_p", + "xtc", + "temperature" + ], + "speculative.n_max": 16, + "speculative.n_min": 5, + "speculative.p_min": 0.8999999761581421, + "timings_per_token": false + }, + "prompt": "", + "next_token": { + "has_next_token": true, + "has_new_line": false, + "n_remain": -1, + "n_decoded": 0, + "stopping_word": "" + } + }, "total_slots": 1, - "chat_template": "" + "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", + "chat_template": "...", + "build_info": "b(build number)-(build commit hash)" } ``` -- `assistant_name` - the required assistant name to generate the prompt in case you have specified a system prompt for all slots. -- `user_name` - the required anti-prompt to generate the prompt in case you have specified a system prompt for all slots. - `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint. - `total_slots` - the total number of slots for process requests (defined by `--parallel` option) +- `model_path` - the path to model file (same with `-m` argument) - `chat_template` - the model's original Jinja2 prompt template -### POST `/v1/chat/completions`: OpenAI-compatible Chat Completions API +### POST `/props`: Change server global properties. -Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. +To use this endpoint with POST method, you need to start server with `--props` - *Options:* +*Options:* - See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). While some OpenAI-specific features such as function calling aren't supported, llama.cpp `/completion`-specific features such as `mirostat` are supported. +- None yet - The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}`), similar to other OpenAI-inspired API providers. +### POST `/embeddings`: non-OpenAI-compatible embeddings API - *Examples:* +This endpoint supports all poolings, including `--pooling none`. When the pooling is `none`, the responses will contain the *unnormalized* embeddings for *all* input tokens. For all other pooling types, only the pooled embeddings are returned, normalized using Euclidian norm. - You can use either Python `openai` library with appropriate checkpoints: +Note that the response format of this endpoint is different from `/v1/embeddings`. - ```python - import openai +*Options:* - client = openai.OpenAI( - base_url="http://localhost:8080/v1", # "http://:port" - api_key = "sk-no-key-required" - ) +Same as the `/v1/embeddings` endpoint. - completion = client.chat.completions.create( - model="gpt-3.5-turbo", - messages=[ - {"role": "system", "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests."}, - {"role": "user", "content": "Write a limerick about python exceptions"} +*Examples:* + +Same as the `/v1/embeddings` endpoint. + +**Response format** + +``` +[ + { + "index": 0, + "embedding": [ + [ ... embeddings for token 0 ... ], + [ ... embeddings for token 1 ... ], + [ ... ] + [ ... embeddings for token N-1 ... ], ] - ) - - print(completion.choices[0].message) - ``` - - ... or raw HTTP requests: - - ```shell - curl http://localhost:8080/v1/chat/completions \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer no-key" \ - -d '{ - "model": "gpt-3.5-turbo", - "messages": [ - { - "role": "system", - "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests." - }, - { - "role": "user", - "content": "Write a limerick about python exceptions" - } + }, + ... + { + "index": P, + "embedding": [ + [ ... embeddings for token 0 ... ], + [ ... embeddings for token 1 ... ], + [ ... ] + [ ... embeddings for token N-1 ... ], ] - }' - ``` - -### POST `/v1/embeddings`: OpenAI-compatible embeddings API - - *Options:* - - See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings). - - *Examples:* - - - input as string - - ```shell - curl http://localhost:8080/v1/embeddings \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer no-key" \ - -d '{ - "input": "hello", - "model":"GPT-4", - "encoding_format": "float" - }' - ``` - - - `input` as string array - - ```shell - curl http://localhost:8080/v1/embeddings \ - -H "Content-Type: application/json" \ - -H "Authorization: Bearer no-key" \ - -d '{ - "input": ["hello", "world"], - "model":"GPT-4", - "encoding_format": "float" - }' - ``` + } +] +``` ### GET `/slots`: Returns the current slots processing state -This endpoint can be disabled with `--no-slots` +> [!WARNING] +> This endpoint is intended for debugging and may be modified in future versions. For security reasons, we strongly advise against enabling it in production environments. + +This endpoint is disabled by default and can be enabled with `--slots` If query param `?fail_on_no_slot=1` is set, this endpoint will respond with status code 503 if there is no available slots. @@ -564,65 +819,76 @@ Example: ```json [ - { - "dynatemp_exponent": 1.0, - "dynatemp_range": 0.0, - "frequency_penalty": 0.0, - "grammar": "", - "id": 0, - "ignore_eos": false, - "logit_bias": [], - "min_p": 0.05000000074505806, - "mirostat": 0, - "mirostat_eta": 0.10000000149011612, - "mirostat_tau": 5.0, - "model": "llama-2-7b-32k-instruct.Q2_K.gguf", - "n_ctx": 2048, - "n_keep": 0, - "n_predict": 100000, - "n_probs": 0, - "next_token": { - "has_next_token": true, - "n_remain": -1, - "n_decoded": 0, - "stopped_eos": false, - "stopped_limit": false, - "stopped_word": false, - "stopping_word": "" - }, - "penalize_nl": true, - "presence_penalty": 0.0, - "prompt": "Say hello to llama.cpp", - "repeat_last_n": 64, - "repeat_penalty": 1.100000023841858, - "samplers": [ - "top_k", - "tfs_z", - "typical_p", - "top_p", - "min_p", - "temperature" - ], - "seed": 42, - "state": 1, - "stop": [ - "\n" - ], - "stream": false, - "task_id": 0, - "temperature": 0.0, - "tfs_z": 1.0, - "top_k": 40, - "top_p": 0.949999988079071, - "typical_p": 1.0 + { + "id": 0, + "id_task": -1, + "n_ctx": 1024, + "speculative": false, + "is_processing": false, + "params": { + "n_predict": -1, + "seed": 4294967295, + "temperature": 0.800000011920929, + "dynatemp_range": 0.0, + "dynatemp_exponent": 1.0, + "top_k": 40, + "top_p": 0.949999988079071, + "min_p": 0.05000000074505806, + "xtc_probability": 0.0, + "xtc_threshold": 0.10000000149011612, + "typical_p": 1.0, + "repeat_last_n": 64, + "repeat_penalty": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "dry_multiplier": 0.0, + "dry_base": 1.75, + "dry_allowed_length": 2, + "dry_penalty_last_n": -1, + "dry_sequence_breakers": [ + "\n", + ":", + "\"", + "*" + ], + "mirostat": 0, + "mirostat_tau": 5.0, + "mirostat_eta": 0.10000000149011612, + "stop": [], + "max_tokens": -1, + "n_keep": 0, + "n_discard": 0, + "ignore_eos": false, + "stream": true, + "n_probs": 0, + "min_keep": 0, + "grammar": "", + "samplers": [ + "dry", + "top_k", + "typ_p", + "top_p", + "min_p", + "xtc", + "temperature" + ], + "speculative.n_max": 16, + "speculative.n_min": 5, + "speculative.p_min": 0.8999999761581421, + "timings_per_token": false + }, + "prompt": "", + "next_token": { + "has_next_token": true, + "has_new_line": false, + "n_remain": -1, + "n_decoded": 0, + "stopping_word": "" } + } ] ``` -Possible values for `slot[i].state` are: -- `0`: SLOT_STATE_IDLE -- `1`: SLOT_STATE_PROCESSING - ### GET `/metrics`: Prometheus compatible metrics exporter This endpoint is only accessible if `--metrics` is set. @@ -639,9 +905,9 @@ Available metrics: ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. - *Options:* +*Options:* - `filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter. +`filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter. **Response format** @@ -659,9 +925,9 @@ Available metrics: ### POST `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file. - *Options:* +*Options:* - `filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter. +`filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter. **Response format** @@ -694,6 +960,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply` +Please note that this value will be overwritten by the `lora` field for each request. + If an adapter is disabled, the scale will be set to 0. **Response format** @@ -715,6 +983,8 @@ If an adapter is disabled, the scale will be set to 0. ### POST `/lora-adapters`: Set list of LoRA adapters +This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request. + To disable an adapter, either remove it from the list below, or set scale to 0. **Request format** @@ -728,29 +998,238 @@ To know the `id` of the adapter, use GET `/lora-adapters` ] ``` -## More examples +## OpenAI-compatible API Endpoints -### Change system prompt on runtime +### GET `/v1/models`: OpenAI-compatible Model Info API -To use the server example to serve multiple chat-type clients while keeping the same system prompt, you can utilize the option `system_prompt`. This only needs to be used once. +Returns information about the loaded model. See [OpenAI Models API documentation](https://platform.openai.com/docs/api-reference/models). -`prompt`: Specify a context that you want all connecting clients to respect. +The returned list always has one single element. -`anti_prompt`: Specify the word you want to use to instruct the model to stop. This must be sent to each client through the `/props` endpoint. +By default, model `id` field is the path to model file, specified via `-m`. You can set a custom value for model `id` field via `--alias` argument. For example, `--alias gpt-4o-mini`. -`assistant_name`: The bot's name is necessary for each customer to generate the prompt. This must be sent to each client through the `/props` endpoint. +Example: ```json { - "system_prompt": { - "prompt": "Transcript of a never ending dialog, where the User interacts with an Assistant.\nThe Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.\nUser: Recommend a nice restaurant in the area.\nAssistant: I recommend the restaurant \"The Golden Duck\". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.\nUser: Who is Richard Feynman?\nAssistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including \"Surely You're Joking, Mr. Feynman!\" and \"What Do You Care What Other People Think?\".\nUser:", - "anti_prompt": "User:", - "assistant_name": "Assistant:" - } + "object": "list", + "data": [ + { + "id": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", + "object": "model", + "created": 1735142223, + "owned_by": "llamacpp", + "meta": { + "vocab_type": 2, + "n_vocab": 128256, + "n_ctx_train": 131072, + "n_embd": 4096, + "n_params": 8030261312, + "size": 4912898304 + } + } + ] } ``` -**NOTE**: You can do this automatically when starting the server by simply creating a .json file with these options and using the CLI option `-spf FNAME` or `--system-prompt-file FNAME`. +### POST `/v1/completions`: OpenAI-compatible Completions API + +Given an input `prompt`, it returns the predicted completion. Streaming mode is also supported. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. + +*Options:* + +See [OpenAI Completions API documentation](https://platform.openai.com/docs/api-reference/completions). + +llama.cpp `/completion`-specific features such as `mirostat` are supported. + +*Examples:* + +Example usage with `openai` python library: + +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/v1", # "http://:port" + api_key = "sk-no-key-required" +) + +completion = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8 +) + +print(completion.choices[0].text) +``` + +### POST `/v1/chat/completions`: OpenAI-compatible Chat Completions API + +Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. + +*Options:* + +See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). While some OpenAI-specific features such as function calling aren't supported, llama.cpp `/completion`-specific features such as `mirostat` are supported. + +The `response_format` parameter supports both plain JSON output (e.g. `{"type": "json_object"}`) and schema-constrained JSON (e.g. `{"type": "json_object", "schema": {"type": "string", "minLength": 10, "maxLength": 100}}` or `{"type": "json_schema", "schema": {"properties": { "name": { "title": "Name", "type": "string" }, "date": { "title": "Date", "type": "string" }, "participants": { "items": {"type: "string" }, "title": "Participants", "type": "string" } } } }`), similar to other OpenAI-inspired API providers. + +*Examples:* + +You can use either Python `openai` library with appropriate checkpoints: + +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/v1", # "http://:port" + api_key = "sk-no-key-required" +) + +completion = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests."}, + {"role": "user", "content": "Write a limerick about python exceptions"} + ] +) + +print(completion.choices[0].message) +``` + +... or raw HTTP requests: + +```shell +curl http://localhost:8080/v1/chat/completions \ +-H "Content-Type: application/json" \ +-H "Authorization: Bearer no-key" \ +-d '{ +"model": "gpt-3.5-turbo", +"messages": [ +{ + "role": "system", + "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests." +}, +{ + "role": "user", + "content": "Write a limerick about python exceptions" +} +] +}' +``` + +... and even tool usage (needs `--jinja` flag): + + ```shell + llama-server --jinja -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf -fa + + # https://huggingface.co/meetkai/functionary-medium-v3.2 + llama-server --jinja -hfr bartowski/functionary-medium-v3.2-GGUF -hff functionary-medium-v3.2-IQ4_XS.gguf -fa + + # https://huggingface.co/meetkai/functionary-medium-v3.1 + llama-server --jinja -hfr meetkai/functionary-medium-v3.1-GGUF -hff functionary-medium-llama-3.1.Q4_0.gguf -fa + + curl http://localhost:8080/v1/chat/completions -d '{ + "model": "gpt-3.5-turbo", + "tools": [ + { + "type":"function", + "function":{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and state, e.g. San Francisco, CA" + } + }, + "required":["location"] + } + } + } + ], + "messages": [ + { + "role": "user", + "content": "What is the weather like in Istanbul?." + } + ] + }' + ``` + +
+ Show output + + ```json + { + "choices": [ + { + "finish_reason": "tool", + "index": 0, + "message": { + "content": null, + "tool_calls": [ + { + "name": "python", + "arguments": "{\"code\":\" \\nprint(\\\"Hello, World!\\\")\"}" + } + ], + "role": "assistant" + } + } + ], + "created": 1727287211, + "model": "gpt-3.5-turbo", + "object": "chat.completion", + "usage": { + "completion_tokens": 16, + "prompt_tokens": 44, + "total_tokens": 60 + }, + "id": "chatcmpl-Htbgh9feMmGM0LEH2hmQvwsCxq3c6Ni8" + } + ``` + +
+ +### POST `/v1/embeddings`: OpenAI-compatible embeddings API + +This endpoint requires that the model uses a pooling different than type `none`. The embeddings are normalized using the Eucledian norm. + +*Options:* + +See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-reference/embeddings). + +*Examples:* + +- input as string + + ```shell + curl http://localhost:8080/v1/embeddings \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer no-key" \ + -d '{ + "input": "hello", + "model":"GPT-4", + "encoding_format": "float" + }' + ``` + +- `input` as string array + + ```shell + curl http://localhost:8080/v1/embeddings \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer no-key" \ + -d '{ + "input": ["hello", "world"], + "model":"GPT-4", + "encoding_format": "float" + }' + ``` + +## More examples ### Interactive mode @@ -815,6 +1294,16 @@ Apart from error types supported by OAI, we also have custom types that are spec } ``` +### Legacy completion web UI + +A new chat-based UI has replaced the old completion-based since [this PR](https://github.com/ggerganov/llama.cpp/pull/10175). If you want to use the old completion, start the server with `--path ./examples/server/public_legacy` + +For example: + +```sh +./llama-server -m my_model.gguf -c 8192 --path ./examples/server/public_legacy +``` + ### Extending or building alternative Web Front End You can extend the front end by running the server binary with `--path` set to `./your-directory` and importing `/completion.js` to get access to the llamaComplete() method. diff --git a/examples/server/bench/README.md b/examples/server/bench/README.md index 0f18ca396..9549795ec 100644 --- a/examples/server/bench/README.md +++ b/examples/server/bench/README.md @@ -6,10 +6,10 @@ Benchmark is using [k6](https://k6.io/). SSE is not supported by default in k6, you have to build k6 with the [xk6-sse](https://github.com/phymbert/xk6-sse) extension. -Example: +Example (assuming golang >= 1.21 is installed): ```shell go install go.k6.io/xk6/cmd/xk6@latest -xk6 build master \ +$GOPATH/bin/xk6 build master \ --with github.com/phymbert/xk6-sse ``` @@ -33,14 +33,13 @@ The server must answer OAI Chat completion requests on `http://localhost:8080/v1 Example: ```shell -server --host localhost --port 8080 \ +llama-server --host localhost --port 8080 \ --model ggml-model-q4_0.gguf \ --cont-batching \ --metrics \ --parallel 8 \ --batch-size 512 \ --ctx-size 4096 \ - --log-format text \ -ngl 33 ``` diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index 2daac0884..5cc6f92ab 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -189,12 +189,12 @@ xychart-beta "pp": { "p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2), "avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2), - "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2), + "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0, }, "tg": { "p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2), "avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2), - "0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2), + "0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0, }, } with open("results.github.env", 'a') as github_env: @@ -214,11 +214,14 @@ def start_benchmark(args): k6_args = [ 'run', args.scenario, '--no-color', + '--no-connection-reuse', + '--no-vu-connection-reuse', ] k6_args.extend(['--duration', args.duration]) k6_args.extend(['--iterations', args.n_prompts]) k6_args.extend(['--vus', args.parallel]) k6_args.extend(['--summary-export', 'k6-results.json']) + k6_args.extend(['--out', 'csv=k6-results.csv']) args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} " args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]]) print(f"bench: starting k6 with: {args}") @@ -231,7 +234,7 @@ def start_server(args): server_process = start_server_background(args) attempts = 0 - max_attempts = 20 + max_attempts = 600 if 'GITHUB_ACTIONS' in os.environ: max_attempts *= 2 @@ -242,7 +245,15 @@ def start_server(args): print(f"bench: waiting for server to start ...") time.sleep(0.5) - print("bench: server started.") + attempts = 0 + while not is_server_ready(args.host, args.port): + attempts += 1 + if attempts > max_attempts: + assert False, "server not ready" + print(f"bench: waiting for server to be ready ...") + time.sleep(0.5) + + print("bench: server started and ready.") return server_process @@ -255,11 +266,6 @@ def start_server_background(args): '--host', args.host, '--port', args.port, ] - model_file = args.model_path_prefix + os.path.sep + args.hf_file - model_dir = os.path.dirname(model_file) - if not os.path.exists(model_dir): - os.makedirs(model_dir) - server_args.extend(['--model', model_file]) server_args.extend(['--hf-repo', args.hf_repo]) server_args.extend(['--hf-file', args.hf_file]) server_args.extend(['--n-gpu-layers', args.n_gpu_layers]) @@ -272,7 +278,6 @@ def start_server_background(args): server_args.append('--cont-batching') server_args.append('--metrics') server_args.append('--flash-attn') - server_args.extend(['--log-format', "text"]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") pkwargs = { @@ -304,6 +309,12 @@ def is_server_listening(server_fqdn, server_port): return _is_server_listening +def is_server_ready(server_fqdn, server_port): + url = f"http://{server_fqdn}:{server_port}/health" + response = requests.get(url) + return response.status_code == 200 + + def escape_metric_name(metric_name): return re.sub('[^A-Z0-9]', '_', metric_name.upper()) diff --git a/examples/server/bench/script.js b/examples/server/bench/script.js index bdf4f5abc..2772bee5e 100644 --- a/examples/server/bench/script.js +++ b/examples/server/bench/script.js @@ -56,6 +56,7 @@ const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens') const llamacpp_tokens_second = new Trend('llamacpp_tokens_second') const llamacpp_prompt_processing_second = new Trend('llamacpp_prompt_processing_second') +const llamacpp_emit_first_token_second = new Trend('llamacpp_emit_first_token_second') const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter') const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter') @@ -89,6 +90,9 @@ export default function () { ], "model": model, "stream": true, + "stream_options": { + "include_usage": true, // False to be supported in llama.cpp server + }, "seed": 42, "max_tokens": max_tokens, "stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS @@ -105,12 +109,20 @@ export default function () { client.on('event', function (event) { if (promptEvalEndTime == null) { promptEvalEndTime = new Date() + llamacpp_emit_first_token_second.add((promptEvalEndTime - startTime) / 1.e3) + } + + if (event.data === '[DONE]' || event.data === '') { + return } let chunk = JSON.parse(event.data) - let choice = chunk.choices[0] - if (choice.finish_reason) { - finish_reason = choice.finish_reason + + if (chunk.choices && chunk.choices.length > 0) { + let choice = chunk.choices[0] + if (choice.finish_reason) { + finish_reason = choice.finish_reason + } } if (chunk.usage) { diff --git a/examples/server/chat.mjs b/examples/server/chat.mjs index a79c8a3cd..4fef5655a 100644 --- a/examples/server/chat.mjs +++ b/examples/server/chat.mjs @@ -1,7 +1,7 @@ import * as readline from 'node:readline' import { stdin, stdout } from 'node:process' import { readFileSync } from 'node:fs' -import { SchemaConverter } from './public/json-schema-to-grammar.mjs' +import { SchemaConverter } from './public_legacy/json-schema-to-grammar.mjs' const args = process.argv.slice(2); const grammarJsonSchemaFile = args.find( diff --git a/examples/server/deps.sh b/examples/server/deps.sh deleted file mode 100755 index d28378901..000000000 --- a/examples/server/deps.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash -# Download and update deps for binary - -# get the directory of this script file -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -PUBLIC=$DIR/public - -echo "download js bundle files" -curl https://npm.reversehttp.com/@preact/signals-core,@preact/signals,htm/preact,preact,preact/hooks > $PUBLIC/index.js -echo >> $PUBLIC/index.js # add newline diff --git a/examples/server/httplib.h b/examples/server/httplib.h index f360bd93e..c2f12dd2a 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -8,7 +8,7 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.15.3" +#define CPPHTTPLIB_VERSION "0.18.5" /* * Configuration @@ -18,8 +18,12 @@ #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 #endif +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND 10000 +#endif + #ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT -#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 100 #endif #ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND @@ -30,20 +34,36 @@ #define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 #endif -#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND -#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND 5 #endif -#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND -#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#ifndef CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND 0 #endif -#ifndef CPPHTTPLIB_WRITE_TIMEOUT_SECOND -#define CPPHTTPLIB_WRITE_TIMEOUT_SECOND 5 +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND 5 #endif -#ifndef CPPHTTPLIB_WRITE_TIMEOUT_USECOND -#define CPPHTTPLIB_WRITE_TIMEOUT_USECOND 0 +#ifndef CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND 0 #endif #ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND @@ -90,8 +110,12 @@ #define CPPHTTPLIB_TCP_NODELAY false #endif +#ifndef CPPHTTPLIB_IPV6_V6ONLY +#define CPPHTTPLIB_IPV6_V6ONLY false +#endif + #ifndef CPPHTTPLIB_RECV_BUFSIZ -#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#define CPPHTTPLIB_RECV_BUFSIZ size_t(16384u) #endif #ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ @@ -145,11 +169,11 @@ using ssize_t = long; #endif // _MSC_VER #ifndef S_ISREG -#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#define S_ISREG(m) (((m) & S_IFREG) == S_IFREG) #endif // S_ISREG #ifndef S_ISDIR -#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#define S_ISDIR(m) (((m) & S_IFDIR) == S_IFDIR) #endif // S_ISDIR #ifndef NOMINMAX @@ -269,7 +293,12 @@ using socket_t = int; #include #include -#if OPENSSL_VERSION_NUMBER < 0x30000000L +#if defined(OPENSSL_IS_BORINGSSL) || defined(LIBRESSL_VERSION_NUMBER) +#if OPENSSL_VERSION_NUMBER < 0x1010107f +#error Please use OpenSSL or a current version of BoringSSL +#endif +#define SSL_get1_peer_certificate SSL_get_peer_certificate +#elif OPENSSL_VERSION_NUMBER < 0x30000000L #error Sorry, OpenSSL versions prior to 3.0.0 are not supported #endif @@ -312,16 +341,63 @@ make_unique(std::size_t n) { return std::unique_ptr(new RT[n]); } -struct ci { - bool operator()(const std::string &s1, const std::string &s2) const { - return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), - s2.end(), - [](unsigned char c1, unsigned char c2) { - return ::tolower(c1) < ::tolower(c2); - }); +namespace case_ignore { + +inline unsigned char to_lower(int c) { + const static unsigned char table[256] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, + 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, + 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, + 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, + 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, + 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, + 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 224, 225, 226, + 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, + 242, 243, 244, 245, 246, 215, 248, 249, 250, 251, 252, 253, 254, 223, 224, + 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, + 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, + 255, + }; + return table[(unsigned char)(char)c]; +} + +inline bool equal(const std::string &a, const std::string &b) { + return a.size() == b.size() && + std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { + return to_lower(ca) == to_lower(cb); + }); +} + +struct equal_to { + bool operator()(const std::string &a, const std::string &b) const { + return equal(a, b); } }; +struct hash { + size_t operator()(const std::string &key) const { + return hash_core(key.data(), key.size(), 0); + } + + size_t hash_core(const char *s, size_t l, size_t h) const { + return (l == 0) ? h + : hash_core(s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no + // overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(to_lower(*s))); + } +}; + +} // namespace case_ignore + // This is based on // "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". @@ -427,7 +503,9 @@ enum StatusCode { NetworkAuthenticationRequired_511 = 511, }; -using Headers = std::multimap; +using Headers = + std::unordered_multimap; using Params = std::multimap; using Match = std::smatch; @@ -534,6 +612,7 @@ using Ranges = std::vector; struct Request { std::string method; std::string path; + Params params; Headers headers; std::string body; @@ -545,11 +624,11 @@ struct Request { // for server std::string version; std::string target; - Params params; MultipartFormDataMap files; Ranges ranges; Match matches; std::unordered_map path_params; + std::function is_connection_closed = []() { return true; }; // for client ResponseHandler response_handler; @@ -560,8 +639,10 @@ struct Request { #endif bool has_header(const std::string &key) const; - std::string get_header_value(const std::string &key, size_t id = 0) const; - uint64_t get_header_value_u64(const std::string &key, size_t id = 0) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, + size_t id = 0) const; size_t get_header_value_count(const std::string &key) const; void set_header(const std::string &key, const std::string &val); @@ -592,8 +673,10 @@ struct Response { std::string location; // Redirect location bool has_header(const std::string &key) const; - std::string get_header_value(const std::string &key, size_t id = 0) const; - uint64_t get_header_value_u64(const std::string &key, size_t id = 0) const; + std::string get_header_value(const std::string &key, const char *def = "", + size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, uint64_t def = 0, + size_t id = 0) const; size_t get_header_value_count(const std::string &key) const; void set_header(const std::string &key, const std::string &val); @@ -614,6 +697,10 @@ struct Response { const std::string &content_type, ContentProviderWithoutLength provider, ContentProviderResourceReleaser resource_releaser = nullptr); + void set_file_content(const std::string &path, + const std::string &content_type); + void set_file_content(const std::string &path); + Response() = default; Response(const Response &) = default; Response &operator=(const Response &) = default; @@ -631,6 +718,8 @@ struct Response { ContentProviderResourceReleaser content_provider_resource_releaser_; bool is_chunked_content_provider_ = false; bool content_provider_success_ = false; + std::string file_content_path_; + std::string file_content_content_type_; }; class Stream { @@ -646,8 +735,6 @@ public: virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; virtual socket_t socket() const = 0; - template - ssize_t write_format(const char *fmt, const Args &...args); ssize_t write(const char *ptr); ssize_t write(const std::string &s); }; @@ -719,13 +806,18 @@ private: if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } - fn = std::move(pool_.jobs_.front()); + fn = pool_.jobs_.front(); pool_.jobs_.pop_front(); } assert(true == static_cast(fn)); fn(); } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ + !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif } ThreadPool &pool_; @@ -787,7 +879,6 @@ public: bool match(Request &request) const override; private: - static constexpr char marker = ':'; // Treat segment separators as the end of path parameter capture // Does not need to handle query parameters as they are parsed before path // matching @@ -871,8 +962,13 @@ public: Server &set_default_file_mimetype(const std::string &mime); Server &set_file_request_handler(Handler handler); - Server &set_error_handler(HandlerWithResponse handler); - Server &set_error_handler(Handler handler); + template + Server &set_error_handler(ErrorHandlerFunc &&handler) { + return set_error_handler_core( + std::forward(handler), + std::is_convertible{}); + } + Server &set_exception_handler(ExceptionHandler handler); Server &set_pre_routing_handler(HandlerWithResponse handler); Server &set_post_routing_handler(Handler handler); @@ -882,6 +978,7 @@ public: Server &set_address_family(int family); Server &set_tcp_nodelay(bool on); + Server &set_ipv6_v6only(bool on); Server &set_socket_options(SocketOptions socket_options); Server &set_default_headers(Headers headers); @@ -914,21 +1011,24 @@ public: bool is_running() const; void wait_until_ready() const; void stop(); + void decommission(); std::function new_task_queue; protected: - bool process_request(Stream &strm, bool close_connection, + bool process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, bool &connection_closed, const std::function &setup_request); std::atomic svr_sock_{INVALID_SOCKET}; size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_SERVER_WRITE_TIMEOUT_USECOND; time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; @@ -943,6 +1043,9 @@ private: static std::unique_ptr make_matcher(const std::string &pattern); + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); + Server &set_error_handler_core(Handler handler, std::false_type); + socket_t create_server_socket(const std::string &host, int port, int socket_flags, SocketOptions socket_options) const; @@ -985,7 +1088,7 @@ private: virtual bool process_and_close_socket(socket_t sock); std::atomic is_running_{false}; - std::atomic done_{false}; + std::atomic is_decommisioned{false}; struct MountPointEntry { std::string mount_point; @@ -1018,6 +1121,7 @@ private: int address_family_ = AF_UNSPEC; bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; SocketOptions socket_options_ = default_socket_options; Headers default_headers_; @@ -1037,6 +1141,7 @@ enum class Error { SSLConnection, SSLLoadingCerts, SSLServerVerification, + SSLServerHostnameVerification, UnsupportedMultipartBoundaryChars, Compression, ConnectionTimeout, @@ -1074,9 +1179,10 @@ public: // Request Headers bool has_request_header(const std::string &key) const; std::string get_request_header_value(const std::string &key, + const char *def = "", size_t id = 0) const; uint64_t get_request_header_value_u64(const std::string &key, - size_t id = 0) const; + uint64_t def = 0, size_t id = 0) const; size_t get_request_header_value_count(const std::string &key) const; private: @@ -1140,10 +1246,18 @@ public: const std::string &content_type); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); @@ -1159,6 +1273,8 @@ public: Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); Result Post(const std::string &path, const MultipartFormDataItems &items); Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); @@ -1173,10 +1289,18 @@ public: const std::string &content_type); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); Result Put(const std::string &path, @@ -1191,6 +1315,8 @@ public: Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); Result Put(const std::string &path, const MultipartFormDataItems &items); Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); @@ -1203,13 +1329,23 @@ public: Result Patch(const std::string &path); Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); @@ -1227,13 +1363,24 @@ public: Result Delete(const std::string &path, const Headers &headers); Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Options(const std::string &path); Result Options(const std::string &path, const Headers &headers); @@ -1258,6 +1405,7 @@ public: void set_address_family(int family); void set_tcp_nodelay(bool on); + void set_ipv6_v6only(bool on); void set_socket_options(SocketOptions socket_options); void set_connection_timeout(time_t sec, time_t usec = 0); @@ -1309,6 +1457,8 @@ public: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); #endif void set_logger(Logger logger); @@ -1375,10 +1525,10 @@ protected: time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; - time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; - time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; - time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; - time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_CLIENT_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; std::string basic_auth_username_; std::string basic_auth_password_; @@ -1395,6 +1545,7 @@ protected: int address_family_ = AF_UNSPEC; bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; SocketOptions socket_options_ = nullptr; bool compress_ = false; @@ -1422,6 +1573,8 @@ protected: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::function server_certificate_verifier_; #endif Logger logger_; @@ -1430,6 +1583,9 @@ private: bool send_(Request &req, Response &res, Error &error); Result send_(Request &&req); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_peer_could_be_closed(SSL *ssl) const; +#endif socket_t create_client_socket(Error &error) const; bool read_response_line(Stream &strm, const Request &req, Response &res) const; @@ -1448,7 +1604,7 @@ private: const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type); + const std::string &content_type, Progress progress); ContentProviderWithoutLength get_multipart_content_provider( const std::string &boundary, const MultipartFormDataItems &items, const MultipartFormDataProviderItems &provider_items) const; @@ -1477,6 +1633,7 @@ public: const std::string &client_key_path); Client(Client &&) = default; + Client &operator=(Client &&) = default; ~Client(); @@ -1523,10 +1680,18 @@ public: const std::string &content_type); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); @@ -1542,6 +1707,8 @@ public: Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); Result Post(const std::string &path, const MultipartFormDataItems &items); Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); @@ -1556,10 +1723,18 @@ public: const std::string &content_type); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); Result Put(const std::string &path, @@ -1574,6 +1749,8 @@ public: Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress); Result Put(const std::string &path, const MultipartFormDataItems &items); Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); @@ -1586,13 +1763,23 @@ public: Result Patch(const std::string &path); Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type); @@ -1610,13 +1797,24 @@ public: Result Delete(const std::string &path, const Headers &headers); Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type, + Progress progress); Result Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress); Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress); Result Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type, + Progress progress); Result Options(const std::string &path); Result Options(const std::string &path, const Headers &headers); @@ -1685,6 +1883,8 @@ public: #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_server_certificate_verifier(std::function verifier); #endif void set_logger(Logger logger); @@ -1730,6 +1930,9 @@ public: SSL_CTX *ssl_context() const; + void update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + private: bool process_and_close_socket(socket_t sock) override; @@ -1810,68 +2013,58 @@ inline void duration_to_sec_and_usec(const T &duration, U callback) { callback(static_cast(sec), static_cast(usec)); } +inline bool is_numeric(const std::string &str) { + return !str.empty() && std::all_of(str.begin(), str.end(), ::isdigit); +} + inline uint64_t get_header_value_u64(const Headers &headers, - const std::string &key, size_t id, - uint64_t def) { + const std::string &key, uint64_t def, + size_t id, bool &is_invalid_value) { + is_invalid_value = false; auto rng = headers.equal_range(key); auto it = rng.first; std::advance(it, static_cast(id)); if (it != rng.second) { - return std::strtoull(it->second.data(), nullptr, 10); + if (is_numeric(it->second)) { + return std::strtoull(it->second.data(), nullptr, 10); + } else { + is_invalid_value = true; + } } return def; } +inline uint64_t get_header_value_u64(const Headers &headers, + const std::string &key, uint64_t def, + size_t id) { + bool dummy = false; + return get_header_value_u64(headers, key, def, id, dummy); +} + } // namespace detail inline uint64_t Request::get_header_value_u64(const std::string &key, - size_t id) const { - return detail::get_header_value_u64(headers, key, id, 0); + uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); } inline uint64_t Response::get_header_value_u64(const std::string &key, - size_t id) const { - return detail::get_header_value_u64(headers, key, id, 0); -} - -template -inline ssize_t Stream::write_format(const char *fmt, const Args &...args) { - const auto bufsiz = 2048; - std::array buf{}; - - auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); - if (sn <= 0) { return sn; } - - auto n = static_cast(sn); - - if (n >= buf.size() - 1) { - std::vector glowable_buf(buf.size()); - - while (n >= glowable_buf.size() - 1) { - glowable_buf.resize(glowable_buf.size() * 2); - n = static_cast( - snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); - } - return write(&glowable_buf[0], n); - } else { - return write(buf.data(), n); - } + uint64_t def, size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); } inline void default_socket_options(socket_t sock) { - int yes = 1; + int opt = 1; #ifdef _WIN32 setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, - reinterpret_cast(&yes), sizeof(yes)); - setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); #else #ifdef SO_REUSEPORT setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); #else setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); #endif #endif } @@ -1997,6 +2190,8 @@ inline std::string to_string(const Error error) { case Error::SSLConnection: return "SSL connection failed"; case Error::SSLLoadingCerts: return "SSL certificate loading failed"; case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; case Error::UnsupportedMultipartBoundaryChars: return "Unsupported HTTP multipart boundary characters"; case Error::Compression: return "Compression failed"; @@ -2016,8 +2211,9 @@ inline std::ostream &operator<<(std::ostream &os, const Error &obj) { } inline uint64_t Result::get_request_header_value_u64(const std::string &key, + uint64_t def, size_t id) const { - return detail::get_header_value_u64(request_headers_, key, id, 0); + return detail::get_header_value_u64(request_headers_, key, def, id); } template @@ -2080,6 +2276,36 @@ make_basic_authentication_header(const std::string &username, namespace detail { +#if defined(_WIN32) +inline std::wstring u8string_to_wstring(const char *s) { + std::wstring ws; + auto len = static_cast(strlen(s)); + auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0); + if (wlen > 0) { + ws.resize(wlen); + wlen = ::MultiByteToWideChar( + CP_UTF8, 0, s, len, + const_cast(reinterpret_cast(ws.data())), wlen); + if (wlen != static_cast(ws.size())) { ws.clear(); } + } + return ws; +} +#endif + +struct FileStat { + FileStat(const std::string &path); + bool is_file() const; + bool is_dir() const; + +private: +#if defined(_WIN32) + struct _stat st_; +#else + struct stat st_; +#endif + int ret_ = -1; +}; + std::string encode_query_param(const std::string &value); std::string decode_url(const std::string &s, bool convert_plus_to_space); @@ -2088,6 +2314,16 @@ void read_file(const std::string &path, std::string &out); std::string trim_copy(const std::string &s); +void divide( + const char *data, std::size_t size, char d, + std::function + fn); + +void divide( + const std::string &str, char d, + std::function + fn); + void split(const char *b, const char *e, char d, std::function fn); @@ -2099,18 +2335,23 @@ bool process_client_socket(socket_t sock, time_t read_timeout_sec, time_t write_timeout_usec, std::function callback); -socket_t create_client_socket( - const std::string &host, const std::string &ip, int port, - int address_family, bool tcp_nodelay, SocketOptions socket_options, - time_t connection_timeout_sec, time_t connection_timeout_usec, - time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, const std::string &intf, Error &error); +socket_t create_client_socket(const std::string &host, const std::string &ip, + int port, int address_family, bool tcp_nodelay, + bool ipv6_v6only, SocketOptions socket_options, + time_t connection_timeout_sec, + time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + const std::string &intf, Error &error); const char *get_header_value(const Headers &headers, const std::string &key, - size_t id = 0, const char *def = nullptr); + const char *def, size_t id); std::string params_to_query_str(const Params ¶ms); +void parse_query_text(const char *data, std::size_t size, Params ¶ms); + void parse_query_text(const std::string &s, Params ¶ms); bool parse_multipart_boundary(const std::string &content_type, @@ -2270,15 +2511,70 @@ public: private: #if defined(_WIN32) - HANDLE hFile_; - HANDLE hMapping_; + HANDLE hFile_ = NULL; + HANDLE hMapping_ = NULL; #else - int fd_; + int fd_ = -1; #endif - size_t size_; - void *addr_; + size_t size_ = 0; + void *addr_ = nullptr; + bool is_open_empty_file = false; }; +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +inline bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; +} + +inline bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +inline bool is_field_name(const std::string &s) { return is_token(s); } + +inline bool is_vchar(char c) { return c >= 33 && c <= 126; } + +inline bool is_obs_text(char c) { return 128 <= static_cast(c); } + +inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +inline bool is_field_content(const std::string &s) { + if (s.empty()) { return false; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +inline bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + } // namespace detail // ---------------------------------------------------------------------------- @@ -2392,20 +2688,6 @@ inline std::string base64_encode(const std::string &in) { return out; } -inline bool is_file(const std::string &path) { -#ifdef _WIN32 - return _access_s(path.c_str(), 0) == 0; -#else - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); -#endif -} - -inline bool is_dir(const std::string &path) { - struct stat st; - return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); -} - inline bool is_valid_path(const std::string &path) { size_t level = 0; size_t i = 0; @@ -2448,6 +2730,21 @@ inline bool is_valid_path(const std::string &path) { return true; } +inline FileStat::FileStat(const std::string &path) { +#if defined(_WIN32) + auto wpath = u8string_to_wstring(path.c_str()); + ret_ = _wstat(wpath.c_str(), &st_); +#else + ret_ = stat(path.c_str(), &st_); +#endif +} +inline bool FileStat::is_file() const { + return ret_ >= 0 && S_ISREG(st_.st_mode); +} +inline bool FileStat::is_dir() const { + return ret_ >= 0 && S_ISDIR(st_.st_mode); +} + inline std::string encode_query_param(const std::string &value) { std::ostringstream escaped; escaped.fill('0'); @@ -2579,6 +2876,27 @@ inline std::string trim_double_quotes_copy(const std::string &s) { return s; } +inline void +divide(const char *data, std::size_t size, char d, + std::function + fn) { + const auto it = std::find(data, data + size, d); + const auto found = static_cast(it != data + size); + const auto lhs_data = data; + const auto lhs_size = static_cast(it - data); + const auto rhs_data = it + found; + const auto rhs_size = size - lhs_size - found; + + fn(lhs_data, lhs_size, rhs_data, rhs_size); +} + +inline void +divide(const std::string &str, char d, + std::function + fn) { + divide(str.data(), str.size(), d, std::move(fn)); +} + inline void split(const char *b, const char *e, char d, std::function fn) { return split(b, e, d, (std::numeric_limits::max)(), std::move(fn)); @@ -2636,6 +2954,10 @@ inline bool stream_line_reader::getline() { fixed_buffer_used_size_ = 0; glowable_buffer_.clear(); +#ifndef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + char prev_byte = 0; +#endif + for (size_t i = 0;; i++) { char byte; auto n = strm_.read(&byte, 1); @@ -2652,7 +2974,12 @@ inline bool stream_line_reader::getline() { append(byte); +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR if (byte == '\n') { break; } +#else + if (prev_byte == '\r' && byte == '\n') { break; } + prev_byte = byte; +#endif } return true; @@ -2671,16 +2998,7 @@ inline void stream_line_reader::append(char c) { } } -inline mmap::mmap(const char *path) -#if defined(_WIN32) - : hFile_(NULL), hMapping_(NULL) -#else - : fd_(-1) -#endif - , - size_(0), addr_(nullptr) { - open(path); -} +inline mmap::mmap(const char *path) { open(path); } inline mmap::~mmap() { close(); } @@ -2688,29 +3006,60 @@ inline bool mmap::open(const char *path) { close(); #if defined(_WIN32) - std::wstring wpath; - for (size_t i = 0; i < strlen(path); i++) { - wpath += path[i]; - } + auto wpath = u8string_to_wstring(path); + if (wpath.empty()) { return false; } +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, NULL); +#else + hFile_ = ::CreateFileW(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, + OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); +#endif if (hFile_ == INVALID_HANDLE_VALUE) { return false; } LARGE_INTEGER size{}; if (!::GetFileSizeEx(hFile_, &size)) { return false; } + // If the following line doesn't compile due to QuadPart, update Windows SDK. + // See: + // https://github.com/yhirose/cpp-httplib/issues/1903#issuecomment-2316520721 + if (static_cast(size.QuadPart) > + (std::numeric_limits::max)()) { + // `size_t` might be 32-bits, on 32-bits Windows. + return false; + } size_ = static_cast(size.QuadPart); +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 hMapping_ = ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); +#else + hMapping_ = ::CreateFileMappingW(hFile_, NULL, PAGE_READONLY, 0, 0, NULL); +#endif + + // Special treatment for an empty file... + if (hMapping_ == NULL && size_ == 0) { + close(); + is_open_empty_file = true; + return true; + } if (hMapping_ == NULL) { close(); return false; } +#if _WIN32_WINNT >= _WIN32_WINNT_WIN8 addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0); +#else + addr_ = ::MapViewOfFile(hMapping_, FILE_MAP_READ, 0, 0, 0); +#endif + + if (addr_ == nullptr) { + close(); + return false; + } #else fd_ = ::open(path, O_RDONLY); if (fd_ == -1) { return false; } @@ -2723,22 +3072,26 @@ inline bool mmap::open(const char *path) { size_ = static_cast(sb.st_size); addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); -#endif - if (addr_ == nullptr) { + // Special treatment for an empty file... + if (addr_ == MAP_FAILED && size_ == 0) { close(); + is_open_empty_file = true; return false; } +#endif return true; } -inline bool mmap::is_open() const { return addr_ != nullptr; } +inline bool mmap::is_open() const { + return is_open_empty_file ? true : addr_ != nullptr; +} inline size_t mmap::size() const { return size_; } inline const char *mmap::data() const { - return static_cast(addr_); + return is_open_empty_file ? "" : static_cast(addr_); } inline void mmap::close() { @@ -2757,6 +3110,8 @@ inline void mmap::close() { ::CloseHandle(hFile_); hFile_ = INVALID_HANDLE_VALUE; } + + is_open_empty_file = false; #else if (addr_ != nullptr) { munmap(addr_, size_); @@ -2782,7 +3137,10 @@ template inline ssize_t handle_EINTR(T fn) { ssize_t res = 0; while (true) { res = fn(); - if (res < 0 && errno == EINTR) { continue; } + if (res < 0 && errno == EINTR) { + std::this_thread::sleep_for(std::chrono::microseconds{1}); + continue; + } break; } return res; @@ -2991,23 +3349,37 @@ private: }; #endif -inline bool keep_alive(socket_t sock, time_t keep_alive_timeout_sec) { +inline bool keep_alive(const std::atomic &svr_sock, socket_t sock, + time_t keep_alive_timeout_sec) { using namespace std::chrono; - auto start = steady_clock::now(); + + const auto interval_usec = + CPPHTTPLIB_KEEPALIVE_TIMEOUT_CHECK_INTERVAL_USECOND; + + // Avoid expensive `steady_clock::now()` call for the first time + if (select_read(sock, 0, interval_usec) > 0) { return true; } + + const auto start = steady_clock::now() - microseconds{interval_usec}; + const auto timeout = seconds{keep_alive_timeout_sec}; + while (true) { - auto val = select_read(sock, 0, 10000); + if (svr_sock == INVALID_SOCKET) { + break; // Server socket is closed + } + + auto val = select_read(sock, 0, interval_usec); if (val < 0) { - return false; + break; // Ssocket error } else if (val == 0) { - auto current = steady_clock::now(); - auto duration = duration_cast(current - start); - auto timeout = keep_alive_timeout_sec * 1000; - if (duration.count() > timeout) { return false; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + if (steady_clock::now() - start > timeout) { + break; // Timeout + } } else { - return true; + return true; // Ready for read } } + + return false; } template @@ -3018,8 +3390,7 @@ process_server_socket_core(const std::atomic &svr_sock, socket_t sock, assert(keep_alive_max_count > 0); auto ret = false; auto count = keep_alive_max_count; - while (svr_sock != INVALID_SOCKET && count > 0 && - keep_alive(sock, keep_alive_timeout_sec)) { + while (count > 0 && keep_alive(svr_sock, sock, keep_alive_timeout_sec)) { auto close_connection = count == 1; auto connection_closed = false; ret = callback(close_connection, connection_closed); @@ -3063,10 +3434,29 @@ inline int shutdown_socket(socket_t sock) { #endif } +inline std::string escape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '\0') { + auto ret = s; + ret[0] = '@'; + return ret; + } + return s; +} + +inline std::string +unescape_abstract_namespace_unix_domain(const std::string &s) { + if (s.size() > 1 && s[0] == '@') { + auto ret = s; + ret[0] = '\0'; + return ret; + } + return s; +} + template socket_t create_socket(const std::string &host, const std::string &ip, int port, int address_family, int socket_flags, bool tcp_nodelay, - SocketOptions socket_options, + bool ipv6_v6only, SocketOptions socket_options, BindOrConnect bind_or_connect) { // Get address info const char *node = nullptr; @@ -3075,7 +3465,7 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, memset(&hints, 0, sizeof(struct addrinfo)); hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = 0; + hints.ai_protocol = IPPROTO_IP; if (!ip.empty()) { node = ip.c_str(); @@ -3093,20 +3483,32 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, const auto addrlen = host.length(); if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } +#ifdef SOCK_CLOEXEC + auto sock = socket(hints.ai_family, hints.ai_socktype | SOCK_CLOEXEC, + hints.ai_protocol); +#else auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); +#endif + if (sock != INVALID_SOCKET) { sockaddr_un addr{}; addr.sun_family = AF_UNIX; - std::copy(host.begin(), host.end(), addr.sun_path); + + auto unescaped_host = unescape_abstract_namespace_unix_domain(host); + std::copy(unescaped_host.begin(), unescaped_host.end(), addr.sun_path); hints.ai_addr = reinterpret_cast(&addr); hints.ai_addrlen = static_cast( sizeof(addr) - sizeof(addr.sun_path) + addrlen); +#ifndef SOCK_CLOEXEC fcntl(sock, F_SETFD, FD_CLOEXEC); +#endif + if (socket_options) { socket_options(sock); } - if (!bind_or_connect(sock, hints)) { + bool dummy; + if (!bind_or_connect(sock, hints, dummy)) { close_socket(sock); sock = INVALID_SOCKET; } @@ -3123,6 +3525,7 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, #endif return INVALID_SOCKET; } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); for (auto rp = result; rp; rp = rp->ai_next) { // Create a socket @@ -3148,11 +3551,18 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); } #else + +#ifdef SOCK_CLOEXEC + auto sock = + socket(rp->ai_family, rp->ai_socktype | SOCK_CLOEXEC, rp->ai_protocol); +#else auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + #endif if (sock == INVALID_SOCKET) { continue; } -#ifndef _WIN32 +#if !defined _WIN32 && !defined SOCK_CLOEXEC if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { close_socket(sock); continue; @@ -3160,39 +3570,38 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, #endif if (tcp_nodelay) { - auto yes = 1; + auto opt = 1; #ifdef _WIN32 setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); #else setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, - reinterpret_cast(&yes), sizeof(yes)); + reinterpret_cast(&opt), sizeof(opt)); +#endif + } + + if (rp->ai_family == AF_INET6) { + auto opt = ipv6_v6only ? 1 : 0; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&opt), sizeof(opt)); +#else + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&opt), sizeof(opt)); #endif } if (socket_options) { socket_options(sock); } - if (rp->ai_family == AF_INET6) { - auto no = 0; -#ifdef _WIN32 - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, - reinterpret_cast(&no), sizeof(no)); -#else - setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, - reinterpret_cast(&no), sizeof(no)); -#endif - } - // bind or connect - if (bind_or_connect(sock, *rp)) { - freeaddrinfo(result); - return sock; - } + auto quit = false; + if (bind_or_connect(sock, *rp, quit)) { return sock; } close_socket(sock); + + if (quit) { break; } } - freeaddrinfo(result); return INVALID_SOCKET; } @@ -3225,6 +3634,7 @@ inline bool bind_ip_address(socket_t sock, const std::string &host) { hints.ai_protocol = 0; if (getaddrinfo(host.c_str(), "0", &hints, &result)) { return false; } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); auto ret = false; for (auto rp = result; rp; rp = rp->ai_next) { @@ -3235,7 +3645,6 @@ inline bool bind_ip_address(socket_t sock, const std::string &host) { } } - freeaddrinfo(result); return ret; } @@ -3247,6 +3656,8 @@ inline bool bind_ip_address(socket_t sock, const std::string &host) { inline std::string if2ip(int address_family, const std::string &ifn) { struct ifaddrs *ifap; getifaddrs(&ifap); + auto se = detail::scope_exit([&] { freeifaddrs(ifap); }); + std::string addr_candidate; for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { if (ifa->ifa_addr && ifn == ifa->ifa_name && @@ -3256,7 +3667,6 @@ inline std::string if2ip(int address_family, const std::string &ifn) { auto sa = reinterpret_cast(ifa->ifa_addr); char buf[INET_ADDRSTRLEN]; if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { - freeifaddrs(ifap); return std::string(buf, INET_ADDRSTRLEN); } } else if (ifa->ifa_addr->sa_family == AF_INET6) { @@ -3269,7 +3679,6 @@ inline std::string if2ip(int address_family, const std::string &ifn) { if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { addr_candidate = std::string(buf, INET6_ADDRSTRLEN); } else { - freeifaddrs(ifap); return std::string(buf, INET6_ADDRSTRLEN); } } @@ -3277,20 +3686,21 @@ inline std::string if2ip(int address_family, const std::string &ifn) { } } } - freeifaddrs(ifap); return addr_candidate; } #endif inline socket_t create_client_socket( const std::string &host, const std::string &ip, int port, - int address_family, bool tcp_nodelay, SocketOptions socket_options, - time_t connection_timeout_sec, time_t connection_timeout_usec, - time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + int address_family, bool tcp_nodelay, bool ipv6_v6only, + SocketOptions socket_options, time_t connection_timeout_sec, + time_t connection_timeout_usec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, time_t write_timeout_usec, const std::string &intf, Error &error) { auto sock = create_socket( - host, ip, port, address_family, 0, tcp_nodelay, std::move(socket_options), - [&](socket_t sock2, struct addrinfo &ai) -> bool { + host, ip, port, address_family, 0, tcp_nodelay, ipv6_v6only, + std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai, bool &quit) -> bool { if (!intf.empty()) { #ifdef USE_IF2IP auto ip_from_if = if2ip(address_family, intf); @@ -3314,7 +3724,10 @@ inline socket_t create_client_socket( } error = wait_until_socket_is_ready(sock2, connection_timeout_sec, connection_timeout_usec); - if (error != Error::Success) { return false; } + if (error != Error::Success) { + if (error == Error::ConnectionTimeout) { quit = true; } + return false; + } } set_nonblocking(sock2, false); @@ -3439,7 +3852,7 @@ inline unsigned int str2tag(const std::string &s) { namespace udl { -inline constexpr unsigned int operator"" _t(const char *s, size_t l) { +inline constexpr unsigned int operator""_t(const char *s, size_t l) { return str2tag_core(s, l, 0); } @@ -3524,8 +3937,9 @@ inline bool can_compress_content_type(const std::string &content_type) { case "application/protobuf"_t: case "application/xhtml+xml"_t: return true; - default: - return !content_type.rfind("text/", 0) && tag != "text/event-stream"_t; + case "text/event-stream"_t: return false; + + default: return !content_type.rfind("text/", 0); } } @@ -3762,8 +4176,8 @@ inline bool has_header(const Headers &headers, const std::string &key) { } inline const char *get_header_value(const Headers &headers, - const std::string &key, size_t id, - const char *def) { + const std::string &key, const char *def, + size_t id) { auto rng = headers.equal_range(key); auto it = rng.first; std::advance(it, static_cast(id)); @@ -3771,14 +4185,6 @@ inline const char *get_header_value(const Headers &headers, return def; } -inline bool compare_case_ignore(const std::string &a, const std::string &b) { - if (a.size() != b.size()) { return false; } - for (size_t i = 0; i < b.size(); i++) { - if (::tolower(a[i]) != ::tolower(b[i])) { return false; } - } - return true; -} - template inline bool parse_header(const char *beg, const char *end, T fn) { // Skip trailing spaces and tabs. @@ -3801,15 +4207,27 @@ inline bool parse_header(const char *beg, const char *end, T fn) { p++; } - if (p < end) { + if (p <= end) { auto key_len = key_end - beg; if (!key_len) { return false; } auto key = std::string(beg, key_end); - auto val = compare_case_ignore(key, "Location") + auto val = case_ignore::equal(key, "Location") ? std::string(p, end) : decode_url(std::string(p, end), false); - fn(std::move(key), std::move(val)); + + // NOTE: From RFC 9110: + // Field values containing CR, LF, or NUL characters are + // invalid and dangerous, due to the varying ways that + // implementations might parse and interpret those + // characters; a recipient of CR, LF, or NUL within a field + // value MUST either reject the message or replace each of + // those characters with SP before further processing or + // forwarding of that message. + static const std::string CR_LF_NUL("\r\n\0", 3); + if (val.find_first_of(CR_LF_NUL) != std::string::npos) { return false; } + + fn(key, val); return true; } @@ -3829,27 +4247,27 @@ inline bool read_headers(Stream &strm, Headers &headers) { if (line_reader.end_with_crlf()) { // Blank line indicates end of headers. if (line_reader.size() == 2) { break; } -#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR } else { +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR // Blank line indicates end of headers. if (line_reader.size() == 1) { break; } line_terminator_len = 1; - } #else - } else { continue; // Skip invalid line. - } #endif + } if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } // Exclude line terminator auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; - parse_header(line_reader.ptr(), end, - [&](std::string &&key, std::string &&val) { - headers.emplace(std::move(key), std::move(val)); - }); + if (!parse_header(line_reader.ptr(), end, + [&](const std::string &key, std::string &val) { + headers.emplace(key, val); + })) { + return false; + } } return true; @@ -3937,8 +4355,19 @@ inline bool read_content_chunked(Stream &strm, T &x, assert(chunk_len == 0); - // Trailer - if (!line_reader.getline()) { return false; } + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentiones "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // does't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-htpplib now allows + // chuncked transfer coding data without the final CRLF. + if (!line_reader.getline()) { return true; } while (strcmp(line_reader.ptr(), "\r\n") != 0) { if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } @@ -3948,8 +4377,8 @@ inline bool read_content_chunked(Stream &strm, T &x, auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; parse_header(line_reader.ptr(), end, - [&](std::string &&key, std::string &&val) { - x.headers.emplace(std::move(key), std::move(val)); + [&](const std::string &key, const std::string &val) { + x.headers.emplace(key, val); }); if (!line_reader.getline()) { return false; } @@ -3959,8 +4388,8 @@ inline bool read_content_chunked(Stream &strm, T &x, } inline bool is_chunked_transfer_encoding(const Headers &headers) { - return compare_case_ignore( - get_header_value(headers, "Transfer-Encoding", 0, ""), "chunked"); + return case_ignore::equal( + get_header_value(headers, "Transfer-Encoding", "", 0), "chunked"); } template @@ -4026,8 +4455,14 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, } else if (!has_header(x.headers, "Content-Length")) { ret = read_content_without_length(strm, out); } else { - auto len = get_header_value_u64(x.headers, "Content-Length", 0, 0); - if (len > payload_max_length) { + auto is_invalid_value = false; + auto len = get_header_value_u64(x.headers, "Content-Length", + std::numeric_limits::max(), + 0, is_invalid_value); + + if (is_invalid_value) { + ret = false; + } else if (len > payload_max_length) { exceed_payload_max_length = true; skip_content_with_length(strm, len); ret = false; @@ -4042,13 +4477,36 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, } return ret; }); -} // namespace detail +} + +inline ssize_t write_request_line(Stream &strm, const std::string &method, + const std::string &path) { + std::string s = method; + s += " "; + s += path; + s += " HTTP/1.1\r\n"; + return strm.write(s.data(), s.size()); +} + +inline ssize_t write_response_line(Stream &strm, int status) { + std::string s = "HTTP/1.1 "; + s += std::to_string(status); + s += " "; + s += httplib::status_message(status); + s += "\r\n"; + return strm.write(s.data(), s.size()); +} inline ssize_t write_headers(Stream &strm, const Headers &headers) { ssize_t write_len = 0; for (const auto &x : headers) { - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + std::string s; + s = x.first; + s += ": "; + s += x.second; + s += "\r\n"; + + auto len = strm.write(s.data(), s.size()); if (len < 0) { return len; } write_len += len; } @@ -4302,22 +4760,22 @@ inline std::string params_to_query_str(const Params ¶ms) { return query; } -inline void parse_query_text(const std::string &s, Params ¶ms) { +inline void parse_query_text(const char *data, std::size_t size, + Params ¶ms) { std::set cache; - split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { + split(data, data + size, '&', [&](const char *b, const char *e) { std::string kv(b, e); if (cache.find(kv) != cache.end()) { return; } - cache.insert(kv); + cache.insert(std::move(kv)); std::string key; std::string val; - split(b, e, '=', [&](const char *b2, const char *e2) { - if (key.empty()) { - key.assign(b2, e2); - } else { - val.assign(b2, e2); - } - }); + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, const char *rhs_data, + std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); if (!key.empty()) { params.emplace(decode_url(key, true), decode_url(val, true)); @@ -4325,6 +4783,10 @@ inline void parse_query_text(const std::string &s, Params ¶ms) { }); } +inline void parse_query_text(const std::string &s, Params ¶ms) { + parse_query_text(s.data(), s.size(), params); +} + inline bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { auto boundary_keyword = "boundary="; @@ -4365,35 +4827,44 @@ inline bool parse_range_header(const std::string &s, Ranges &ranges) { #else inline bool parse_range_header(const std::string &s, Ranges &ranges) try { #endif - static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); - std::smatch m; - if (std::regex_match(s, m, re_first_range)) { - auto pos = static_cast(m.position(1)); - auto len = static_cast(m.length(1)); + auto is_valid = [](const std::string &str) { + return std::all_of(str.cbegin(), str.cend(), + [](unsigned char c) { return std::isdigit(c); }); + }; + + if (s.size() > 7 && s.compare(0, 6, "bytes=") == 0) { + const auto pos = static_cast(6); + const auto len = static_cast(s.size() - 6); auto all_valid_ranges = true; split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { if (!all_valid_ranges) { return; } - static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); - std::cmatch cm; - if (std::regex_match(b, e, cm, re_another_range)) { - ssize_t first = -1; - if (!cm.str(1).empty()) { - first = static_cast(std::stoll(cm.str(1))); - } - ssize_t last = -1; - if (!cm.str(2).empty()) { - last = static_cast(std::stoll(cm.str(2))); - } - - if (first != -1 && last != -1 && first > last) { - all_valid_ranges = false; - return; - } - ranges.emplace_back(std::make_pair(first, last)); + const auto it = std::find(b, e, '-'); + if (it == e) { + all_valid_ranges = false; + return; } + + const auto lhs = std::string(b, it); + const auto rhs = std::string(it + 1, e); + if (!is_valid(lhs) || !is_valid(rhs)) { + all_valid_ranges = false; + return; + } + + const auto first = + static_cast(lhs.empty() ? -1 : std::stoll(lhs)); + const auto last = + static_cast(rhs.empty() ? -1 : std::stoll(rhs)); + if ((first == -1 && last == -1) || + (first != -1 && last != -1 && first > last)) { + all_valid_ranges = false; + return; + } + + ranges.emplace_back(first, last); }); - return all_valid_ranges; + return all_valid_ranges && !ranges.empty(); } return false; #ifdef CPPHTTPLIB_NO_EXCEPTIONS @@ -4452,7 +4923,7 @@ public: const auto header = buf_head(pos); if (!parse_header(header.data(), header.data() + header.size(), - [&](std::string &&, std::string &&) {})) { + [&](const std::string &, const std::string &) {})) { is_valid_ = false; return false; } @@ -4562,7 +5033,9 @@ private: const std::string &b) const { if (a.size() < b.size()) { return false; } for (size_t i = 0; i < b.size(); i++) { - if (::tolower(a[i]) != ::tolower(b[i])) { return false; } + if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + return false; + } } return true; } @@ -4645,16 +5118,6 @@ private: size_t buf_epos_ = 0; }; -inline std::string to_lower(const char *beg, const char *end) { - std::string out; - auto it = beg; - while (it != end) { - out += static_cast(::tolower(*it)); - it++; - } - return out; -} - inline std::string random_string(size_t length) { static const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; @@ -4768,7 +5231,18 @@ inline bool range_error(Request &req, Response &res) { last_pos = contant_len - 1; } - if (last_pos == -1) { last_pos = contant_len - 1; } + // NOTE: RFC-9110 '14.1.2. Byte Ranges': + // A client can limit the number of bytes requested without knowing the + // size of the selected representation. If the last-pos value is absent, + // or if the value is greater than or equal to the current length of the + // representation data, the byte range is interpreted as the remainder of + // the representation (i.e., the server replaces the value of last-pos + // with a value that is one less than the current length of the selected + // representation). + // https://www.rfc-editor.org/rfc/rfc9110.html#section-14.1.2-6 + if (last_pos == -1 || last_pos >= contant_len) { + last_pos = contant_len - 1; + } // Range must be within content length if (!(0 <= first_pos && first_pos <= last_pos && @@ -4795,12 +5269,11 @@ inline bool range_error(Request &req, Response &res) { inline std::pair get_range_offset_and_length(Range r, size_t content_length) { - (void)(content_length); // patch to get rid of "unused parameter" on release build assert(r.first != -1 && r.second != -1); assert(0 <= r.first && r.first < static_cast(content_length)); assert(r.first <= r.second && r.second < static_cast(content_length)); - + (void)(content_length); return std::make_pair(r.first, static_cast(r.second - r.first) + 1); } @@ -5230,6 +5703,7 @@ inline void hosted_at(const std::string &hostname, #endif return; } + auto se = detail::scope_exit([&] { freeaddrinfo(result); }); for (auto rp = result; rp; rp = rp->ai_next) { const auto &addr = @@ -5241,8 +5715,6 @@ inline void hosted_at(const std::string &hostname, addrs.push_back(ip); } } - - freeaddrinfo(result); } inline std::string append_query_params(const std::string &path, @@ -5291,8 +5763,8 @@ inline bool Request::has_header(const std::string &key) const { } inline std::string Request::get_header_value(const std::string &key, - size_t id) const { - return detail::get_header_value(headers, key, id, ""); + const char *def, size_t id) const { + return detail::get_header_value(headers, key, def, id); } inline size_t Request::get_header_value_count(const std::string &key) const { @@ -5302,7 +5774,8 @@ inline size_t Request::get_header_value_count(const std::string &key) const { inline void Request::set_header(const std::string &key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { headers.emplace(key, val); } } @@ -5356,8 +5829,9 @@ inline bool Response::has_header(const std::string &key) const { } inline std::string Response::get_header_value(const std::string &key, + const char *def, size_t id) const { - return detail::get_header_value(headers, key, id, ""); + return detail::get_header_value(headers, key, def, id); } inline size_t Response::get_header_value_count(const std::string &key) const { @@ -5367,13 +5841,14 @@ inline size_t Response::get_header_value_count(const std::string &key) const { inline void Response::set_header(const std::string &key, const std::string &val) { - if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + if (detail::fields::is_field_name(key) && + detail::fields::is_field_value(val)) { headers.emplace(key, val); } } inline void Response::set_redirect(const std::string &url, int stat) { - if (!detail::has_crlf(url)) { + if (detail::fields::is_field_value(url)) { set_header("Location", url); if (300 <= stat && stat < 400) { this->status = stat; @@ -5436,14 +5911,25 @@ inline void Response::set_chunked_content_provider( is_chunked_content_provider_ = true; } +inline void Response::set_file_content(const std::string &path, + const std::string &content_type) { + file_content_path_ = path; + file_content_content_type_ = content_type; +} + +inline void Response::set_file_content(const std::string &path) { + file_content_path_ = path; +} + // Result implementation inline bool Result::has_request_header(const std::string &key) const { return request_headers_.find(key) != request_headers_.end(); } inline std::string Result::get_request_header_value(const std::string &key, + const char *def, size_t id) const { - return detail::get_header_value(request_headers_, key, id, ""); + return detail::get_header_value(request_headers_, key, def, id); } inline size_t @@ -5584,6 +6070,8 @@ inline socket_t BufferStream::socket() const { return 0; } inline const std::string &BufferStream::get_buffer() const { return buffer; } inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { + static constexpr char marker[] = "/:"; + // One past the last ending position of a path param substring std::size_t last_param_end = 0; @@ -5596,13 +6084,14 @@ inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { #endif while (true) { - const auto marker_pos = pattern.find(marker, last_param_end); + const auto marker_pos = pattern.find( + marker, last_param_end == 0 ? last_param_end : last_param_end - 1); if (marker_pos == std::string::npos) { break; } static_fragments_.push_back( - pattern.substr(last_param_end, marker_pos - last_param_end)); + pattern.substr(last_param_end, marker_pos - last_param_end + 1)); - const auto param_name_start = marker_pos + 1; + const auto param_name_start = marker_pos + 2; auto sep_pos = pattern.find(separator, param_name_start); if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } @@ -5664,7 +6153,7 @@ inline bool PathParamsMatcher::match(Request &request) const { request.path_params.emplace( param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); - // Mark everythin up to '/' as matched + // Mark everything up to '/' as matched starting_pos = sep_pos + 1; } // Returns false if the path is longer than the pattern @@ -5763,7 +6252,8 @@ inline bool Server::set_base_dir(const std::string &dir, inline bool Server::set_mount_point(const std::string &mount_point, const std::string &dir, Headers headers) { - if (detail::is_dir(dir)) { + detail::FileStat stat(dir); + if (stat.is_dir()) { std::string mnt = !mount_point.empty() ? mount_point : "/"; if (!mnt.empty() && mnt[0] == '/') { base_dirs_.push_back({mnt, dir, std::move(headers)}); @@ -5800,12 +6290,14 @@ inline Server &Server::set_file_request_handler(Handler handler) { return *this; } -inline Server &Server::set_error_handler(HandlerWithResponse handler) { +inline Server &Server::set_error_handler_core(HandlerWithResponse handler, + std::true_type) { error_handler_ = std::move(handler); return *this; } -inline Server &Server::set_error_handler(Handler handler) { +inline Server &Server::set_error_handler_core(Handler handler, + std::false_type) { error_handler_ = [handler](const Request &req, Response &res) { handler(req, res); return HandlerResponse::Handled; @@ -5849,6 +6341,11 @@ inline Server &Server::set_tcp_nodelay(bool on) { return *this; } +inline Server &Server::set_ipv6_v6only(bool on) { + ipv6_v6only_ = on; + return *this; +} + inline Server &Server::set_socket_options(SocketOptions socket_options) { socket_options_ = std::move(socket_options); return *this; @@ -5900,27 +6397,27 @@ inline Server &Server::set_payload_max_length(size_t length) { inline bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { - return bind_internal(host, port, socket_flags) >= 0; + auto ret = bind_internal(host, port, socket_flags); + if (ret == -1) { is_decommisioned = true; } + return ret >= 0; } inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { - return bind_internal(host, 0, socket_flags); + auto ret = bind_internal(host, 0, socket_flags); + if (ret == -1) { is_decommisioned = true; } + return ret; } -inline bool Server::listen_after_bind() { - auto se = detail::scope_exit([&]() { done_ = true; }); - return listen_internal(); -} +inline bool Server::listen_after_bind() { return listen_internal(); } inline bool Server::listen(const std::string &host, int port, int socket_flags) { - auto se = detail::scope_exit([&]() { done_ = true; }); return bind_to_port(host, port, socket_flags) && listen_internal(); } inline bool Server::is_running() const { return is_running_; } inline void Server::wait_until_ready() const { - while (!is_running() && !done_) { + while (!is_running_ && !is_decommisioned) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); } } @@ -5932,8 +6429,11 @@ inline void Server::stop() { detail::shutdown_socket(sock); detail::close_socket(sock); } + is_decommisioned = false; } +inline void Server::decommission() { is_decommisioned = true; } + inline bool Server::parse_request_line(const char *s, Request &req) const { auto len = strlen(s); if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; } @@ -5972,26 +6472,13 @@ inline bool Server::parse_request_line(const char *s, Request &req) const { } } - size_t count = 0; - - detail::split(req.target.data(), req.target.data() + req.target.size(), '?', - 2, [&](const char *b, const char *e) { - switch (count) { - case 0: - req.path = detail::decode_url(std::string(b, e), false); - break; - case 1: { - if (e - b > 0) { - detail::parse_query_text(std::string(b, e), req.params); - } - break; - } - default: break; - } - count++; - }); - - if (count > 2) { return false; } + detail::divide(req.target, '?', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + req.path = detail::decode_url( + std::string(lhs_data, lhs_size), false); + detail::parse_query_text(rhs_data, rhs_size, req.params); + }); } return true; @@ -6030,23 +6517,24 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, if (close_connection || req.get_header_value("Connection") == "close") { res.set_header("Connection", "close"); } else { - std::stringstream ss; - ss << "timeout=" << keep_alive_timeout_sec_ - << ", max=" << keep_alive_max_count_; - res.set_header("Keep-Alive", ss.str()); + std::string s = "timeout="; + s += std::to_string(keep_alive_timeout_sec_); + s += ", max="; + s += std::to_string(keep_alive_max_count_); + res.set_header("Keep-Alive", s); } - if (!res.has_header("Content-Type") && - (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { + if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) && + !res.has_header("Content-Type")) { res.set_header("Content-Type", "text/plain"); } - if (!res.has_header("Content-Length") && res.body.empty() && - !res.content_length_ && !res.content_provider_) { + if (res.body.empty() && !res.content_length_ && !res.content_provider_ && + !res.has_header("Content-Length")) { res.set_header("Content-Length", "0"); } - if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { res.set_header("Accept-Ranges", "bytes"); } @@ -6055,12 +6543,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, // Response line and headers { detail::BufferStream bstrm; - - if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, - status_message(res.status))) { - return false; - } - + if (!detail::write_response_line(bstrm, res.status)) { return false; } if (!header_writer_(bstrm, res.headers)) { return false; } // Flush buffer @@ -6254,7 +6737,14 @@ inline bool Server::handle_file_request(const Request &req, Response &res, auto path = entry.base_dir + sub_path; if (path.back() == '/') { path += "index.html"; } - if (detail::is_file(path)) { + detail::FileStat stat(path); + + if (stat.is_dir()) { + res.set_redirect(sub_path + "/", StatusCode::MovedPermanently_301); + return true; + } + + if (stat.is_file()) { for (const auto &kv : entry.headers) { res.set_header(kv.first, kv.second); } @@ -6289,8 +6779,8 @@ Server::create_server_socket(const std::string &host, int port, SocketOptions socket_options) const { return detail::create_socket( host, std::string(), port, address_family_, socket_flags, tcp_nodelay_, - std::move(socket_options), - [](socket_t sock, struct addrinfo &ai) -> bool { + ipv6_v6only_, std::move(socket_options), + [](socket_t sock, struct addrinfo &ai, bool & /*quit*/) -> bool { if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { return false; } @@ -6301,6 +6791,8 @@ Server::create_server_socket(const std::string &host, int port, inline int Server::bind_internal(const std::string &host, int port, int socket_flags) { + if (is_decommisioned) { return -1; } + if (!is_valid()) { return -1; } svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); @@ -6326,6 +6818,8 @@ inline int Server::bind_internal(const std::string &host, int port, } inline bool Server::listen_internal() { + if (is_decommisioned) { return false; } + auto ret = true; is_running_ = true; auto se = detail::scope_exit([&]() { is_running_ = false; }); @@ -6346,13 +6840,22 @@ inline bool Server::listen_internal() { #ifndef _WIN32 } #endif + +#if defined _WIN32 + // sockets conneced via WASAccept inherit flags NO_HANDLE_INHERIT, + // OVERLAPPED + socket_t sock = WSAAccept(svr_sock_, nullptr, nullptr, nullptr, 0); +#elif defined SOCK_CLOEXEC + socket_t sock = accept4(svr_sock_, nullptr, nullptr, SOCK_CLOEXEC); +#else socket_t sock = accept(svr_sock_, nullptr, nullptr); +#endif if (sock == INVALID_SOCKET) { if (errno == EMFILE) { // The per-process limit of open file descriptors has been reached. // Try to accept new connections after a short sleep. - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::this_thread::sleep_for(std::chrono::microseconds{1}); continue; } else if (errno == EINTR || errno == EAGAIN) { continue; @@ -6406,6 +6909,7 @@ inline bool Server::listen_internal() { task_queue->shutdown(); } + is_decommisioned = !ret; return ret; } @@ -6503,7 +7007,7 @@ inline bool Server::dispatch_request(Request &req, Response &res, inline void Server::apply_ranges(const Request &req, Response &res, std::string &content_type, std::string &boundary) const { - if (req.ranges.size() > 1) { + if (req.ranges.size() > 1 && res.status == StatusCode::PartialContent_206) { auto it = res.headers.find("Content-Type"); if (it != res.headers.end()) { content_type = it->second; @@ -6521,7 +7025,7 @@ inline void Server::apply_ranges(const Request &req, Response &res, if (res.body.empty()) { if (res.content_length_ > 0) { size_t length = 0; - if (req.ranges.empty()) { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { length = res.content_length_; } else if (req.ranges.size() == 1) { auto offset_and_length = detail::get_range_offset_and_length( @@ -6550,7 +7054,7 @@ inline void Server::apply_ranges(const Request &req, Response &res, } } } else { - if (req.ranges.empty()) { + if (req.ranges.empty() || res.status != StatusCode::PartialContent_206) { ; } else if (req.ranges.size() == 1) { auto offset_and_length = @@ -6621,7 +7125,9 @@ inline bool Server::dispatch_request_for_content_reader( } inline bool -Server::process_request(Stream &strm, bool close_connection, +Server::process_request(Stream &strm, const std::string &remote_addr, + int remote_port, const std::string &local_addr, + int local_port, bool close_connection, bool &connection_closed, const std::function &setup_request) { std::array buf{}; @@ -6675,11 +7181,13 @@ Server::process_request(Stream &strm, bool close_connection, connection_closed = true; } - strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.remote_addr = remote_addr; + req.remote_port = remote_port; req.set_header("REMOTE_ADDR", req.remote_addr); req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); - strm.get_local_ip_and_port(req.local_addr, req.local_port); + req.local_addr = local_addr; + req.local_port = local_port; req.set_header("LOCAL_ADDR", req.local_addr); req.set_header("LOCAL_PORT", std::to_string(req.local_port)); @@ -6701,13 +7209,20 @@ Server::process_request(Stream &strm, bool close_connection, switch (status) { case StatusCode::Continue_100: case StatusCode::ExpectationFailed_417: - strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, - status_message(status)); + detail::write_response_line(strm, status); + strm.write("\r\n"); break; - default: return write_response(strm, close_connection, req, res); + default: + connection_closed = true; + return write_response(strm, true, req, res); } } + // Setup `is_connection_closed` method + req.is_connection_closed = [&]() { + return !detail::is_socket_alive(strm.socket()); + }; + // Routing auto routed = false; #ifdef CPPHTTPLIB_NO_EXCEPTIONS @@ -6750,6 +7265,32 @@ Server::process_request(Stream &strm, bool close_connection, : StatusCode::PartialContent_206; } + // Serve file content by using a content provider + if (!res.file_content_path_.empty()) { + const auto &path = res.file_content_path_; + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::NotFound_404; + return write_response(strm, close_connection, req, res); + } + + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type( + path, file_extension_and_mimetype_map_, default_file_mimetype_); + } + + res.set_content_provider( + mm->size(), content_type, + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + } + if (detail::range_error(req, res)) { res.body.clear(); res.content_length_ = 0; @@ -6769,12 +7310,21 @@ Server::process_request(Stream &strm, bool close_connection, inline bool Server::is_valid() const { return true; } inline bool Server::process_and_close_socket(socket_t sock) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + auto ret = detail::process_server_socket( svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [this](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, connection_closed, nullptr); }); @@ -6793,8 +7343,8 @@ inline ClientImpl::ClientImpl(const std::string &host, int port) inline ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) - : host_(host), port_(port), - host_and_port_(adjust_host_string(host) + ":" + std::to_string(port)), + : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), + host_and_port_(adjust_host_string(host_) + ":" + std::to_string(port)), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} inline ClientImpl::~ClientImpl() { @@ -6825,6 +7375,7 @@ inline void ClientImpl::copy_settings(const ClientImpl &rhs) { url_encode_ = rhs.url_encode_; address_family_ = rhs.address_family_; tcp_nodelay_ = rhs.tcp_nodelay_; + ipv6_v6only_ = rhs.ipv6_v6only_; socket_options_ = rhs.socket_options_; compress_ = rhs.compress_; decompress_ = rhs.decompress_; @@ -6845,6 +7396,8 @@ inline void ClientImpl::copy_settings(const ClientImpl &rhs) { #endif #ifdef CPPHTTPLIB_OPENSSL_SUPPORT server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; + server_certificate_verifier_ = rhs.server_certificate_verifier_; #endif logger_ = rhs.logger_; } @@ -6853,9 +7406,9 @@ inline socket_t ClientImpl::create_client_socket(Error &error) const { if (!proxy_host_.empty() && proxy_port_ != -1) { return detail::create_client_socket( proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, - socket_options_, connection_timeout_sec_, connection_timeout_usec_, - read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, interface_, error); + ipv6_v6only_, socket_options_, connection_timeout_sec_, + connection_timeout_usec_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, interface_, error); } // Check is custom IP specified for host_ @@ -6864,10 +7417,10 @@ inline socket_t ClientImpl::create_client_socket(Error &error) const { if (it != addr_map_.end()) { ip = it->second; } return detail::create_client_socket( - host_, ip, port_, address_family_, tcp_nodelay_, socket_options_, - connection_timeout_sec_, connection_timeout_usec_, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, interface_, - error); + host_, ip, port_, address_family_, tcp_nodelay_, ipv6_v6only_, + socket_options_, connection_timeout_sec_, connection_timeout_usec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, interface_, error); } inline bool ClientImpl::create_and_connect_socket(Socket &socket, @@ -6956,6 +7509,18 @@ inline bool ClientImpl::send(Request &req, Response &res, Error &error) { return ret; } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline bool ClientImpl::is_ssl_peer_could_be_closed(SSL *ssl) const { + detail::set_nonblocking(socket_.sock, true); + auto se = detail::scope_exit( + [&]() { detail::set_nonblocking(socket_.sock, false); }); + + char buf[1]; + return !SSL_peek(ssl, buf, 1) && + SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; +} +#endif + inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { { std::lock_guard guard(socket_mutex_); @@ -6967,6 +7532,13 @@ inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { auto is_alive = false; if (socket_.is_open()) { is_alive = detail::is_socket_alive(socket_.sock); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { is_alive = false; } + } +#endif + if (!is_alive) { // Attempt to avoid sigpipe by shutting down nongracefully if it seems // like the other side has already closed the connection Also, there @@ -7144,7 +7716,7 @@ inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { if (location.empty()) { return false; } const static std::regex re( - R"((?:(https?):)?(?://(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); std::smatch m; if (!std::regex_match(location, m, re)) { return false; } @@ -7243,12 +7815,26 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } -#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT - if (!req.has_header("User-Agent")) { - auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; - req.set_header("User-Agent", agent); - } + if (!req.content_receiver) { + if (!req.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; #endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif + req.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + }; if (req.body.empty()) { if (req.content_provider_) { @@ -7308,8 +7894,14 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req, { detail::BufferStream bstrm; - const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path; - bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + const auto &path_with_query = + req.params.empty() ? req.path + : append_query_params(req.path, req.params); + + const auto &path = + url_encode_ ? detail::encode_url(path_with_query) : path_with_query; + + detail::write_request_line(bstrm, req.method, path); header_writer_(bstrm, req.headers); @@ -7417,11 +8009,12 @@ inline Result ClientImpl::send_with_content_provider( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type) { + const std::string &content_type, Progress progress) { Request req; req.method = method; req.headers = headers; req.path = path; + req.progress = progress; auto error = Error::Success; @@ -7448,9 +8041,7 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, if (is_ssl()) { auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; if (!is_proxy_enabled) { - char buf[1]; - if (SSL_peek(socket_.ssl, buf, 1) == 0 && - SSL_get_error(socket_.ssl, 0) == SSL_ERROR_ZERO_RETURN) { + if (is_ssl_peer_could_be_closed(socket_.ssl)) { error = Error::SSLPeerCouldBeClosed_; return false; } @@ -7468,7 +8059,9 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, // Body if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && req.method != "CONNECT") { - auto redirect = 300 < res.status && res.status < 400 && follow_location_; + auto redirect = 300 < res.status && res.status < 400 && + res.status != StatusCode::NotModified_304 && + follow_location_; if (req.response_handler && !redirect) { if (!req.response_handler(res)) { @@ -7489,9 +8082,7 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, : static_cast( [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { - if (res.body.size() + n > res.body.max_size()) { - return false; - } + assert(res.body.size() + n <= res.body.max_size()); res.body.append(buf, n); return true; }); @@ -7503,12 +8094,25 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, return ret; }; - int dummy_status; - if (!detail::read_content(strm, res, (std::numeric_limits::max)(), - dummy_status, std::move(progress), std::move(out), - decompress_)) { - if (error != Error::Canceled) { error = Error::Read; } - return false; + if (res.has_header("Content-Length")) { + if (!req.content_receiver) { + auto len = res.get_header_value_u64("Content-Length"); + if (len > res.body.max_size()) { + error = Error::Read; + return false; + } + res.body.reserve(static_cast(len)); + } + } + + if (res.status != StatusCode::NotModified_304) { + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, std::move(progress), + std::move(out), decompress_)) { + if (error != Error::Canceled) { error = Error::Read; } + return false; + } } } @@ -7717,14 +8321,22 @@ inline Result ClientImpl::Post(const std::string &path, inline Result ClientImpl::Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type) { - return Post(path, Headers(), body, content_length, content_type); + return Post(path, Headers(), body, content_length, content_type, nullptr); } inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { return send_with_content_provider("POST", path, headers, body, content_length, - nullptr, nullptr, content_type); + nullptr, nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("POST", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); } inline Result ClientImpl::Post(const std::string &path, const std::string &body, @@ -7732,12 +8344,27 @@ inline Result ClientImpl::Post(const std::string &path, const std::string &body, return Post(path, Headers(), body, content_type); } +inline Result ClientImpl::Post(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return Post(path, Headers(), body, content_type, progress); +} + inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return send_with_content_provider("POST", path, headers, body.data(), - body.size(), nullptr, nullptr, - content_type); + body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("POST", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); } inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { @@ -7763,14 +8390,15 @@ inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &content_type) { return send_with_content_provider("POST", path, headers, nullptr, content_length, std::move(content_provider), - nullptr, content_type); + nullptr, content_type, nullptr); } inline Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type) { return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type); + std::move(content_provider), content_type, + nullptr); } inline Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7779,6 +8407,13 @@ inline Result ClientImpl::Post(const std::string &path, const Headers &headers, return Post(path, headers, query, "application/x-www-form-urlencoded"); } +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + inline Result ClientImpl::Post(const std::string &path, const MultipartFormDataItems &items) { return Post(path, Headers(), items); @@ -7816,7 +8451,7 @@ ClientImpl::Post(const std::string &path, const Headers &headers, return send_with_content_provider( "POST", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type); + content_type, nullptr); } inline Result ClientImpl::Put(const std::string &path) { @@ -7833,7 +8468,15 @@ inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { return send_with_content_provider("PUT", path, headers, body, content_length, - nullptr, nullptr, content_type); + nullptr, nullptr, content_type, nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PUT", path, headers, body, content_length, + nullptr, nullptr, content_type, progress); } inline Result ClientImpl::Put(const std::string &path, const std::string &body, @@ -7841,12 +8484,27 @@ inline Result ClientImpl::Put(const std::string &path, const std::string &body, return Put(path, Headers(), body, content_type); } +inline Result ClientImpl::Put(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return Put(path, Headers(), body, content_type, progress); +} + inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return send_with_content_provider("PUT", path, headers, body.data(), - body.size(), nullptr, nullptr, - content_type); + body.size(), nullptr, nullptr, content_type, + nullptr); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return send_with_content_provider("PUT", path, headers, body.data(), + body.size(), nullptr, nullptr, content_type, + progress); } inline Result ClientImpl::Put(const std::string &path, size_t content_length, @@ -7868,14 +8526,15 @@ inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &content_type) { return send_with_content_provider("PUT", path, headers, nullptr, content_length, std::move(content_provider), - nullptr, content_type); + nullptr, content_type, nullptr); } inline Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type) { return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type); + std::move(content_provider), content_type, + nullptr); } inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { @@ -7888,6 +8547,13 @@ inline Result ClientImpl::Put(const std::string &path, const Headers &headers, return Put(path, headers, query, "application/x-www-form-urlencoded"); } +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded", + progress); +} + inline Result ClientImpl::Put(const std::string &path, const MultipartFormDataItems &items) { return Put(path, Headers(), items); @@ -7925,7 +8591,7 @@ ClientImpl::Put(const std::string &path, const Headers &headers, return send_with_content_provider( "PUT", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type); + content_type, nullptr); } inline Result ClientImpl::Patch(const std::string &path) { return Patch(path, std::string(), std::string()); @@ -7937,12 +8603,26 @@ inline Result ClientImpl::Patch(const std::string &path, const char *body, return Patch(path, Headers(), body, content_length, content_type); } +inline Result ClientImpl::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_length, content_type, progress); +} + inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { + return Patch(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { return send_with_content_provider("PATCH", path, headers, body, content_length, nullptr, nullptr, - content_type); + content_type, progress); } inline Result ClientImpl::Patch(const std::string &path, @@ -7951,12 +8631,26 @@ inline Result ClientImpl::Patch(const std::string &path, return Patch(path, Headers(), body, content_type); } +inline Result ClientImpl::Patch(const std::string &path, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Patch(path, Headers(), body, content_type, progress); +} + inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { + return Patch(path, headers, body, content_type, nullptr); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { return send_with_content_provider("PATCH", path, headers, body.data(), - body.size(), nullptr, nullptr, - content_type); + body.size(), nullptr, nullptr, content_type, + progress); } inline Result ClientImpl::Patch(const std::string &path, size_t content_length, @@ -7978,14 +8672,15 @@ inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &content_type) { return send_with_content_provider("PATCH", path, headers, nullptr, content_length, std::move(content_provider), - nullptr, content_type); + nullptr, content_type, nullptr); } inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type) { return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type); + std::move(content_provider), content_type, + nullptr); } inline Result ClientImpl::Delete(const std::string &path) { @@ -8003,14 +8698,30 @@ inline Result ClientImpl::Delete(const std::string &path, const char *body, return Delete(path, Headers(), body, content_length, content_type); } +inline Result ClientImpl::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body, content_length, content_type, progress); +} + inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { + return Delete(path, headers, body, content_length, content_type, nullptr); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { Request req; req.method = "DELETE"; req.headers = headers; req.path = path; + req.progress = progress; if (!content_type.empty()) { req.set_header("Content-Type", content_type); } req.body.assign(body, content_length); @@ -8024,6 +8735,14 @@ inline Result ClientImpl::Delete(const std::string &path, return Delete(path, Headers(), body.data(), body.size(), content_type); } +inline Result ClientImpl::Delete(const std::string &path, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Delete(path, Headers(), body.data(), body.size(), content_type, + progress); +} + inline Result ClientImpl::Delete(const std::string &path, const Headers &headers, const std::string &body, @@ -8031,6 +8750,15 @@ inline Result ClientImpl::Delete(const std::string &path, return Delete(path, headers, body.data(), body.size(), content_type); } +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return Delete(path, headers, body.data(), body.size(), content_type, + progress); +} + inline Result ClientImpl::Options(const std::string &path) { return Options(path, Headers()); } @@ -8138,6 +8866,8 @@ inline void ClientImpl::set_address_family(int family) { inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } +inline void ClientImpl::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; } + inline void ClientImpl::set_socket_options(SocketOptions socket_options) { socket_options_ = std::move(socket_options); } @@ -8187,13 +8917,11 @@ inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, std::size_t size) const { auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + auto se = detail::scope_exit([&] { BIO_free_all(mem); }); if (!mem) { return nullptr; } auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); - if (!inf) { - BIO_free_all(mem); - return nullptr; - } + if (!inf) { return nullptr; } auto cts = X509_STORE_new(); if (cts) { @@ -8207,13 +8935,21 @@ inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, } sk_X509_INFO_pop_free(inf, X509_INFO_free); - BIO_free_all(mem); return cts; } inline void ClientImpl::enable_server_certificate_verification(bool enabled) { server_certificate_verification_ = enabled; } + +inline void ClientImpl::enable_server_hostname_verification(bool enabled) { + server_hostname_verification_ = enabled; +} + +inline void ClientImpl::set_server_certificate_verifier( + std::function verifier) { + server_certificate_verifier_ = verifier; +} #endif inline void ClientImpl::set_logger(Logger logger) { @@ -8257,13 +8993,30 @@ inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, return ssl; } -inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, bool shutdown_gracefully) { // sometimes we may want to skip this to try to avoid SIGPIPE if we know // the remote has closed the network connection // Note that it is not always possible to avoid SIGPIPE, this is merely a // best-efforts. - if (shutdown_gracefully) { SSL_shutdown(ssl); } + if (shutdown_gracefully) { +#ifdef _WIN32 + (void)(sock); + SSL_shutdown(ssl); +#else + timeval tv; + tv.tv_sec = 1; + tv.tv_usec = 0; + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&tv), sizeof(tv)); + + auto ret = SSL_shutdown(ssl); + while (ret == 0) { + std::this_thread::sleep_for(std::chrono::milliseconds{100}); + ret = SSL_shutdown(ssl); + } +#endif + } std::lock_guard guard(ctx_mutex); SSL_free(ssl); @@ -8366,7 +9119,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); } else if (is_readable()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret >= 0) { return ret; } err = SSL_get_error(ssl_, ret); @@ -8397,7 +9150,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { #endif if (is_writable()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_write(ssl_, ptr, static_cast(handle_size)); if (ret >= 0) { return ret; } err = SSL_get_error(ssl_, ret); @@ -8439,7 +9192,7 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); if (private_key_password != nullptr && (private_key_password[0] != '\0')) { SSL_CTX_set_default_passwd_cb_userdata( @@ -8449,7 +9202,8 @@ inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != - 1) { + 1 || + SSL_CTX_check_private_key(ctx_) != 1) { SSL_CTX_free(ctx_); ctx_ = nullptr; } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { @@ -8471,7 +9225,7 @@ inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); if (SSL_CTX_use_certificate(ctx_, cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { @@ -8505,6 +9259,19 @@ inline bool SSLServer::is_valid() const { return ctx_; } inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } +inline void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + + std::lock_guard guard(ctx_mutex_); + + SSL_CTX_use_certificate(ctx_, cert); + SSL_CTX_use_PrivateKey(ctx_, private_key); + + if (client_ca_cert_store != nullptr) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + } +} + inline bool SSLServer::process_and_close_socket(socket_t sock) { auto ssl = detail::ssl_new( sock, ctx_, ctx_mutex_, @@ -8516,20 +9283,29 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) { auto ret = false; if (ssl) { + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + ret = detail::process_server_socket_ssl( svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [this, ssl](Stream &strm, bool close_connection, - bool &connection_closed) { - return process_request(strm, close_connection, connection_closed, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, + connection_closed, [&](Request &req) { req.ssl = ssl; }); }); // Shutdown gracefully if the result seemed successful, non-gracefully if // the connection appeared to be closed. const bool shutdown_gracefully = ret; - detail::ssl_delete(ctx_mutex_, ssl, shutdown_gracefully); + detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); } detail::shutdown_socket(sock); @@ -8551,6 +9327,8 @@ inline SSLClient::SSLClient(const std::string &host, int port, : ClientImpl(host, port, client_cert_path, client_key_path) { ctx_ = SSL_CTX_new(TLS_client_method()); + SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); + detail::split(&host_[0], &host_[host_.size()], '.', [&](const char *b, const char *e) { host_components_.emplace_back(b, e); @@ -8758,36 +9536,47 @@ inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { } if (server_certificate_verification_) { - verify_result_ = SSL_get_verify_result(ssl2); + if (server_certificate_verifier_) { + if (!server_certificate_verifier_(ssl2)) { + error = Error::SSLServerVerification; + return false; + } + } else { + verify_result_ = SSL_get_verify_result(ssl2); - if (verify_result_ != X509_V_OK) { - error = Error::SSLServerVerification; - return false; + if (verify_result_ != X509_V_OK) { + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + auto se = detail::scope_exit([&] { X509_free(server_cert); }); + + if (server_cert == nullptr) { + error = Error::SSLServerVerification; + return false; + } + + if (server_hostname_verification_) { + if (!verify_host(server_cert)) { + error = Error::SSLServerHostnameVerification; + return false; + } + } } - - auto server_cert = SSL_get1_peer_certificate(ssl2); - - if (server_cert == nullptr) { - error = Error::SSLServerVerification; - return false; - } - - if (!verify_host(server_cert)) { - X509_free(server_cert); - error = Error::SSLServerVerification; - return false; - } - X509_free(server_cert); } return true; }, [&](SSL *ssl2) { +#if defined(OPENSSL_IS_BORINGSSL) + SSL_set_tlsext_host_name(ssl2, host_.c_str()); +#else // NOTE: Direct call instead of using the OpenSSL macro to suppress // -Wold-style-cast warning - // SSL_set_tlsext_host_name(ssl2, host_.c_str()); SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, static_cast(const_cast(host_.c_str()))); +#endif return true; }); @@ -8812,7 +9601,8 @@ inline void SSLClient::shutdown_ssl_impl(Socket &socket, return; } if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, shutdown_gracefully); + detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, + shutdown_gracefully); socket.ssl = nullptr; } assert(socket.ssl == nullptr); @@ -8861,8 +9651,8 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { auto type = GEN_DNS; - struct in6_addr addr6 {}; - struct in_addr addr {}; + struct in6_addr addr6{}; + struct in_addr addr{}; size_t addr_len = 0; #ifndef __MINGW32__ @@ -8965,7 +9755,7 @@ inline Client::Client(const std::string &scheme_host_port, const std::string &client_cert_path, const std::string &client_key_path) { const static std::regex re( - R"((?:([a-z]+):\/\/)?(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); std::smatch m; if (std::regex_match(scheme_host_port, m, re)) { @@ -9002,10 +9792,12 @@ inline Client::Client(const std::string &scheme_host_port, client_key_path); } } else { + // NOTE: Update TEST(UniversalClientImplTest, Ipv6LiteralAddress) + // if port param below changes. cli_ = detail::make_unique(scheme_host_port, 80, client_cert_path, client_key_path); } -} +} // namespace detail inline Client::Client(const std::string &host, int port) : cli_(detail::make_unique(host, port)) {} @@ -9111,15 +9903,30 @@ inline Result Client::Post(const std::string &path, const Headers &headers, const std::string &content_type) { return cli_->Post(path, headers, body, content_length, content_type); } +inline Result Client::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_length, content_type, + progress); +} inline Result Client::Post(const std::string &path, const std::string &body, const std::string &content_type) { return cli_->Post(path, body, content_type); } +inline Result Client::Post(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, body, content_type, progress); +} inline Result Client::Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return cli_->Post(path, headers, body, content_type); } +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Post(path, headers, body, content_type, progress); +} inline Result Client::Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type) { @@ -9150,6 +9957,10 @@ inline Result Client::Post(const std::string &path, const Headers &headers, const Params ¶ms) { return cli_->Post(path, headers, params); } +inline Result Client::Post(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + return cli_->Post(path, headers, params, progress); +} inline Result Client::Post(const std::string &path, const MultipartFormDataItems &items) { return cli_->Post(path, items); @@ -9180,15 +9991,29 @@ inline Result Client::Put(const std::string &path, const Headers &headers, const std::string &content_type) { return cli_->Put(path, headers, body, content_length, content_type); } +inline Result Client::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_length, content_type, progress); +} inline Result Client::Put(const std::string &path, const std::string &body, const std::string &content_type) { return cli_->Put(path, body, content_type); } +inline Result Client::Put(const std::string &path, const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, body, content_type, progress); +} inline Result Client::Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return cli_->Put(path, headers, body, content_type); } +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, Progress progress) { + return cli_->Put(path, headers, body, content_type, progress); +} inline Result Client::Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type) { @@ -9219,6 +10044,10 @@ inline Result Client::Put(const std::string &path, const Headers &headers, const Params ¶ms) { return cli_->Put(path, headers, params); } +inline Result Client::Put(const std::string &path, const Headers &headers, + const Params ¶ms, Progress progress) { + return cli_->Put(path, headers, params, progress); +} inline Result Client::Put(const std::string &path, const MultipartFormDataItems &items) { return cli_->Put(path, items); @@ -9246,20 +10075,44 @@ inline Result Client::Patch(const std::string &path, const char *body, const std::string &content_type) { return cli_->Patch(path, body, content_length, content_type); } +inline Result Client::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_length, content_type, progress); +} inline Result Client::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { return cli_->Patch(path, headers, body, content_length, content_type); } +inline Result Client::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, headers, body, content_length, content_type, + progress); +} inline Result Client::Patch(const std::string &path, const std::string &body, const std::string &content_type) { return cli_->Patch(path, body, content_type); } +inline Result Client::Patch(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, body, content_type, progress); +} inline Result Client::Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return cli_->Patch(path, headers, body, content_type); } +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Patch(path, headers, body, content_type, progress); +} inline Result Client::Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type) { @@ -9294,20 +10147,44 @@ inline Result Client::Delete(const std::string &path, const char *body, const std::string &content_type) { return cli_->Delete(path, body, content_length, content_type); } +inline Result Client::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_length, content_type, progress); +} inline Result Client::Delete(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type) { return cli_->Delete(path, headers, body, content_length, content_type); } +inline Result Client::Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, headers, body, content_length, content_type, + progress); +} inline Result Client::Delete(const std::string &path, const std::string &body, const std::string &content_type) { return cli_->Delete(path, body, content_type); } +inline Result Client::Delete(const std::string &path, const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, body, content_type, progress); +} inline Result Client::Delete(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type) { return cli_->Delete(path, headers, body, content_type); } +inline Result Client::Delete(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type, + Progress progress) { + return cli_->Delete(path, headers, body, content_type, progress); +} inline Result Client::Options(const std::string &path) { return cli_->Options(path); } @@ -9417,6 +10294,15 @@ inline void Client::set_proxy_digest_auth(const std::string &username, inline void Client::enable_server_certificate_verification(bool enabled) { cli_->enable_server_certificate_verification(enabled); } + +inline void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +inline void Client::set_server_certificate_verifier( + std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} #endif inline void Client::set_logger(Logger logger) { diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz new file mode 100644 index 000000000..582ccc0d3 Binary files /dev/null and b/examples/server/public/index.html.gz differ diff --git a/examples/server/public/index.js b/examples/server/public/index.js deleted file mode 100644 index fe615ca25..000000000 --- a/examples/server/public/index.js +++ /dev/null @@ -1 +0,0 @@ -const t=Symbol.for("preact-signals");function n(){if(r>1){r--;return}let t,n=!1;while(void 0!==i){let _=i;i=void 0;u++;while(void 0!==_){const i=_.o;_.o=void 0;_.f&=-3;if(!(8&_.f)&&h(_))try{_.c()}catch(e){if(!n){t=e;n=!0}}_=i}}u=0;r--;if(n)throw t}function e(t){if(r>0)return t();r++;try{return t()}finally{n()}}let _,i;function o(t){const n=_;_=void 0;try{return t()}finally{_=n}}let r=0,u=0,l=0;function f(t){if(void 0===_)return;let n=t.n;if(void 0===n||n.t!==_){n={i:0,S:t,p:_.s,n:void 0,t:_,e:void 0,x:void 0,r:n};if(void 0!==_.s)_.s.n=n;_.s=n;t.n=n;if(32&_.f)t.S(n);return n}else if(-1===n.i){n.i=0;if(void 0!==n.n){n.n.p=n.p;if(void 0!==n.p)n.p.n=n.n;n.p=_.s;n.n=void 0;_.s.n=n;_.s=n}return n}}function s(t){this.v=t;this.i=0;this.n=void 0;this.t=void 0}s.prototype.brand=t;s.prototype.h=function(){return!0};s.prototype.S=function(t){if(this.t!==t&&void 0===t.e){t.x=this.t;if(void 0!==this.t)this.t.e=t;this.t=t}};s.prototype.U=function(t){if(void 0!==this.t){const n=t.e,e=t.x;if(void 0!==n){n.x=e;t.e=void 0}if(void 0!==e){e.e=n;t.x=void 0}if(t===this.t)this.t=e}};s.prototype.subscribe=function(t){return k(()=>{const n=this.value,e=_;_=void 0;try{t(n)}finally{_=e}})};s.prototype.valueOf=function(){return this.value};s.prototype.toString=function(){return this.value+""};s.prototype.toJSON=function(){return this.value};s.prototype.peek=function(){const t=_;_=void 0;try{return this.value}finally{_=t}};Object.defineProperty(s.prototype,"value",{get(){const t=f(this);if(void 0!==t)t.i=this.i;return this.v},set(t){if(t!==this.v){if(u>100)throw new Error("Cycle detected");this.v=t;this.i++;l++;r++;try{for(let t=this.t;void 0!==t;t=t.x)t.t.N()}finally{n()}}}});function c(t){return new s(t)}function h(t){for(let n=t.s;void 0!==n;n=n.n)if(n.S.i!==n.i||!n.S.h()||n.S.i!==n.i)return!0;return!1}function a(t){for(let n=t.s;void 0!==n;n=n.n){const e=n.S.n;if(void 0!==e)n.r=e;n.S.n=n;n.i=-1;if(void 0===n.n){t.s=n;break}}}function p(t){let n,e=t.s;while(void 0!==e){const t=e.p;if(-1===e.i){e.S.U(e);if(void 0!==t)t.n=e.n;if(void 0!==e.n)e.n.p=t}else n=e;e.S.n=e.r;if(void 0!==e.r)e.r=void 0;e=t}t.s=n}function d(t){s.call(this,void 0);this.x=t;this.s=void 0;this.g=l-1;this.f=4}(d.prototype=new s).h=function(){this.f&=-3;if(1&this.f)return!1;if(32==(36&this.f))return!0;this.f&=-5;if(this.g===l)return!0;this.g=l;this.f|=1;if(this.i>0&&!h(this)){this.f&=-2;return!0}const t=_;try{a(this);_=this;const t=this.x();if(16&this.f||this.v!==t||0===this.i){this.v=t;this.f&=-17;this.i++}}catch(t){this.v=t;this.f|=16;this.i++}_=t;p(this);this.f&=-2;return!0};d.prototype.S=function(t){if(void 0===this.t){this.f|=36;for(let t=this.s;void 0!==t;t=t.n)t.S.S(t)}s.prototype.S.call(this,t)};d.prototype.U=function(t){if(void 0!==this.t){s.prototype.U.call(this,t);if(void 0===this.t){this.f&=-33;for(let t=this.s;void 0!==t;t=t.n)t.S.U(t)}}};d.prototype.N=function(){if(!(2&this.f)){this.f|=6;for(let t=this.t;void 0!==t;t=t.x)t.t.N()}};Object.defineProperty(d.prototype,"value",{get(){if(1&this.f)throw new Error("Cycle detected");const t=f(this);this.h();if(void 0!==t)t.i=this.i;if(16&this.f)throw this.v;return this.v}});function v(t){return new d(t)}function y(t){const e=t.u;t.u=void 0;if("function"==typeof e){r++;const i=_;_=void 0;try{e()}catch(n){t.f&=-2;t.f|=8;m(t);throw n}finally{_=i;n()}}}function m(t){for(let n=t.s;void 0!==n;n=n.n)n.S.U(n);t.x=void 0;t.s=void 0;y(t)}function g(t){if(_!==this)throw new Error("Out-of-order effect");p(this);_=t;this.f&=-2;if(8&this.f)m(this);n()}function b(t){this.x=t;this.u=void 0;this.s=void 0;this.o=void 0;this.f=32}b.prototype.c=function(){const t=this.S();try{if(8&this.f)return;if(void 0===this.x)return;const n=this.x();if("function"==typeof n)this.u=n}finally{t()}};b.prototype.S=function(){if(1&this.f)throw new Error("Cycle detected");this.f|=1;this.f&=-9;y(this);a(this);r++;const t=_;_=this;return g.bind(this,t)};b.prototype.N=function(){if(!(2&this.f)){this.f|=2;this.o=i;i=this}};b.prototype.d=function(){this.f|=8;if(!(1&this.f))m(this)};function k(t){const n=new b(t);try{n.c()}catch(t){n.d();throw t}return n.d.bind(n)}var w,S,x,C,U,E,H,P,N,$,T,D,M={},F=[],A=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i,W=Array.isArray;function L(t,n){for(var e in n)t[e]=n[e];return t}function O(t){var n=t.parentNode;n&&n.removeChild(t)}function R(t,n,e){var _,i,o,r={};for(o in n)"key"==o?_=n[o]:"ref"==o?i=n[o]:r[o]=n[o];if(arguments.length>2&&(r.children=arguments.length>3?w.call(arguments,2):e),"function"==typeof t&&null!=t.defaultProps)for(o in t.defaultProps)void 0===r[o]&&(r[o]=t.defaultProps[o]);return I(t,r,_,i,null)}function I(t,n,e,_,i){var o={type:t,props:n,key:e,ref:_,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,constructor:void 0,__v:null==i?++x:i,__i:-1,__u:0};return null==i&&null!=S.vnode&&S.vnode(o),o}function V(){return{current:null}}function j(t){return t.children}function q(t,n){this.props=t,this.context=n}function B(t,n){if(null==n)return t.__?B(t.__,t.__i+1):null;for(var e;nn&&U.sort(P));J.__r=0}function K(t,n,e,_,i,o,r,u,l,f,s){var c,h,a,p,d,v=_&&_.__k||F,y=n.length;for(e.__d=l,Q(e,n,v),l=e.__d,c=0;c0?I(i.type,i.props,i.key,i.ref?i.ref:null,i.__v):i)?(i.__=t,i.__b=t.__b+1,u=Z(i,e,r,s),i.__i=u,o=null,-1!==u&&(s--,(o=e[u])&&(o.__u|=131072)),null==o||null===o.__v?(-1==u&&c--,"function"!=typeof i.type&&(i.__u|=65536)):u!==r&&(u==r-1?c--:u==r+1?c++:u>r?s>l-r?c+=u-r:c--:u(null!=l&&0==(131072&l.__u)?1:0))for(;r>=0||u=0){if((l=n[r])&&0==(131072&l.__u)&&i==l.key&&o===l.type)return r;r--}if(u2&&(u.children=arguments.length>3?w.call(arguments,2):e),I(t.type,u,_||t.key,i||t.ref,null)}function ht(t,n){var e={__c:n="__cC"+D++,__:t,Consumer:function(t,n){return t.children(n)},Provider:function(t){var e,_;return this.getChildContext||(e=[],(_={})[n]=this,this.getChildContext=function(){return _},this.componentWillUnmount=function(){e=null},this.shouldComponentUpdate=function(t){this.props.value!==t.value&&e.some((function(t){t.__e=!0,G(t)}))},this.sub=function(t){e.push(t);var n=t.componentWillUnmount;t.componentWillUnmount=function(){e&&e.splice(e.indexOf(t),1),n&&n.call(t)}}),t.children}};return e.Provider.__=e.Consumer.contextType=e}w=F.slice,S={__e:function(t,n,e,_){for(var i,o,r;n=n.__;)if((i=n.__c)&&!i.__)try{if((o=i.constructor)&&null!=o.getDerivedStateFromError&&(i.setState(o.getDerivedStateFromError(t)),r=i.__d),null!=i.componentDidCatch&&(i.componentDidCatch(t,_||{}),r=i.__d),r)return i.__E=i}catch(n){t=n}throw t}},x=0,C=function(t){return null!=t&&null==t.constructor},q.prototype.setState=function(t,n){var e;e=null!=this.__s&&this.__s!==this.state?this.__s:this.__s=L({},this.state),"function"==typeof t&&(t=t(L({},e),this.props)),t&&L(e,t),null!=t&&this.__v&&(n&&this._sb.push(n),G(this))},q.prototype.forceUpdate=function(t){this.__v&&(this.__e=!0,t&&this.__h.push(t),G(this))},q.prototype.render=j,U=[],H="function"==typeof Promise?Promise.prototype.then.bind(Promise.resolve()):setTimeout,P=function(t,n){return t.__v.__b-n.__v.__b},J.__r=0,N=0,$=et(!1),T=et(!0),D=0;var at,pt,dt,vt,yt=0,mt=[],gt=S,bt=gt.__b,kt=gt.__r,wt=gt.diffed,St=gt.__c,xt=gt.unmount,Ct=gt.__;function Ut(t,n){gt.__h&>.__h(pt,t,yt||n),yt=0;var e=pt.__H||(pt.__H={__:[],__h:[]});return t>=e.__.length&&e.__.push({}),e.__[t]}function Et(t){return yt=1,Ht(Bt,t)}function Ht(t,n,e){var _=Ut(at++,2);if(_.t=t,!_.__c&&(_.__=[e?e(n):Bt(void 0,n),function(t){var n=_.__N?_.__N[0]:_.__[0],e=_.t(n,t);n!==e&&(_.__N=[e,_.__[1]],_.__c.setState({}))}],_.__c=pt,!pt.u)){var i=function(t,n,e){if(!_.__c.__H)return!0;var i=_.__c.__H.__.filter((function(t){return!!t.__c}));if(i.every((function(t){return!t.__N})))return!o||o.call(this,t,n,e);var r=!1;return i.forEach((function(t){if(t.__N){var n=t.__[0];t.__=t.__N,t.__N=void 0,n!==t.__[0]&&(r=!0)}})),!(!r&&_.__c.props===t)&&(!o||o.call(this,t,n,e))};pt.u=!0;var o=pt.shouldComponentUpdate,r=pt.componentWillUpdate;pt.componentWillUpdate=function(t,n,e){if(this.__e){var _=o;o=void 0,i(t,n,e),o=_}r&&r.call(this,t,n,e)},pt.shouldComponentUpdate=i}return _.__N||_.__}function Pt(t,n){var e=Ut(at++,3);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__H.__h.push(e))}function Nt(t,n){var e=Ut(at++,4);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__h.push(e))}function $t(t){return yt=5,Dt((function(){return{current:t}}),[])}function Tt(t,n,e){yt=6,Nt((function(){return"function"==typeof t?(t(n()),function(){return t(null)}):t?(t.current=n(),function(){return t.current=null}):void 0}),null==e?e:e.concat(t))}function Dt(t,n){var e=Ut(at++,7);return qt(e.__H,n)&&(e.__=t(),e.__H=n,e.__h=t),e.__}function Mt(t,n){return yt=8,Dt((function(){return t}),n)}function Ft(t){var n=pt.context[t.__c],e=Ut(at++,9);return e.c=t,n?(null==e.__&&(e.__=!0,n.sub(pt)),n.props.value):t.__}function At(t,n){gt.useDebugValue&>.useDebugValue(n?n(t):t)}function Wt(t){var n=Ut(at++,10),e=Et();return n.__=t,pt.componentDidCatch||(pt.componentDidCatch=function(t,_){n.__&&n.__(t,_),e[1](t)}),[e[0],function(){e[1](void 0)}]}function Lt(){var t=Ut(at++,11);if(!t.__){for(var n=pt.__v;null!==n&&!n.__m&&null!==n.__;)n=n.__;var e=n.__m||(n.__m=[0,0]);t.__="P"+e[0]+"-"+e[1]++}return t.__}function Ot(){for(var t;t=mt.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(Vt),t.__H.__h.forEach(jt),t.__H.__h=[]}catch(n){t.__H.__h=[],gt.__e(n,t.__v)}}gt.__b=function(t){pt=null,bt&&bt(t)},gt.__=function(t,n){t&&n.__k&&n.__k.__m&&(t.__m=n.__k.__m),Ct&&Ct(t,n)},gt.__r=function(t){kt&&kt(t),at=0;var n=(pt=t.__c).__H;n&&(dt===pt?(n.__h=[],pt.__h=[],n.__.forEach((function(t){t.__N&&(t.__=t.__N),t.i=t.__N=void 0}))):(n.__h.forEach(Vt),n.__h.forEach(jt),n.__h=[],at=0)),dt=pt},gt.diffed=function(t){wt&&wt(t);var n=t.__c;n&&n.__H&&(n.__H.__h.length&&(1!==mt.push(n)&&vt===gt.requestAnimationFrame||((vt=gt.requestAnimationFrame)||It)(Ot)),n.__H.__.forEach((function(t){t.i&&(t.__H=t.i),t.i=void 0}))),dt=pt=null},gt.__c=function(t,n){n.some((function(t){try{t.__h.forEach(Vt),t.__h=t.__h.filter((function(t){return!t.__||jt(t)}))}catch(r){n.some((function(t){t.__h&&(t.__h=[])})),n=[],gt.__e(r,t.__v)}})),St&&St(t,n)},gt.unmount=function(t){xt&&xt(t);var n,e=t.__c;e&&e.__H&&(e.__H.__.forEach((function(t){try{Vt(t)}catch(t){n=t}})),e.__H=void 0,n&>.__e(n,e.__v))};var Rt="function"==typeof requestAnimationFrame;function It(t){var n,e=function(){clearTimeout(_),Rt&&cancelAnimationFrame(n),setTimeout(t)},_=setTimeout(e,100);Rt&&(n=requestAnimationFrame(e))}function Vt(t){var n=pt,e=t.__c;"function"==typeof e&&(t.__c=void 0,e()),pt=n}function jt(t){var n=pt;t.__c=t.__(),pt=n}function qt(t,n){return!t||t.length!==n.length||n.some((function(n,e){return n!==t[e]}))}function Bt(t,n){return"function"==typeof n?n(t):n}function zt(t,n){S[t]=n.bind(null,S[t]||(()=>{}))}let Gt,Jt;function Kt(t){if(Jt)Jt();Jt=t&&t.S()}function Qt({data:t}){const n=Yt(t);n.value=t;const e=Dt(()=>{let t=this.__v;while(t=t.__)if(t.__c){t.__c.__$f|=4;break}this.__$u.c=()=>{var t;if(!C(e.peek())&&3===(null==(t=this.base)?void 0:t.nodeType))this.base.data=e.peek();else{this.__$f|=1;this.setState({})}};return v(()=>{let t=n.value.value;return 0===t?0:!0===t?"":t||""})},[]);return e.value}Qt.displayName="_st";Object.defineProperties(s.prototype,{constructor:{configurable:!0,value:void 0},type:{configurable:!0,value:Qt},props:{configurable:!0,get(){return{data:this}}},__b:{configurable:!0,value:1}});zt("__b",(t,n)=>{if("string"==typeof n.type){let t,e=n.props;for(let _ in e){if("children"===_)continue;let i=e[_];if(i instanceof s){if(!t)n.__np=t={};t[_]=i;e[_]=i.peek()}}}t(n)});zt("__r",(t,n)=>{Kt();let e,_=n.__c;if(_){_.__$f&=-2;e=_.__$u;if(void 0===e)_.__$u=e=function(t){let n;k((function(){n=this}));n.c=()=>{_.__$f|=1;_.setState({})};return n}()}Gt=_;Kt(e);t(n)});zt("__e",(t,n,e,_)=>{Kt();Gt=void 0;t(n,e,_)});zt("diffed",(t,n)=>{Kt();Gt=void 0;let e;if("string"==typeof n.type&&(e=n.__e)){let t=n.__np,_=n.props;if(t){let n=e.U;if(n)for(let e in n){let _=n[e];if(void 0!==_&&!(e in t)){_.d();n[e]=void 0}}else{n={};e.U=n}for(let i in t){let o=n[i],r=t[i];if(void 0===o){o=Xt(e,i,r,_);n[i]=o}else o.o(r,_)}}}t(n)});function Xt(t,n,e,_){const i=n in t&&void 0===t.ownerSVGElement,o=c(e);return{o:(t,n)=>{o.value=t;_=n},d:k(()=>{const e=o.value.value;if(_[n]!==e){_[n]=e;if(i)t[n]=e;else if(e)t.setAttribute(n,e);else t.removeAttribute(n)}})}}zt("unmount",(t,n)=>{if("string"==typeof n.type){let t=n.__e;if(t){const n=t.U;if(n){t.U=void 0;for(let t in n){let e=n[t];if(e)e.d()}}}}else{let t=n.__c;if(t){const n=t.__$u;if(n){t.__$u=void 0;n.d()}}}t(n)});zt("__h",(t,n,e,_)=>{if(_<3||9===_)n.__$f|=2;t(n,e,_)});q.prototype.shouldComponentUpdate=function(t,n){const e=this.__$u;if(!(e&&void 0!==e.s||4&this.__$f))return!0;if(3&this.__$f)return!0;for(let _ in n)return!0;for(let _ in t)if("__source"!==_&&t[_]!==this.props[_])return!0;for(let _ in this.props)if(!(_ in t))return!0;return!1};function Yt(t){return Dt(()=>c(t),[])}function Zt(t){const n=$t(t);n.current=t;Gt.__$f|=4;return Dt(()=>v(()=>n.current()),[])}function tn(t){const n=$t(t);n.current=t;Pt(()=>k(()=>n.current()),[])}var nn=function(t,n,e,_){var i;n[0]=0;for(var o=1;o=5&&((i||!t&&5===_)&&(r.push(_,0,i,e),_=6),t&&(r.push(_,t,0,e),_=6)),i=""},l=0;l"===n?(_=1,i=""):i=n+i[0]:o?n===o?o="":i+=n:'"'===n||"'"===n?o=n:">"===n?(u(),_=1):_&&("="===n?(_=5,e=i,i=""):"/"===n&&(_<5||">"===t[l][f+1])?(u(),3===_&&(r=r[0]),_=r,(r=r[0]).push(2,0,_),_=0):" "===n||"\t"===n||"\n"===n||"\r"===n?(u(),_=2):i+=n),3===_&&"!--"===i&&(_=4,r=r[0])}return u(),r}(t)),n),arguments,[])).length>1?n:n[0]}var on=_n.bind(R);export{q as Component,j as Fragment,s as Signal,e as batch,ct as cloneElement,v as computed,ht as createContext,R as createElement,V as createRef,k as effect,R as h,on as html,st as hydrate,C as isValidElement,S as options,ft as render,c as signal,Y as toChildArray,o as untracked,Mt as useCallback,Zt as useComputed,Ft as useContext,At as useDebugValue,Pt as useEffect,Wt as useErrorBoundary,Lt as useId,Tt as useImperativeHandle,Nt as useLayoutEffect,Dt as useMemo,Ht as useReducer,$t as useRef,Yt as useSignal,tn as useSignalEffect,Et as useState}; diff --git a/examples/server/public/loading.html b/examples/server/public/loading.html new file mode 100644 index 000000000..c3fd19a0f --- /dev/null +++ b/examples/server/public/loading.html @@ -0,0 +1,12 @@ + + + + + + +
+ The model is loading. Please wait.
+ The user interface will appear soon. +
+ + diff --git a/examples/server/public/colorthemes.css b/examples/server/public_legacy/colorthemes.css similarity index 100% rename from examples/server/public/colorthemes.css rename to examples/server/public_legacy/colorthemes.css diff --git a/examples/server/public/completion.js b/examples/server/public_legacy/completion.js similarity index 96% rename from examples/server/public/completion.js rename to examples/server/public_legacy/completion.js index 36818f764..30df7c2fa 100644 --- a/examples/server/public/completion.js +++ b/examples/server/public_legacy/completion.js @@ -29,7 +29,7 @@ export async function* llama(prompt, params = {}, config = {}) { const completionParams = { ...paramDefaults, ...params, prompt }; - const response = await fetch(`${api_url}/completion`, { + const response = await fetch(`${api_url}${config.endpoint || '/completion'}`, { method: 'POST', body: JSON.stringify(completionParams), headers: { @@ -78,7 +78,12 @@ export async function* llama(prompt, params = {}, config = {}) { for (const line of lines) { const match = regex.exec(line); if (match) { - result[match[1]] = match[2] + result[match[1]] = match[2]; + if (result.data === '[DONE]') { + cont = false; + break; + } + // since we know this is llama.cpp, let's just decode the json in data if (result.data) { result.data = JSON.parse(result.data); diff --git a/examples/server/public/favicon.ico b/examples/server/public_legacy/favicon.ico similarity index 100% rename from examples/server/public/favicon.ico rename to examples/server/public_legacy/favicon.ico diff --git a/examples/server/public/index-new.html b/examples/server/public_legacy/index-new.html similarity index 95% rename from examples/server/public/index-new.html rename to examples/server/public_legacy/index-new.html index c87dd8f1e..cbfbbdf28 100644 --- a/examples/server/public/index-new.html +++ b/examples/server/public_legacy/index-new.html @@ -39,11 +39,15 @@ temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower repeat_last_n: 0, // 0 = disable penalty, -1 = context size repeat_penalty: 1.0, // 1.0 = disabled - penalize_nl: false, // true only useful for infinite completion + dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well + dry_base: 1.75, // 0.0 = disabled + dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well + dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) top_k: 0, // <= 0 to use vocab size top_p: 1.0, // 1.0 = disabled min_p: 0.05, // 0 = disabled; recommended for non-english: ~ 0.4 - tfs_z: 1.0, // 1.0 = disabled + xtc_probability: 0.0, // 0 = disabled; + xtc_threshold: 0.1, // > 0.5 disables XTC; typical_p: 1.0, // 1.0 = disabled presence_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled @@ -831,11 +835,16 @@ return html`
${IntField({ label: "Top-K", title: "Limits the selection of the next token to the K most probable tokens. 1 means no randomness = greedy sampling. If set to 0, it means the entire vocabulary size is considered.", max: 100, min: 0, step: 1, name: "top_k", value: params.value.top_k })} ${IntField({ label: "Penalize Last N", title: "The last n tokens that are taken into account to penalise repetitions. A value of 0 means that this function is deactivated and -1 means that the entire size of the context is taken into account.", max: 2048, min: 0, step: 16, name: "repeat_last_n", value: params.value.repeat_last_n })} - ${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Presence Penalty", title: "A penalty that is applied if certain tokens appear repeatedly in the generated text. A higher value leads to fewer repetitions.", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} - ${FloatField({ label: "TFS-Z", title: "Activates tail-free sampling, a method used to limit the prediction of tokens that are too frequent. The parameter z controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })} ${FloatField({ label: "Frequency Penalty", title: "A penalty that is applied based on the frequency with which certain tokens occur in the training data set. A higher value results in rare tokens being favoured.", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} + ${FloatField({ label: "Top-P", title: "Limits the selection of the next token to a subset of tokens whose combined probability reaches a threshold value P = top-P. If set to 1, it means the entire vocabulary size is considered.", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Typical-P", title: "Activates local typical sampling, a method used to limit the prediction of tokens that are atypical in the current context. The parameter p controls the strength of this limitation. A value of 1.0 means that this function is deactivated.", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} + ${FloatField({ label: "XTC probability", title: "Sets the chance for token removal (checked once on sampler start)", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })} + ${FloatField({ label: "XTC threshold", title: "Sets a minimum probability threshold for tokens to be removed", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })} + ${FloatField({ label: "DRY Penalty Multiplier", title: "Set the DRY repetition penalty multiplier. Default is 0.0, which disables DRY.", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })} + ${FloatField({ label: "DRY Base", title: "Set the DRY repetition penalty base value. Default is 1.75", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })} + ${IntField({ label: "DRY Allowed Length", title: "Tokens that extend repetition beyond this receive exponentially increasing penalty. Default is 2", max: 10, min: 1, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })} + ${IntField({ label: "DRY Penalty Last N", title: "How many tokens to scan for repetitions. Default is -1, where 0 is disabled and -1 is context size", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })} ${IntField({ label: "Min Keep", title: "If greater than 0, samplers are forced to return N possible tokens at minimum. Default is 0", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
@@ -1132,12 +1141,15 @@ document.addEventListener('DOMContentLoaded', (event) => { const snapSettings = { temperature: { snapValue: 1.0, snapRangeMultiplier: 6 }, min_p: { snapValue: 0.05, snapRangeMultiplier: 2 }, + xtc_probability: { snapValue: 0.0, snapRangeMultiplier: 4 }, + xtc_threshold: { snapValue: 0.5, snapRangeMultiplier: 4 }, top_p: { snapValue: 1.0, snapRangeMultiplier: 4 }, - tfs_z: { snapValue: 1.0, snapRangeMultiplier: 4 }, typical_p: { snapValue: 1.0, snapRangeMultiplier: 4 }, repeat_penalty: { snapValue: 1.0, snapRangeMultiplier: 4 }, presence_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 }, frequency_penalty: { snapValue: 0.0, snapRangeMultiplier: 4 }, + dry_multiplier: { snapValue: 0.0, snapRangeMultiplier: 4 }, + dry_base: { snapValue: 1.75, snapRangeMultiplier: 4 }, }; // add an event listener for each slider Object.keys(snapSettings).forEach(sliderName => { diff --git a/examples/server/public/index.html b/examples/server/public_legacy/index.html similarity index 96% rename from examples/server/public/index.html rename to examples/server/public_legacy/index.html index 07fec6a38..75f39330a 100644 --- a/examples/server/public/index.html +++ b/examples/server/public_legacy/index.html @@ -303,11 +303,15 @@ temperature: 0.7, repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_penalty: 1.18, // 1.0 = disabled - penalize_nl: false, + dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well + dry_base: 1.75, // 0.0 = disabled + dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well + dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) top_k: 40, // <= 0 to use vocab size top_p: 0.95, // 1.0 = disabled min_p: 0.05, // 0 = disabled - tfs_z: 1.0, // 1.0 = disabled + xtc_probability: 0.0, // 0 = disabled; + xtc_threshold: 0.1, // > 0.5 disables XTC; typical_p: 1.0, // 1.0 = disabled presence_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled @@ -1001,7 +1005,6 @@ ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} - ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} @@ -1009,10 +1012,15 @@
More options
- ${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })} ${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} ${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} + ${FloatField({ label: "DRY Penalty Multiplier", max: 5.0, min: 0.0, name: "dry_multiplier", step: 0.01, value: params.value.dry_multiplier })} + ${FloatField({ label: "DRY Base", max: 3.0, min: 1.0, name: "dry_base", step: 0.01, value: params.value.dry_base })} + ${IntField({ label: "DRY Allowed Length", max: 10, min: 2, step: 1, name: "dry_allowed_length", value: params.value.dry_allowed_length })} + ${IntField({ label: "DRY Penalty Last N", max: 2048, min: -1, step: 16, name: "dry_penalty_last_n", value: params.value.dry_penalty_last_n })} + ${FloatField({ label: "XTC probability", max: 1.0, min: 0.0, name: "xtc_probability", step: 0.01, value: params.value.xtc_probability })} + ${FloatField({ label: "XTC threshold", max: 0.5, min: 0.0, name: "xtc_threshold", step: 0.01, value: params.value.xtc_threshold })}

diff --git a/examples/server/public_legacy/index.js b/examples/server/public_legacy/index.js new file mode 100644 index 000000000..32ec6e9e1 --- /dev/null +++ b/examples/server/public_legacy/index.js @@ -0,0 +1 @@ +const t=Symbol.for("preact-signals");function n(){if(r>1){r--;return}let t,n=!1;while(void 0!==i){let _=i;i=void 0;u++;while(void 0!==_){const i=_.o;_.o=void 0;_.f&=-3;if(!(8&_.f)&&h(_))try{_.c()}catch(e){if(!n){t=e;n=!0}}_=i}}u=0;r--;if(n)throw t}function e(t){if(r>0)return t();r++;try{return t()}finally{n()}}let _,i;function o(t){const n=_;_=void 0;try{return t()}finally{_=n}}let r=0,u=0,l=0;function s(t){if(void 0===_)return;let n=t.n;if(void 0===n||n.t!==_){n={i:0,S:t,p:_.s,n:void 0,t:_,e:void 0,x:void 0,r:n};if(void 0!==_.s)_.s.n=n;_.s=n;t.n=n;if(32&_.f)t.S(n);return n}else if(-1===n.i){n.i=0;if(void 0!==n.n){n.n.p=n.p;if(void 0!==n.p)n.p.n=n.n;n.p=_.s;n.n=void 0;_.s.n=n;_.s=n}return n}}function f(t){this.v=t;this.i=0;this.n=void 0;this.t=void 0}f.prototype.brand=t;f.prototype.h=function(){return!0};f.prototype.S=function(t){if(this.t!==t&&void 0===t.e){t.x=this.t;if(void 0!==this.t)this.t.e=t;this.t=t}};f.prototype.U=function(t){if(void 0!==this.t){const n=t.e,e=t.x;if(void 0!==n){n.x=e;t.e=void 0}if(void 0!==e){e.e=n;t.x=void 0}if(t===this.t)this.t=e}};f.prototype.subscribe=function(t){return k(()=>{const n=this.value,e=_;_=void 0;try{t(n)}finally{_=e}})};f.prototype.valueOf=function(){return this.value};f.prototype.toString=function(){return this.value+""};f.prototype.toJSON=function(){return this.value};f.prototype.peek=function(){const t=_;_=void 0;try{return this.value}finally{_=t}};Object.defineProperty(f.prototype,"value",{get(){const t=s(this);if(void 0!==t)t.i=this.i;return this.v},set(t){if(t!==this.v){if(u>100)throw new Error("Cycle detected");this.v=t;this.i++;l++;r++;try{for(let t=this.t;void 0!==t;t=t.x)t.t.N()}finally{n()}}}});function c(t){return new f(t)}function h(t){for(let n=t.s;void 0!==n;n=n.n)if(n.S.i!==n.i||!n.S.h()||n.S.i!==n.i)return!0;return!1}function a(t){for(let n=t.s;void 0!==n;n=n.n){const e=n.S.n;if(void 0!==e)n.r=e;n.S.n=n;n.i=-1;if(void 0===n.n){t.s=n;break}}}function p(t){let n,e=t.s;while(void 0!==e){const t=e.p;if(-1===e.i){e.S.U(e);if(void 0!==t)t.n=e.n;if(void 0!==e.n)e.n.p=t}else n=e;e.S.n=e.r;if(void 0!==e.r)e.r=void 0;e=t}t.s=n}function d(t){f.call(this,void 0);this.x=t;this.s=void 0;this.g=l-1;this.f=4}(d.prototype=new f).h=function(){this.f&=-3;if(1&this.f)return!1;if(32==(36&this.f))return!0;this.f&=-5;if(this.g===l)return!0;this.g=l;this.f|=1;if(this.i>0&&!h(this)){this.f&=-2;return!0}const t=_;try{a(this);_=this;const t=this.x();if(16&this.f||this.v!==t||0===this.i){this.v=t;this.f&=-17;this.i++}}catch(t){this.v=t;this.f|=16;this.i++}_=t;p(this);this.f&=-2;return!0};d.prototype.S=function(t){if(void 0===this.t){this.f|=36;for(let t=this.s;void 0!==t;t=t.n)t.S.S(t)}f.prototype.S.call(this,t)};d.prototype.U=function(t){if(void 0!==this.t){f.prototype.U.call(this,t);if(void 0===this.t){this.f&=-33;for(let t=this.s;void 0!==t;t=t.n)t.S.U(t)}}};d.prototype.N=function(){if(!(2&this.f)){this.f|=6;for(let t=this.t;void 0!==t;t=t.x)t.t.N()}};Object.defineProperty(d.prototype,"value",{get(){if(1&this.f)throw new Error("Cycle detected");const t=s(this);this.h();if(void 0!==t)t.i=this.i;if(16&this.f)throw this.v;return this.v}});function v(t){return new d(t)}function y(t){const e=t.u;t.u=void 0;if("function"==typeof e){r++;const i=_;_=void 0;try{e()}catch(n){t.f&=-2;t.f|=8;m(t);throw n}finally{_=i;n()}}}function m(t){for(let n=t.s;void 0!==n;n=n.n)n.S.U(n);t.x=void 0;t.s=void 0;y(t)}function g(t){if(_!==this)throw new Error("Out-of-order effect");p(this);_=t;this.f&=-2;if(8&this.f)m(this);n()}function b(t){this.x=t;this.u=void 0;this.s=void 0;this.o=void 0;this.f=32}b.prototype.c=function(){const t=this.S();try{if(8&this.f)return;if(void 0===this.x)return;const n=this.x();if("function"==typeof n)this.u=n}finally{t()}};b.prototype.S=function(){if(1&this.f)throw new Error("Cycle detected");this.f|=1;this.f&=-9;y(this);a(this);r++;const t=_;_=this;return g.bind(this,t)};b.prototype.N=function(){if(!(2&this.f)){this.f|=2;this.o=i;i=this}};b.prototype.d=function(){this.f|=8;if(!(1&this.f))m(this)};function k(t){const n=new b(t);try{n.c()}catch(t){n.d();throw t}return n.d.bind(n)}var w,S,x,C,U,E,H,P,N,$,T,D,M={},A=[],F=/acit|ex(?:s|g|n|p|$)|rph|grid|ows|mnc|ntw|ine[ch]|zoo|^ord|itera/i,W=Array.isArray;function L(t,n){for(var e in n)t[e]=n[e];return t}function O(t){t&&t.parentNode&&t.parentNode.removeChild(t)}function R(t,n,e){var _,i,o,r={};for(o in n)"key"==o?_=n[o]:"ref"==o?i=n[o]:r[o]=n[o];if(arguments.length>2&&(r.children=arguments.length>3?w.call(arguments,2):e),"function"==typeof t&&null!=t.defaultProps)for(o in t.defaultProps)void 0===r[o]&&(r[o]=t.defaultProps[o]);return I(t,r,_,i,null)}function I(t,n,e,_,i){var o={type:t,props:n,key:e,ref:_,__k:null,__:null,__b:0,__e:null,__d:void 0,__c:null,constructor:void 0,__v:null==i?++x:i,__i:-1,__u:0};return null==i&&null!=S.vnode&&S.vnode(o),o}function V(){return{current:null}}function j(t){return t.children}function q(t,n){this.props=t,this.context=n}function B(t,n){if(null==n)return t.__?B(t.__,t.__i+1):null;for(var e;nn&&U.sort(P));J.__r=0}function K(t,n,e,_,i,o,r,u,l,s,f){var c,h,a,p,d,v=_&&_.__k||A,y=n.length;for(e.__d=l,Q(e,n,v),l=e.__d,c=0;c0?I(i.type,i.props,i.key,i.ref?i.ref:null,i.__v):i).__=t,i.__b=t.__b+1,o=null,-1!==(u=i.__i=Z(i,e,r,f))&&(f--,(o=e[u])&&(o.__u|=131072)),null==o||null===o.__v?(-1==u&&c--,"function"!=typeof i.type&&(i.__u|=65536)):u!==r&&(u==r-1?c--:u==r+1?c++:(u>r?c--:c++,i.__u|=65536))):i=t.__k[_]=null;if(f)for(_=0;_(null!=l&&0==(131072&l.__u)?1:0))for(;r>=0||u=0){if((l=n[r])&&0==(131072&l.__u)&&i==l.key&&o===l.type)return r;r--}if(u2&&(u.children=arguments.length>3?w.call(arguments,2):e),I(t.type,u,_||t.key,i||t.ref,null)}function ht(t,n){var e={__c:n="__cC"+D++,__:t,Consumer:function(t,n){return t.children(n)},Provider:function(t){var e,_;return this.getChildContext||(e=new Set,(_={})[n]=this,this.getChildContext=function(){return _},this.componentWillUnmount=function(){e=null},this.shouldComponentUpdate=function(t){this.props.value!==t.value&&e.forEach((function(t){t.__e=!0,G(t)}))},this.sub=function(t){e.add(t);var n=t.componentWillUnmount;t.componentWillUnmount=function(){e&&e.delete(t),n&&n.call(t)}}),t.children}};return e.Provider.__=e.Consumer.contextType=e}w=A.slice,S={__e:function(t,n,e,_){for(var i,o,r;n=n.__;)if((i=n.__c)&&!i.__)try{if((o=i.constructor)&&null!=o.getDerivedStateFromError&&(i.setState(o.getDerivedStateFromError(t)),r=i.__d),null!=i.componentDidCatch&&(i.componentDidCatch(t,_||{}),r=i.__d),r)return i.__E=i}catch(n){t=n}throw t}},x=0,C=function(t){return null!=t&&null==t.constructor},q.prototype.setState=function(t,n){var e;e=null!=this.__s&&this.__s!==this.state?this.__s:this.__s=L({},this.state),"function"==typeof t&&(t=t(L({},e),this.props)),t&&L(e,t),null!=t&&this.__v&&(n&&this._sb.push(n),G(this))},q.prototype.forceUpdate=function(t){this.__v&&(this.__e=!0,t&&this.__h.push(t),G(this))},q.prototype.render=j,U=[],H="function"==typeof Promise?Promise.prototype.then.bind(Promise.resolve()):setTimeout,P=function(t,n){return t.__v.__b-n.__v.__b},J.__r=0,N=0,$=et(!1),T=et(!0),D=0;var at,pt,dt,vt,yt=0,mt=[],gt=S,bt=gt.__b,kt=gt.__r,wt=gt.diffed,St=gt.__c,xt=gt.unmount,Ct=gt.__;function Ut(t,n){gt.__h&>.__h(pt,t,yt||n),yt=0;var e=pt.__H||(pt.__H={__:[],__h:[]});return t>=e.__.length&&e.__.push({}),e.__[t]}function Et(t){return yt=1,Ht(Bt,t)}function Ht(t,n,e){var _=Ut(at++,2);if(_.t=t,!_.__c&&(_.__=[e?e(n):Bt(void 0,n),function(t){var n=_.__N?_.__N[0]:_.__[0],e=_.t(n,t);n!==e&&(_.__N=[e,_.__[1]],_.__c.setState({}))}],_.__c=pt,!pt.u)){var i=function(t,n,e){if(!_.__c.__H)return!0;var i=_.__c.__H.__.filter((function(t){return!!t.__c}));if(i.every((function(t){return!t.__N})))return!o||o.call(this,t,n,e);var r=!1;return i.forEach((function(t){if(t.__N){var n=t.__[0];t.__=t.__N,t.__N=void 0,n!==t.__[0]&&(r=!0)}})),!(!r&&_.__c.props===t)&&(!o||o.call(this,t,n,e))};pt.u=!0;var o=pt.shouldComponentUpdate,r=pt.componentWillUpdate;pt.componentWillUpdate=function(t,n,e){if(this.__e){var _=o;o=void 0,i(t,n,e),o=_}r&&r.call(this,t,n,e)},pt.shouldComponentUpdate=i}return _.__N||_.__}function Pt(t,n){var e=Ut(at++,3);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__H.__h.push(e))}function Nt(t,n){var e=Ut(at++,4);!gt.__s&&qt(e.__H,n)&&(e.__=t,e.i=n,pt.__h.push(e))}function $t(t){return yt=5,Dt((function(){return{current:t}}),[])}function Tt(t,n,e){yt=6,Nt((function(){return"function"==typeof t?(t(n()),function(){return t(null)}):t?(t.current=n(),function(){return t.current=null}):void 0}),null==e?e:e.concat(t))}function Dt(t,n){var e=Ut(at++,7);return qt(e.__H,n)&&(e.__=t(),e.__H=n,e.__h=t),e.__}function Mt(t,n){return yt=8,Dt((function(){return t}),n)}function At(t){var n=pt.context[t.__c],e=Ut(at++,9);return e.c=t,n?(null==e.__&&(e.__=!0,n.sub(pt)),n.props.value):t.__}function Ft(t,n){gt.useDebugValue&>.useDebugValue(n?n(t):t)}function Wt(t){var n=Ut(at++,10),e=Et();return n.__=t,pt.componentDidCatch||(pt.componentDidCatch=function(t,_){n.__&&n.__(t,_),e[1](t)}),[e[0],function(){e[1](void 0)}]}function Lt(){var t=Ut(at++,11);if(!t.__){for(var n=pt.__v;null!==n&&!n.__m&&null!==n.__;)n=n.__;var e=n.__m||(n.__m=[0,0]);t.__="P"+e[0]+"-"+e[1]++}return t.__}function Ot(){for(var t;t=mt.shift();)if(t.__P&&t.__H)try{t.__H.__h.forEach(Vt),t.__H.__h.forEach(jt),t.__H.__h=[]}catch(n){t.__H.__h=[],gt.__e(n,t.__v)}}gt.__b=function(t){pt=null,bt&&bt(t)},gt.__=function(t,n){t&&n.__k&&n.__k.__m&&(t.__m=n.__k.__m),Ct&&Ct(t,n)},gt.__r=function(t){kt&&kt(t),at=0;var n=(pt=t.__c).__H;n&&(dt===pt?(n.__h=[],pt.__h=[],n.__.forEach((function(t){t.__N&&(t.__=t.__N),t.i=t.__N=void 0}))):(n.__h.forEach(Vt),n.__h.forEach(jt),n.__h=[],at=0)),dt=pt},gt.diffed=function(t){wt&&wt(t);var n=t.__c;n&&n.__H&&(n.__H.__h.length&&(1!==mt.push(n)&&vt===gt.requestAnimationFrame||((vt=gt.requestAnimationFrame)||It)(Ot)),n.__H.__.forEach((function(t){t.i&&(t.__H=t.i),t.i=void 0}))),dt=pt=null},gt.__c=function(t,n){n.some((function(t){try{t.__h.forEach(Vt),t.__h=t.__h.filter((function(t){return!t.__||jt(t)}))}catch(r){n.some((function(t){t.__h&&(t.__h=[])})),n=[],gt.__e(r,t.__v)}})),St&&St(t,n)},gt.unmount=function(t){xt&&xt(t);var n,e=t.__c;e&&e.__H&&(e.__H.__.forEach((function(t){try{Vt(t)}catch(t){n=t}})),e.__H=void 0,n&>.__e(n,e.__v))};var Rt="function"==typeof requestAnimationFrame;function It(t){var n,e=function(){clearTimeout(_),Rt&&cancelAnimationFrame(n),setTimeout(t)},_=setTimeout(e,100);Rt&&(n=requestAnimationFrame(e))}function Vt(t){var n=pt,e=t.__c;"function"==typeof e&&(t.__c=void 0,e()),pt=n}function jt(t){var n=pt;t.__c=t.__(),pt=n}function qt(t,n){return!t||t.length!==n.length||n.some((function(n,e){return n!==t[e]}))}function Bt(t,n){return"function"==typeof n?n(t):n}function zt(t,n){S[t]=n.bind(null,S[t]||(()=>{}))}let Gt,Jt;function Kt(t){if(Jt)Jt();Jt=t&&t.S()}function Qt({data:t}){const n=Yt(t);n.value=t;const e=Dt(()=>{let t=this.__v;while(t=t.__)if(t.__c){t.__c.__$f|=4;break}this.__$u.c=()=>{var t;if(!C(e.peek())&&3===(null==(t=this.base)?void 0:t.nodeType))this.base.data=e.peek();else{this.__$f|=1;this.setState({})}};return v(()=>{let t=n.value.value;return 0===t?0:!0===t?"":t||""})},[]);return e.value}Qt.displayName="_st";Object.defineProperties(f.prototype,{constructor:{configurable:!0,value:void 0},type:{configurable:!0,value:Qt},props:{configurable:!0,get(){return{data:this}}},__b:{configurable:!0,value:1}});zt("__b",(t,n)=>{if("string"==typeof n.type){let t,e=n.props;for(let _ in e){if("children"===_)continue;let i=e[_];if(i instanceof f){if(!t)n.__np=t={};t[_]=i;e[_]=i.peek()}}}t(n)});zt("__r",(t,n)=>{Kt();let e,_=n.__c;if(_){_.__$f&=-2;e=_.__$u;if(void 0===e)_.__$u=e=function(t){let n;k((function(){n=this}));n.c=()=>{_.__$f|=1;_.setState({})};return n}()}Gt=_;Kt(e);t(n)});zt("__e",(t,n,e,_)=>{Kt();Gt=void 0;t(n,e,_)});zt("diffed",(t,n)=>{Kt();Gt=void 0;let e;if("string"==typeof n.type&&(e=n.__e)){let t=n.__np,_=n.props;if(t){let n=e.U;if(n)for(let e in n){let _=n[e];if(void 0!==_&&!(e in t)){_.d();n[e]=void 0}}else{n={};e.U=n}for(let i in t){let o=n[i],r=t[i];if(void 0===o){o=Xt(e,i,r,_);n[i]=o}else o.o(r,_)}}}t(n)});function Xt(t,n,e,_){const i=n in t&&void 0===t.ownerSVGElement,o=c(e);return{o:(t,n)=>{o.value=t;_=n},d:k(()=>{const e=o.value.value;if(_[n]!==e){_[n]=e;if(i)t[n]=e;else if(e)t.setAttribute(n,e);else t.removeAttribute(n)}})}}zt("unmount",(t,n)=>{if("string"==typeof n.type){let t=n.__e;if(t){const n=t.U;if(n){t.U=void 0;for(let t in n){let e=n[t];if(e)e.d()}}}}else{let t=n.__c;if(t){const n=t.__$u;if(n){t.__$u=void 0;n.d()}}}t(n)});zt("__h",(t,n,e,_)=>{if(_<3||9===_)n.__$f|=2;t(n,e,_)});q.prototype.shouldComponentUpdate=function(t,n){const e=this.__$u;if(!(e&&void 0!==e.s||4&this.__$f))return!0;if(3&this.__$f)return!0;for(let _ in n)return!0;for(let _ in t)if("__source"!==_&&t[_]!==this.props[_])return!0;for(let _ in this.props)if(!(_ in t))return!0;return!1};function Yt(t){return Dt(()=>c(t),[])}function Zt(t){const n=$t(t);n.current=t;Gt.__$f|=4;return Dt(()=>v(()=>n.current()),[])}function tn(t){const n=$t(t);n.current=t;Pt(()=>k(()=>n.current()),[])}var nn=function(t,n,e,_){var i;n[0]=0;for(var o=1;o=5&&((i||!t&&5===_)&&(r.push(_,0,i,e),_=6),t&&(r.push(_,t,0,e),_=6)),i=""},l=0;l"===n?(_=1,i=""):i=n+i[0]:o?n===o?o="":i+=n:'"'===n||"'"===n?o=n:">"===n?(u(),_=1):_&&("="===n?(_=5,e=i,i=""):"/"===n&&(_<5||">"===t[l][s+1])?(u(),3===_&&(r=r[0]),_=r,(r=r[0]).push(2,0,_),_=0):" "===n||"\t"===n||"\n"===n||"\r"===n?(u(),_=2):i+=n),3===_&&"!--"===i&&(_=4,r=r[0])}return u(),r}(t)),n),arguments,[])).length>1?n:n[0]}var on=_n.bind(R);export{q as Component,j as Fragment,f as Signal,e as batch,ct as cloneElement,v as computed,ht as createContext,R as createElement,V as createRef,k as effect,R as h,on as html,ft as hydrate,C as isValidElement,S as options,st as render,c as signal,Y as toChildArray,o as untracked,Mt as useCallback,Zt as useComputed,At as useContext,Ft as useDebugValue,Pt as useEffect,Wt as useErrorBoundary,Lt as useId,Tt as useImperativeHandle,Nt as useLayoutEffect,Dt as useMemo,Ht as useReducer,$t as useRef,Yt as useSignal,tn as useSignalEffect,Et as useState}; diff --git a/examples/server/public/json-schema-to-grammar.mjs b/examples/server/public_legacy/json-schema-to-grammar.mjs similarity index 99% rename from examples/server/public/json-schema-to-grammar.mjs rename to examples/server/public_legacy/json-schema-to-grammar.mjs index 7267f3f9c..e67bb15c1 100644 --- a/examples/server/public/json-schema-to-grammar.mjs +++ b/examples/server/public_legacy/json-schema-to-grammar.mjs @@ -529,7 +529,7 @@ export class SchemaConverter { return joinSeq(); }; - return this._addRule(name, "\"\\\"\" " + toRule(transform()) + " \"\\\"\" space") + return this._addRule(name, "\"\\\"\" (" + toRule(transform()) + ") \"\\\"\" space") } _notStrings(strings) { diff --git a/examples/server/public_legacy/loading.html b/examples/server/public_legacy/loading.html new file mode 100644 index 000000000..c3fd19a0f --- /dev/null +++ b/examples/server/public_legacy/loading.html @@ -0,0 +1,12 @@ + + + + + + +
+ The model is loading. Please wait.
+ The user interface will appear soon. +
+ + diff --git a/examples/server/public/prompt-formats.js b/examples/server/public_legacy/prompt-formats.js similarity index 100% rename from examples/server/public/prompt-formats.js rename to examples/server/public_legacy/prompt-formats.js diff --git a/examples/server/public/style.css b/examples/server/public_legacy/style.css old mode 100755 new mode 100644 similarity index 100% rename from examples/server/public/style.css rename to examples/server/public_legacy/style.css diff --git a/examples/server/public/system-prompts.js b/examples/server/public_legacy/system-prompts.js similarity index 100% rename from examples/server/public/system-prompts.js rename to examples/server/public_legacy/system-prompts.js diff --git a/examples/server/public/theme-beeninorder.css b/examples/server/public_legacy/theme-beeninorder.css similarity index 100% rename from examples/server/public/theme-beeninorder.css rename to examples/server/public_legacy/theme-beeninorder.css diff --git a/examples/server/public/theme-ketivah.css b/examples/server/public_legacy/theme-ketivah.css similarity index 100% rename from examples/server/public/theme-ketivah.css rename to examples/server/public_legacy/theme-ketivah.css diff --git a/examples/server/public/theme-mangotango.css b/examples/server/public_legacy/theme-mangotango.css similarity index 100% rename from examples/server/public/theme-mangotango.css rename to examples/server/public_legacy/theme-mangotango.css diff --git a/examples/server/public/theme-playground.css b/examples/server/public_legacy/theme-playground.css similarity index 100% rename from examples/server/public/theme-playground.css rename to examples/server/public_legacy/theme-playground.css diff --git a/examples/server/public/theme-polarnight.css b/examples/server/public_legacy/theme-polarnight.css similarity index 100% rename from examples/server/public/theme-polarnight.css rename to examples/server/public_legacy/theme-polarnight.css diff --git a/examples/server/public/theme-snowstorm.css b/examples/server/public_legacy/theme-snowstorm.css similarity index 100% rename from examples/server/public/theme-snowstorm.css rename to examples/server/public_legacy/theme-snowstorm.css diff --git a/examples/server/public_simplechat/simplechat.js b/examples/server/public_simplechat/simplechat.js index 8e0df3b61..2fcd24a86 100644 --- a/examples/server/public_simplechat/simplechat.js +++ b/examples/server/public_simplechat/simplechat.js @@ -407,6 +407,9 @@ class SimpleChat { if (curLine.startsWith("data:")) { curLine = curLine.substring(5); } + if (curLine.trim() === "[DONE]") { + break; + } let curJson = JSON.parse(curLine); console.debug("DBUG:SC:PART:Json:", curJson); this.append_response(this.response_extract_stream(curJson, apiEP)); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9ab8f8ca6..d1ea343dd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,8 +1,12 @@ #include "utils.hpp" +#include "arg.h" #include "common.h" #include "json-schema-to-grammar.h" #include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -10,48 +14,38 @@ // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" -// auto generated files (update with ./deps.sh) -#include "colorthemes.css.hpp" -#include "style.css.hpp" -#include "theme-beeninorder.css.hpp" -#include "theme-ketivah.css.hpp" -#include "theme-mangotango.css.hpp" -#include "theme-playground.css.hpp" -#include "theme-polarnight.css.hpp" -#include "theme-snowstorm.css.hpp" -#include "index.html.hpp" -#include "index-new.html.hpp" -#include "index.js.hpp" -#include "completion.js.hpp" -#include "system-prompts.js.hpp" -#include "prompt-formats.js.hpp" -#include "json-schema-to-grammar.mjs.hpp" +// auto generated files (see README.md for details) +#include "index.html.gz.hpp" +#include "loading.html.hpp" #include #include #include #include -#include -#include -#include -#include -#include -#include +#include #include +#include +#include +#include +#include +#include +#include using json = nlohmann::ordered_json; -bool server_verbose = false; -bool server_log_json = true; +constexpr int HTTP_POLLING_SECONDS = 1; enum stop_type { - STOP_TYPE_FULL, - STOP_TYPE_PARTIAL, + STOP_TYPE_NONE, + STOP_TYPE_EOS, + STOP_TYPE_WORD, + STOP_TYPE_LIMIT, }; // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future SLOT_STATE_PROCESSING_PROMPT, SLOT_STATE_DONE_PROMPT, SLOT_STATE_GENERATING, @@ -64,6 +58,9 @@ enum server_state { enum server_task_type { SERVER_TASK_TYPE_COMPLETION, + SERVER_TASK_TYPE_EMBEDDING, + SERVER_TASK_TYPE_RERANK, + SERVER_TASK_TYPE_INFILL, SERVER_TASK_TYPE_CANCEL, SERVER_TASK_TYPE_NEXT_RESPONSE, SERVER_TASK_TYPE_METRICS, @@ -73,20 +70,376 @@ enum server_task_type { SERVER_TASK_TYPE_SET_LORA, }; -enum server_task_cmpl_type { - SERVER_TASK_CMPL_TYPE_NORMAL, - SERVER_TASK_CMPL_TYPE_EMBEDDING, - SERVER_TASK_CMPL_TYPE_INFILL, +enum oaicompat_type { + OAICOMPAT_TYPE_NONE, + OAICOMPAT_TYPE_CHAT, + OAICOMPAT_TYPE_COMPLETION, + OAICOMPAT_TYPE_EMBEDDING, +}; + +// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 +enum error_type { + ERROR_TYPE_INVALID_REQUEST, + ERROR_TYPE_AUTHENTICATION, + ERROR_TYPE_SERVER, + ERROR_TYPE_NOT_FOUND, + ERROR_TYPE_PERMISSION, + ERROR_TYPE_UNAVAILABLE, // custom error + ERROR_TYPE_NOT_SUPPORTED, // custom error +}; + +struct slot_params { + bool stream = true; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half + int32_t n_predict = -1; // new tokens to predict + int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters + + int64_t t_max_prompt_ms = -1; // TODO: implement + int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + + std::vector lora; + + std::vector antiprompt; + std::vector response_fields; + bool timings_per_token = false; + bool post_sampling_probs = false; + bool ignore_eos = false; + + struct common_params_sampling sampling; + struct common_params_speculative speculative; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + json to_json() const { + std::vector samplers; + samplers.reserve(sampling.samplers.size()); + for (const auto & sampler : sampling.samplers) { + samplers.emplace_back(common_sampler_type_to_str(sampler)); + } + + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + + return json { + {"n_predict", n_predict}, // Server configured n_predict + {"seed", sampling.seed}, + {"temperature", sampling.temp}, + {"dynatemp_range", sampling.dynatemp_range}, + {"dynatemp_exponent", sampling.dynatemp_exponent}, + {"top_k", sampling.top_k}, + {"top_p", sampling.top_p}, + {"min_p", sampling.min_p}, + {"xtc_probability", sampling.xtc_probability}, + {"xtc_threshold", sampling.xtc_threshold}, + {"typical_p", sampling.typ_p}, + {"repeat_last_n", sampling.penalty_last_n}, + {"repeat_penalty", sampling.penalty_repeat}, + {"presence_penalty", sampling.penalty_present}, + {"frequency_penalty", sampling.penalty_freq}, + {"dry_multiplier", sampling.dry_multiplier}, + {"dry_base", sampling.dry_base}, + {"dry_allowed_length", sampling.dry_allowed_length}, + {"dry_penalty_last_n", sampling.dry_penalty_last_n}, + {"dry_sequence_breakers", sampling.dry_sequence_breakers}, + {"mirostat", sampling.mirostat}, + {"mirostat_tau", sampling.mirostat_tau}, + {"mirostat_eta", sampling.mirostat_eta}, + {"stop", antiprompt}, + {"max_tokens", n_predict}, // User configured n_predict + {"n_keep", n_keep}, + {"n_discard", n_discard}, + {"ignore_eos", sampling.ignore_eos}, + {"stream", stream}, + {"logit_bias", format_logit_bias(sampling.logit_bias)}, + {"n_probs", sampling.n_probs}, + {"min_keep", sampling.min_keep}, + {"grammar", sampling.grammar}, + // {"grammar_trigger_words", sampling.grammar_trigger_words}, + {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, + {"samplers", samplers}, + {"speculative.n_max", speculative.n_max}, + {"speculative.n_min", speculative.n_min}, + {"speculative.p_min", speculative.p_min}, + {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, + }; + } }; struct server_task { - int id = -1; // to be filled by server_queue - int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL + int id = -1; // to be filled by server_queue + int index = -1; // used when there are multiple prompts (batch request) server_task_type type; - json data; - server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + // used by SERVER_TASK_TYPE_CANCEL + int id_target = -1; + + // used by SERVER_TASK_TYPE_INFERENCE + slot_params params; + llama_tokens prompt_tokens; + int id_selected_slot = -1; + + // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE + struct slot_action { + int slot_id; + std::string filename; + std::string filepath; + }; + slot_action slot_action; + + // used by SERVER_TASK_TYPE_METRICS + bool metrics_reset_bucket = false; + + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + + server_task(server_task_type type) : type(type) {} + + static slot_params params_from_json_cmpl( + const llama_context * ctx, + const common_params & params_base, + const json & data) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + slot_params params; + + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + slot_params defaults; + defaults.sampling = params_base.sampling; + defaults.speculative = params_base.speculative; + + // enabling this will output extra debug information in the HTTP responses from the server + params.verbose = params_base.verbosity > 9; + params.timings_per_token = json_value(data, "timings_per_token", false); + + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + params.n_indent = json_value(data, "n_indent", defaults.n_indent); + params.n_keep = json_value(data, "n_keep", defaults.n_keep); + params.n_discard = json_value(data, "n_discard", defaults.n_discard); + //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement + params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); + params.response_fields = json_value(data, "response_fields", std::vector()); + + params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); + params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); + params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); + params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); + params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); + params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); + params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); + params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); + params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); + params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); + params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); + params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); + params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); + params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); + params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); + params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); + params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); + params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); + params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); + params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); + params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); + params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); + params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); + params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); + params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); + + params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); + params.speculative.n_min = std::max(params.speculative.n_min, 2); + params.speculative.n_max = std::max(params.speculative.n_max, 0); + + // Use OpenAI API logprobs only if n_probs wasn't provided + if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ + params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); + } + + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = params_base.lora_adapters; + } + + // TODO: add more sanity checks for the input parameters + + if (params.sampling.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (params.sampling.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (params.sampling.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + params.sampling.penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); + } + + if (params.sampling.dry_base < 1.0f) { + params.sampling.dry_base = defaults.sampling.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (params.sampling.dry_sequence_breakers.empty()) { + throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); + } + } + } + + // process "json_schema" and "grammar" + if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { + throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); + } + if (data.contains("json_schema") && !data.contains("grammar")) { + try { + auto schema = json_value(data, "json_schema", json::object()); + LOG_DBG("JSON schema: %s\n", schema.dump(2).c_str()); + params.sampling.grammar = json_schema_to_grammar(schema); + LOG_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); + } catch (const std::exception & e) { + throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); + } + } else { + params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); + LOG_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); + params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); + LOG_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); + } + + { + auto it = data.find("chat_format"); + if (it != data.end()) { + params.oaicompat_chat_format = static_cast(it->get()); + LOG_DBG("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str()); + } else { + params.oaicompat_chat_format = defaults.oaicompat_chat_format; + } + } + + { + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + common_grammar_trigger trigger; + trigger.word = t.at("word"); + trigger.at_start = t.at("at_start"); + + auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); + params.sampling.grammar_trigger_tokens.push_back(ids[0]); + continue; + } + LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); + params.sampling.grammar_trigger_words.push_back(trigger); + } + } + if (params.sampling.grammar_lazy) { + GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0); + } + } + + { + params.sampling.logit_bias.clear(); + params.ignore_eos = json_value(data, "ignore_eos", false); + + const auto & logit_bias = data.find("logit_bias"); + if (logit_bias != data.end() && logit_bias->is_array()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : *logit_bias) { + // TODO: we may want to throw errors here, in case "el" is incorrect + if (el.is_array() && el.size() == 2) { + float bias; + if (el[1].is_number()) { + bias = el[1].get(); + } else if (el[1].is_boolean() && !el[1].get()) { + bias = -INFINITY; + } else { + continue; + } + + if (el[0].is_number_integer()) { + llama_token tok = el[0].get(); + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else if (el[0].is_string()) { + auto toks = common_tokenize(vocab, el[0].get(), false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } + } + } + } + + { + params.antiprompt.clear(); + + const auto & stop = data.find("stop"); + if (stop != data.end() && stop->is_array()) { + for (const auto & word : *stop) { + if (!word.empty()) { + params.antiprompt.push_back(word); + } + } + } + } + + { + const auto samplers = data.find("samplers"); + if (samplers != data.end()) { + if (samplers->is_array()) { + params.sampling.samplers = common_sampler_types_from_names(*samplers, false); + } else if (samplers->is_string()){ + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); + } + } else { + params.sampling.samplers = defaults.sampling.samplers; + } + } + + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; + params.oaicompat_model = json_value(data, "model", model_name); + + return params; + } // utility function static std::unordered_set get_list_id(const std::vector & tasks) { @@ -98,33 +451,750 @@ struct server_task { } }; -struct server_task_result { - int id = -1; +struct result_timings { + int32_t prompt_n = -1; + double prompt_ms; + double prompt_per_token_ms; + double prompt_per_second; - json data; + int32_t predicted_n = -1; + double predicted_ms; + double predicted_per_token_ms; + double predicted_per_second; - bool stop; - bool error; + json to_json() const { + return { + {"prompt_n", prompt_n}, + {"prompt_ms", prompt_ms}, + {"prompt_per_token_ms", prompt_per_token_ms}, + {"prompt_per_second", prompt_per_second}, + + {"predicted_n", predicted_n}, + {"predicted_ms", predicted_ms}, + {"predicted_per_token_ms", predicted_per_token_ms}, + {"predicted_per_second", predicted_per_second}, + }; + } }; -struct slot_params { - bool stream = true; - bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt +struct server_task_result { + int id = -1; + int id_slot = -1; + virtual bool is_error() { + // only used by server_task_result_error + return false; + } + virtual bool is_stop() { + // only used by server_task_result_cmpl_* + return false; + } + virtual int get_index() { + return -1; + } + virtual json to_json() = 0; + virtual ~server_task_result() = default; +}; - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict +// using shared_ptr for polymorphism of server_task_result +using server_task_result_ptr = std::unique_ptr; - std::vector antiprompt; +inline std::string stop_type_to_str(stop_type type) { + switch (type) { + case STOP_TYPE_EOS: return "eos"; + case STOP_TYPE_WORD: return "word"; + case STOP_TYPE_LIMIT: return "limit"; + default: return "none"; + } +} - json input_prefix; - json input_suffix; +struct completion_token_output { + llama_token tok; + float prob; + std::string text_to_send; + struct prob_info { + llama_token tok; + std::string txt; + float prob; + }; + std::vector probs; + + json to_json(bool post_sampling_probs) const { + json probs_for_token = json::array(); + for (const auto & p : probs) { + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); + probs_for_token.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + }); + } + return probs_for_token; + } + + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { + json out = json::array(); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); + out.push_back(json { + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, + }); + } + return out; + } + + static float logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + } + + static std::vector str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; + } +}; + +struct server_task_result_cmpl_final : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + bool stream; + result_timings timings; + std::string prompt; + + bool truncated; + int32_t n_decoded; + int32_t n_prompt_tokens; + int32_t n_tokens_cached; + bool has_new_line; + std::string stopping_word; + stop_type stop = STOP_TYPE_NONE; + + bool post_sampling_probs; + std::vector probs_output; + std::vector response_fields; + + slot_params generation_params; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return true; // in stream mode, final responses are considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + json res = json { + {"index", index}, + {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"tokens", stream ? llama_tokens {} : tokens}, + {"id_slot", id_slot}, + {"stop", true}, + {"model", oaicompat_model}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + {"generation_settings", generation_params.to_json()}, + {"prompt", prompt}, + {"has_new_line", has_new_line}, + {"truncated", truncated}, + {"stop_type", stop_type_to_str(stop)}, + {"stopping_word", stopping_word}, + {"tokens_cached", n_tokens_cached}, + {"timings", timings.to_json()}, + }; + if (!stream && !probs_output.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); + } + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (!stream && probs_output.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + json finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + json res = json { + {"choices", json::array({ + json{ + {"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", finish_reason}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + std::string finish_reason = "length"; + common_chat_msg message; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + message = common_chat_parse(content, oaicompat_chat_format); + finish_reason = message.tool_calls.empty() ? "stop" : "tool_calls"; + } else { + message.content = content; + } + + json tool_calls; + if (!message.tool_calls.empty()) { + tool_calls = json::array(); + for (const auto & tc : message.tool_calls) { + tool_calls.push_back({ + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + {"id", tc.id.empty() ? json() : json(tc.id)}, + }); + } + } + + json choice { + {"finish_reason", finish_reason}, + {"index", 0}, + {"message", json { + {"content", message.content}, + {"tool_calls", tool_calls}, + {"role", "assistant"}, + }}, + }; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, + }; + } + + std::time_t t = std::time(0); + + json res = json { + {"choices", json::array({choice})}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens} + }}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat_stream() { + std::time_t t = std::time(0); + std::string finish_reason = "length"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + finish_reason = "stop"; + } + + json choice = json { + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()} + }; + + json ret = json { + {"choices", json::array({choice})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return ret; + } +}; + +struct server_task_result_cmpl_partial : server_task_result { + int index = 0; + + std::string content; + llama_tokens tokens; + + int32_t n_decoded; + int32_t n_prompt_tokens; + + bool post_sampling_probs; + completion_token_output prob_output; + result_timings timings; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // in stream mode, partial responses are not considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + // non-OAI-compat JSON + json res = json { + {"index", index}, + {"content", content}, + {"tokens", tokens}, + {"stop", false}, + {"id_slot", id_slot}, + {"tokens_predicted", n_decoded}, + {"tokens_evaluated", n_prompt_tokens}, + }; + // populate the timings object when needed (usually for the last response or with timings_per_token enabled) + if (timings.prompt_n > 0) { + res.push_back({"timings", timings.to_json()}); + } + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); + } + return res; + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json logprobs = json(nullptr); // OAI default to null + if (prob_output.probs.size() > 0) { + logprobs = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + json res = json { + {"choices", json::array({ + json{ + {"text", content}, + {"index", index}, + {"logprobs", logprobs}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id} + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + if (timings.prompt_n >= 0) { + res.push_back({"timings", timings.to_json()}); + } + + return res; + } + + json to_json_oaicompat_chat() { + bool first = n_decoded == 0; + std::time_t t = std::time(0); + json choices; + + if (first) { + if (content.empty()) { + choices = json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{{"role", "assistant"}}}}}); + } else { + // We have to send this as two updates to conform to openai behavior + json initial_ret = json{{"choices", json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json{ + {"role", "assistant"} + }}}})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + json second_ret = json{ + {"choices", json::array({json{{"finish_reason", nullptr}, + {"index", 0}, + {"delta", json { + {"content", content}}} + }})}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"object", "chat.completion.chunk"}}; + + return std::vector({initial_ret, second_ret}); + } + } else { + choices = json::array({json{ + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", + json { + {"content", content}, + }}, + }}); + } + + GGML_ASSERT(choices.size() >= 1); + + if (prob_output.probs.size() > 0) { + choices[0]["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, + }; + } + + json ret = json { + {"choices", choices}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"} + }; + + if (timings.prompt_n >= 0) { + ret.push_back({"timings", timings.to_json()}); + } + + return std::vector({ret}); + } +}; + +struct server_task_result_embd : server_task_result { + int index = 0; + std::vector> embedding; + + int32_t n_tokens; + + // OAI-compat fields + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? to_json_oaicompat() + : to_json_non_oaicompat(); + } + + json to_json_non_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding}, + }; + } + + json to_json_oaicompat() { + return json { + {"index", index}, + {"embedding", embedding[0]}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +struct server_task_result_rerank : server_task_result { + int index = 0; + float score = -1e6; + + int32_t n_tokens; + + virtual int get_index() override { + return index; + } + + virtual json to_json() override { + return json { + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, + }; + } +}; + +// this function maybe used outside of server_task_result_error +static json format_error_response(const std::string & message, const enum error_type type) { + std::string type_str; + int code = 500; + switch (type) { + case ERROR_TYPE_INVALID_REQUEST: + type_str = "invalid_request_error"; + code = 400; + break; + case ERROR_TYPE_AUTHENTICATION: + type_str = "authentication_error"; + code = 401; + break; + case ERROR_TYPE_NOT_FOUND: + type_str = "not_found_error"; + code = 404; + break; + case ERROR_TYPE_SERVER: + type_str = "server_error"; + code = 500; + break; + case ERROR_TYPE_PERMISSION: + type_str = "permission_error"; + code = 403; + break; + case ERROR_TYPE_NOT_SUPPORTED: + type_str = "not_supported_error"; + code = 501; + break; + case ERROR_TYPE_UNAVAILABLE: + type_str = "unavailable_error"; + code = 503; + break; + } + return json { + {"code", code}, + {"message", message}, + {"type", type_str}, + }; +} + +struct server_task_result_error : server_task_result { + int index = 0; + error_type err_type = ERROR_TYPE_SERVER; + std::string err_msg; + + virtual bool is_error() override { + return true; + } + + virtual json to_json() override { + return format_error_response(err_msg, err_type); + } +}; + +struct server_task_result_metrics : server_task_result { + int n_idle_slots; + int n_processing_slots; + int n_tasks_deferred; + int64_t t_start; + + int32_t kv_cache_tokens_count; + int32_t kv_cache_used_cells; + + // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + // while we can also use std::vector this requires copying the slot object which can be quite messy + // therefore, we use json to temporarily store the slot.to_json() result + json slots_data = json::array(); + + virtual json to_json() override { + return json { + { "idle", n_idle_slots }, + { "processing", n_processing_slots }, + { "deferred", n_tasks_deferred }, + { "t_start", t_start }, + + { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, + { "t_tokens_generation_total", t_tokens_generation_total }, + { "n_tokens_predicted_total", n_tokens_predicted_total }, + { "t_prompt_processing_total", t_prompt_processing_total }, + + { "n_prompt_tokens_processed", n_prompt_tokens_processed }, + { "t_prompt_processing", t_prompt_processing }, + { "n_tokens_predicted", n_tokens_predicted }, + { "t_tokens_generation", t_tokens_generation }, + + { "n_decode_total", n_decode_total }, + { "n_busy_slots_total", n_busy_slots_total }, + + { "kv_cache_tokens_count", kv_cache_tokens_count }, + { "kv_cache_used_cells", kv_cache_used_cells }, + + { "slots", slots_data }, + }; + } +}; + +struct server_task_result_slot_save_load : server_task_result { + std::string filename; + bool is_save; // true = save, false = load + + size_t n_tokens; + size_t n_bytes; + double t_ms; + + virtual json to_json() override { + if (is_save) { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_saved", n_tokens }, + { "n_written", n_bytes }, + { "timings", { + { "save_ms", t_ms } + }}, + }; + } else { + return json { + { "id_slot", id_slot }, + { "filename", filename }, + { "n_restored", n_tokens }, + { "n_read", n_bytes }, + { "timings", { + { "restore_ms", t_ms } + }}, + }; + } + } +}; + +struct server_task_result_slot_erase : server_task_result { + size_t n_erased; + + virtual json to_json() override { + return json { + { "id_slot", id_slot }, + { "n_erased", n_erased }, + }; + } +}; + +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { + return json {{ "success", true }}; + } }; struct server_slot { int id; int id_task = -1; + // only used for completion/embedding/infill/rerank + server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; + + llama_batch batch_spec = {}; + + llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; + + common_speculative * spec = nullptr; + + std::vector lora; + // the index relative to completion multi-task request size_t index = 0; @@ -143,75 +1213,77 @@ struct server_slot { int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict + // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated int32_t n_prompt_tokens = 0; int32_t n_prompt_tokens_processed = 0; - json prompt; // can be either a string, array of strings or array of token ids + // input prompt tokens + llama_tokens prompt_tokens; - // when a task is submitted, we first tokenize the prompt and store it here - std::vector prompt_tokens; + size_t last_nl_pos = 0; + + std::string generated_text; + llama_tokens generated_tokens; + + llama_tokens cache_tokens; - std::string generated_text; - std::vector cache_tokens; std::vector generated_token_probs; - server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; bool has_next_token = true; + bool has_new_line = false; bool truncated = false; - bool stopped_eos = false; - bool stopped_word = false; - bool stopped_limit = false; + stop_type stop; - bool oaicompat = false; - - std::string oaicompat_model; std::string stopping_word; // sampling json json_schema; - struct gpt_sampler_params sparams; - struct gpt_sampler * smpl = nullptr; + struct common_sampler * smpl = nullptr; llama_token sampled; - int32_t ga_i = 0; // group-attention state - int32_t ga_n = 1; // group-attention factor - int32_t ga_w = 512; // group-attention width - - int32_t n_past_se = 0; // self-extend + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; // stats - size_t n_sent_text = 0; // number of sent text character - size_t n_sent_token_probs = 0; + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; int64_t t_start_generation; double t_prompt_processing; // ms - double t_token_generation; // ms + double t_token_generation; // ms std::function callback_on_release; void reset() { + SLT_DBG(*this, "%s", "\n"); + n_prompt_tokens = 0; + last_nl_pos = 0; generated_text = ""; + has_new_line = false; truncated = false; - stopped_eos = false; - stopped_word = false; - stopped_limit = false; + stop = STOP_TYPE_NONE; stopping_word = ""; n_past = 0; n_sent_text = 0; - n_sent_token_probs = 0; - cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; - ga_i = 0; - n_past_se = 0; + task_type = SERVER_TASK_TYPE_COMPLETION; + generated_tokens.clear(); generated_token_probs.clear(); } - bool has_budget(gpt_params &global_params) { + bool is_non_causal() const { + return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; + } + + bool can_batch_with(server_slot & other_slot) { + return is_non_causal() == other_slot.is_non_causal() + && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params & global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless } @@ -231,8 +1303,13 @@ struct server_slot { return state != SLOT_STATE_IDLE; } - void add_token_string(const completion_token_output & token) { + bool can_speculate() const { + return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt; + } + + void add_token(const completion_token_output & token) { if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); return; } generated_token_probs.push_back(token); @@ -240,50 +1317,49 @@ struct server_slot { void release() { if (is_processing()) { + SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); + + t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; state = SLOT_STATE_IDLE; - LOG_INFO("slot released", { - {"id_slot", id}, - {"id_task", id_task}, - {"n_past", n_past}, - {"truncated", truncated}, - }); callback_on_release(id); } } - json get_formated_timings() const { - return json { - {"prompt_n", n_prompt_tokens_processed}, - {"prompt_ms", t_prompt_processing}, - {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed}, - {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed}, + result_timings get_timings() const { + result_timings timings; + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - {"predicted_n", n_decoded}, - {"predicted_ms", t_token_generation}, - {"predicted_per_token_ms", t_token_generation / n_decoded}, - {"predicted_per_second", 1e3 / t_token_generation * n_decoded}, - }; + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + return timings; } - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) { + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { size_t stop_pos = std::string::npos; for (const std::string & word : params.antiprompt) { size_t pos; - if (type == STOP_TYPE_FULL) { + if (is_full_stop) { const size_t tmp = word.size() + last_token_size; const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; pos = text.find(word, from_pos); } else { + // otherwise, partial stop pos = find_partial_stop_string(word, text); } if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (type == STOP_TYPE_FULL) { - stopped_word = true; + if (is_full_stop) { + stop = STOP_TYPE_WORD; stopping_word = word; has_next_token = false; } @@ -295,49 +1371,42 @@ struct server_slot { } void print_timings() const { - char buffer[512]; + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - double t_token = t_prompt_processing / n_prompt_tokens_processed; - double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; - snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)", - t_prompt_processing, n_prompt_tokens_processed, - t_token, n_tokens_second); + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_token_generation, n_decoded, t_gen, n_gen_second, + t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + } - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_prompt_processing", t_prompt_processing}, - {"n_prompt_tokens_processed", n_prompt_tokens_processed}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - t_token = t_token_generation / n_decoded; - n_tokens_second = 1e3 / t_token_generation * n_decoded; - - snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)", - t_token_generation, n_decoded, - t_token, n_tokens_second); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_token_generation", t_token_generation}, - {"n_decoded", n_decoded}, - {"t_token", t_token}, - {"n_tokens_second", n_tokens_second}, - }); - - snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation); - - LOG_INFO(buffer, { - {"id_slot", id}, - {"id_task", id_task}, - {"t_prompt_processing", t_prompt_processing}, - {"t_token_generation", t_token_generation}, - {"t_total", t_prompt_processing + t_token_generation}, - }); + json to_json() const { + return json { + {"id", id}, + {"id_task", id_task}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + {"non_causal", is_non_causal()}, + {"params", params.to_json()}, + {"prompt", common_detokenize(ctx, prompt_tokens)}, + {"next_token", + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + {"stopping_word", stopping_word}, + } + }, + }; } }; @@ -405,16 +1474,18 @@ struct server_queue { std::condition_variable condition_tasks; // callback functions - std::function callback_new_task; - std::function callback_update_slots; + std::function callback_new_task; + std::function callback_update_slots; // Add a new task to the end of the queue int post(server_task task, bool front = false) { std::unique_lock lock(mutex_tasks); - if (task.id == -1) { - task.id = id++; - LOG_VERBOSE("new task id", {{"new_id", task.id}}); + GGML_ASSERT(task.id != -1); + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); } + QUE_DBG("new task, id = %d, front = %d\n", task.id, front); if (front) { queue_tasks.push_front(std::move(task)); } else { @@ -430,8 +1501,12 @@ struct server_queue { for (auto & task : tasks) { if (task.id == -1) { task.id = id++; - LOG_VERBOSE("new task id", {{"new_id", task.id}}); } + // if this is cancel task make sure to clean up pending tasks + if (task.type == SERVER_TASK_TYPE_CANCEL) { + cleanup_pending_task(task.id_target); + } + QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); if (front) { queue_tasks.push_front(std::move(task)); } else { @@ -445,6 +1520,7 @@ struct server_queue { // Add a new task, but defer until one slot is available void defer(server_task task) { std::unique_lock lock(mutex_tasks); + QUE_DBG("defer task, id = %d\n", task.id); queue_tasks_deferred.push_back(std::move(task)); condition_tasks.notify_one(); } @@ -453,12 +1529,11 @@ struct server_queue { int get_new_id() { std::unique_lock lock(mutex_tasks); int new_id = id++; - LOG_VERBOSE("new task id", {{"new_id", new_id}}); return new_id; } // Register function to process a new task - void on_new_task(std::function callback) { + void on_new_task(std::function callback) { callback_new_task = std::move(callback); } @@ -495,7 +1570,7 @@ struct server_queue { running = true; while (true) { - LOG_VERBOSE("new task may arrive", {}); + QUE_DBG("%s", "processing new tasks\n"); while (true) { std::unique_lock lock(mutex_tasks); @@ -506,21 +1581,22 @@ struct server_queue { server_task task = queue_tasks.front(); queue_tasks.pop_front(); lock.unlock(); - LOG_VERBOSE("callback_new_task", {{"id_task", task.id}}); - callback_new_task(task); + + QUE_DBG("processing task, id = %d\n", task.id); + callback_new_task(std::move(task)); } // all tasks in the current loop is processed, slots data is now ready - LOG_VERBOSE("callback_update_slots", {}); + QUE_DBG("%s", "update slots\n"); callback_update_slots(); - LOG_VERBOSE("wait for new task", {}); + QUE_DBG("%s", "waiting for new tasks\n"); { std::unique_lock lock(mutex_tasks); if (queue_tasks.empty()) { if (!running) { - LOG_VERBOSE("ending start_loop", {}); + QUE_DBG("%s", "terminate\n"); return; } condition_tasks.wait(lock, [&]{ @@ -530,51 +1606,83 @@ struct server_queue { } } } + +private: + void cleanup_pending_task(int id_target) { + // no need lock because this is called exclusively by post() + auto rm_func = [id_target](const server_task & task) { + return task.id_target == id_target; + }; + queue_tasks.erase( + std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), + queue_tasks.end()); + queue_tasks_deferred.erase( + std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), + queue_tasks_deferred.end()); + } }; struct server_response { // for keeping track of all tasks waiting for the result std::unordered_set waiting_task_ids; - // the main result queue - std::vector queue_results; + // the main result queue (using ptr for polymorphism) + std::vector queue_results; std::mutex mutex_results; std::condition_variable condition_results; // add the id_task to the list of tasks waiting for response void add_waiting_task_id(int id_task) { - LOG_VERBOSE("waiting for task id", {{"id_task", id_task}}); + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); std::unique_lock lock(mutex_results); waiting_task_ids.insert(id_task); } void add_waiting_tasks(const std::vector & tasks) { - for (const auto & t : tasks) { - add_waiting_task_id(t.id); + std::unique_lock lock(mutex_results); + + for (const auto & task : tasks) { + SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); + waiting_task_ids.insert(task.id); } } // when the request is finished, we can remove task associated with it void remove_waiting_task_id(int id_task) { - LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}}); + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); std::unique_lock lock(mutex_results); waiting_task_ids.erase(id_task); + // make sure to clean up all pending results + queue_results.erase( + std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { + return res->id == id_task; + }), + queue_results.end()); + } + + void remove_waiting_task_ids(const std::unordered_set & id_tasks) { + std::unique_lock lock(mutex_results); + + for (const auto & id_task : id_tasks) { + SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); + waiting_task_ids.erase(id_task); + } } // This function blocks the thread until there is a response for one of the id_tasks - server_task_result recv(const std::unordered_set & id_tasks) { + server_task_result_ptr recv(const std::unordered_set & id_tasks) { while (true) { std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ return !queue_results.empty(); }); - for (int i = 0; i < (int) queue_results.size(); i++) { - if (id_tasks.find(queue_results[i].id) != id_tasks.end()) { - server_task_result res = queue_results[i]; + for (size_t i = 0; i < queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); queue_results.erase(queue_results.begin() + i); return res; } @@ -584,21 +1692,45 @@ struct server_response { // should never reach here } + // same as recv(), but have timeout in seconds + // if timeout is reached, nullptr is returned + server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { + while (true) { + std::unique_lock lock(mutex_results); + + for (int i = 0; i < (int) queue_results.size(); i++) { + if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { + server_task_result_ptr res = std::move(queue_results[i]); + queue_results.erase(queue_results.begin() + i); + return res; + } + } + + std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); + if (cr_res == std::cv_status::timeout) { + return nullptr; + } + } + + // should never reach here + } + // single-task version of recv() - server_task_result recv(int id_task) { + server_task_result_ptr recv(int id_task) { std::unordered_set id_tasks = {id_task}; return recv(id_tasks); } // Send a new result to a waiting id_task - void send(server_task_result & result) { - LOG_VERBOSE("send new result", {{"id_task", result.id}}); + void send(server_task_result_ptr && result) { + SRV_DBG("sending result for task id = %d\n", result->id); std::unique_lock lock(mutex_results); for (const auto & id_task : waiting_task_ids) { - if (result.id == id_task) { - LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}}); - queue_results.push_back(std::move(result)); + if (result->id == id_task) { + SRV_DBG("task id = %d pushed to result queue\n", result->id); + + queue_results.emplace_back(std::move(result)); condition_results.notify_all(); return; } @@ -607,13 +1739,22 @@ struct server_response { }; struct server_context { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + llama_model * model = nullptr; llama_context * ctx = nullptr; - std::vector lora_adapters; - gpt_params params; + const llama_vocab * vocab = nullptr; - llama_batch batch; + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; + + llama_batch batch = {}; bool clean_kv_cache = true; bool add_bos_token = true; @@ -621,12 +1762,6 @@ struct server_context { int32_t n_ctx; // total context for all clients / slots - // system prompt - bool system_need_update = false; - - std::string system_prompt; - std::vector system_tokens; - // slots / clients std::vector slots; json default_generation_settings_for_props; @@ -639,98 +1774,156 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; + common_chat_templates chat_templates; + ~server_context() { - if (ctx) { - llama_free(ctx); - ctx = nullptr; - } - - if (model) { - llama_free_model(model); - model = nullptr; - } - // Clear any sampling context for (server_slot & slot : slots) { - if (slot.smpl != nullptr) { - gpt_sampler_free(slot.smpl); - } + common_sampler_free(slot.smpl); + slot.smpl = nullptr; + + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; + + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); } llama_batch_free(batch); } - bool load_model(const gpt_params & params_) { - params = params_; + bool load_model(const common_params & params) { + SRV_INF("loading model '%s'\n", params.model.c_str()); - // dedicate one sequence to the system prompt - params.n_parallel += 1; + params_base = params; - llama_init_result llama_init = llama_init_from_gpt_params(params); + llama_init = common_init_from_params(params_base); + + model = llama_init.model.get(); + ctx = llama_init.context.get(); - model = llama_init.model; - ctx = llama_init.context; - lora_adapters = llama_init.lora_adapters; - params.n_parallel -= 1; // but be sneaky about it if (model == nullptr) { - LOG_ERROR("unable to load model", {{"model", params.model}}); + SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); return false; } + vocab = llama_model_get_vocab(model); + n_ctx = llama_n_ctx(ctx); - add_bos_token = llama_add_bos_token(model); - has_eos_token = !llama_add_eos_token(model); + add_bos_token = llama_vocab_get_add_bos(vocab); + has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + + if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); + + auto params_dft = params_base; + + params_dft.devices = params_base.speculative.devices; + params_dft.hf_file = params_base.speculative.hf_file; + params_dft.hf_repo = params_base.speculative.hf_repo; + params_dft.model = params_base.speculative.model; + params_dft.model_url = params_base.speculative.model_url; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + + llama_init_dft = common_init_from_params(params_dft); + + model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str()); + return false; + } + + if (!common_speculative_are_compatible(ctx, llama_init_dft.context.get())) { + SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str()); + + return false; + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + // force F16 KV cache for the draft model for extra performance + cparams_dft.type_k = GGML_TYPE_F16; + cparams_dft.type_v = GGML_TYPE_F16; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); + } + + chat_templates = common_chat_templates_from_model(model, params_base.chat_template); + GGML_ASSERT(chat_templates.template_default.get() != nullptr); return true; } - bool validate_model_chat_template() const { + bool validate_builtin_chat_template(bool use_jinja) const { llama_chat_message chat[] = {{"user", "test"}}; - const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); - - return res > 0; + if (use_jinja) { + auto templates = common_chat_templates_from_model(model, ""); + common_chat_inputs inputs; + inputs.messages = json::array({{ + {"role", "user"}, + {"content", "test"}, + }}); + GGML_ASSERT(templates.template_default); + try { + common_chat_params_init(*templates.template_default, inputs); + if (templates.template_tool_use) { + common_chat_params_init(*templates.template_tool_use, inputs); + } + return true; + } catch (const std::exception & e) { + SRV_ERR("failed to apply template: %s\n", e.what()); + return false; + } + } else { + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); + const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0); + return chat_res > 0; + } } void init() { - const int32_t n_ctx_slot = n_ctx / params.n_parallel; + const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; - LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}}); + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - for (int i = 0; i < params.n_parallel; i++) { + for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; slot.id = i; + slot.ctx = ctx; slot.n_ctx = n_ctx_slot; - slot.n_predict = params.n_predict; + slot.n_predict = params_base.n_predict; - LOG_INFO("new slot", { - {"id_slot", slot.id}, - {"n_ctx_slot", slot.n_ctx} - }); + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - const int ga_n = params.grp_attn_n; - const int ga_w = params.grp_attn_w; + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } - if (ga_n != 1) { - GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT - GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT - //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT - //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT - - LOG_INFO("slot self-extend", { - {"id_slot", slot.id}, - {"ga_n", ga_n}, - {"ga_w", ga_w} - }); + slot.spec = common_speculative_init(slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } } - slot.ga_i = 0; - slot.ga_n = ga_n; - slot.ga_w = ga_w; + SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx); - slot.sparams = params.sparams; + slot.params.sampling = params_base.sampling; slot.callback_on_release = [this](int) { queue_tasks.pop_deferred_task(); @@ -741,8 +1934,7 @@ struct server_context { slots.push_back(slot); } - default_generation_settings_for_props = get_formated_generation(slots.front()); - default_generation_settings_for_props["seed"] = -1; + default_generation_settings_for_props = slots[0].to_json(); // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) @@ -750,53 +1942,12 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); } metrics.init(); } - std::vector tokenize(const json & json_prompt, bool add_special) const { - // TODO: currently, we tokenize using special tokens by default - // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) - // but it's better compared to completely ignoring ChatML and other chat templates - const bool TMP_FORCE_SPECIAL = true; - - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - std::vector prompt_tokens; - - if (json_prompt.is_array()) { - bool first = true; - for (const auto & p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); - - std::vector p; - if (first) { - p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); - first = false; - } else { - p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL); - } - - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } - - prompt_tokens.push_back(p.template get()); - } - } - } else { - auto s = json_prompt.template get(); - prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL); - } - - return prompt_tokens; - } - server_slot * get_slot_by_id(int id) { for (server_slot & slot : slots) { if (slot.id == id) { @@ -807,12 +1958,12 @@ struct server_context { return nullptr; } - server_slot * get_available_slot(const std::string & prompt) { + server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) { - int max_lcp_len = 0; + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + int lcs_len = 0; float similarity = 0; for (server_slot & slot : slots) { @@ -821,36 +1972,27 @@ struct server_context { continue; } - // skip the slot if it does not contains prompt - if (!slot.prompt.is_string()) { + // skip the slot if it does not contains cached tokens + if (slot.cache_tokens.empty()) { continue; } - // current slot's prompt - std::string slot_prompt = slot.prompt.get(); + // length of the Longest Common Subsequence between the current slot's prompt and the input prompt + int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); - // length of the current slot's prompt - int slot_prompt_len = slot_prompt.size(); - - // length of the Longest Common Prefix between the current slot's prompt and the input prompt - int lcp_len = common_part(slot_prompt, prompt); - - // fraction of the common substring length compared to the current slot's prompt length - similarity = static_cast(lcp_len) / slot_prompt_len; + // fraction of the common subsequence length compared to the current slot's prompt length + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); // select the current slot if the criteria match - if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { - max_lcp_len = lcp_len; + if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { + lcs_len = cur_lcs_len; + similarity = cur_similarity; ret = &slot; } } if (ret != nullptr) { - LOG_VERBOSE("selected slot by lcp similarity", { - {"id_slot", ret->id}, - {"max_lcp_len", max_lcp_len}, - {"similarity", similarity}, - }); + SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity); } } @@ -871,10 +2013,7 @@ struct server_context { } if (ret != nullptr) { - LOG_VERBOSE("selected slot by lru", { - {"id_slot", ret->id}, - {"t_last", t_last}, - }); + SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last); } } @@ -882,171 +2021,37 @@ struct server_context { } bool launch_slot_with_task(server_slot & slot, const server_task & task) { - slot_params default_params; - // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - auto default_sparams = params.sparams; - const auto & data = task.data; + slot.reset(); + slot.id_task = task.id; + slot.index = task.index; + slot.task_type = task.type; + slot.params = std::move(task.params); + slot.prompt_tokens = std::move(task.prompt_tokens); - if (data.count("__oaicompat") != 0) { - slot.oaicompat = true; - slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - } else { - slot.oaicompat = false; - slot.oaicompat_model = ""; + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = task.params.lora; } - slot.params.stream = json_value(data, "stream", false); - slot.params.cache_prompt = json_value(data, "cache_prompt", false); - slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict)); - slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k); - slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); - slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); - slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); - slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); - slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); - slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); - slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); - slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); - slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); - slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); - slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); - slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); - slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); - slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl); - slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep); - slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard); - slot.sparams.seed = json_value(data, "seed", default_sparams.seed); - slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); - slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { - send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); - return false; - } - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - slot.sparams.grammar = json_schema_to_grammar(schema); - } catch (const std::exception & e) { - send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST); - return false; - } - } else { - slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar); - } - - if (slot.params.cache_prompt && slot.ga_n != 1) { - LOG_WARNING("cache_prompt is not supported with group-attention", {}); - slot.params.cache_prompt = false; - } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? - LOG_WARNING("Max tokens to predict exceeds server configuration", { - {"params.n_predict", slot.params.n_predict}, - {"slot.n_predict", slot.n_predict}, - }); slot.params.n_predict = slot.n_predict; + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict); } - // infill - slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix); - slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); - - // get prompt - if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) { - const auto & prompt = data.find("prompt"); - if (prompt == data.end()) { - send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - if ((prompt->is_string()) || - (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) || - (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) { - slot.prompt = *prompt; - } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { - slot.prompt = prompt->at(0); - } else { - send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); - return false; - } - } - - { - slot.sparams.logit_bias.clear(); - - if (json_value(data, "ignore_eos", false) && has_eos_token) { - slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY}); - } - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_n_vocab(model); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - slot.sparams.logit_bias.push_back({tok, bias}); - } - } else if (el[0].is_string()) { - auto toks = llama_tokenize(model, el[0].get(), false); - for (auto tok : toks) { - slot.sparams.logit_bias.push_back({tok, bias}); - } - } - } - } - } - } - - { - slot.params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - slot.params.antiprompt.push_back(word); - } - } - } - } - - { - const auto & samplers = data.find("samplers"); - if (samplers != data.end() && samplers->is_array()) { - std::vector sampler_names; - for (const auto & name : *samplers) { - if (name.is_string()) { - sampler_names.emplace_back(name); - } - } - slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false); - } else { - slot.sparams.samplers = default_sparams.samplers; - } + if (slot.params.ignore_eos && has_eos_token) { + slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); } { if (slot.smpl != nullptr) { - gpt_sampler_free(slot.smpl); + common_sampler_free(slot.smpl); } - slot.smpl = gpt_sampler_init(model, slot.sparams); + slot.smpl = common_sampler_init(model, slot.params.sampling); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -1054,137 +2059,70 @@ struct server_context { } } - slot.state = SLOT_STATE_PROCESSING_PROMPT; - slot.prompt_tokens.clear(); + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); - LOG_INFO("slot is processing task", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - }); + slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + } + + slot.state = SLOT_STATE_STARTED; + + SLT_INF(slot, "%s", "processing task\n"); return true; } void kv_cache_clear() { - LOG_VERBOSE("clearing KV cache", {}); + SRV_DBG("%s", "clearing KV cache\n"); // clear the entire KV cache llama_kv_cache_clear(ctx); clean_kv_cache = false; } - void system_prompt_update() { - LOG_VERBOSE("system prompt update", { - {"system_prompt", system_prompt}, - }); - - kv_cache_clear(); - system_tokens.clear(); - - if (!system_prompt.empty()) { - system_tokens = ::llama_tokenize(ctx, system_prompt, true); - - const int32_t n_batch = llama_n_batch(ctx); - const int32_t n_tokens_prompt = system_tokens.size(); - - for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i); - - llama_batch_clear(batch); - - for (int32_t j = 0; j < n_tokens; ++j) { - llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false); - } - - if (llama_decode(ctx, batch) != 0) { - LOG_ERROR("llama_decode() failed", {}); - return; - } - } - - // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i <= params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); - } - } - - system_need_update = false; - } - - bool system_prompt_set(const std::string & sys_prompt) { - system_prompt = sys_prompt; - - LOG_VERBOSE("system prompt process", { - {"system_prompt", system_prompt}, - }); - - // release all slots - for (server_slot & slot : slots) { - slot.release(); - } - - system_need_update = true; - return true; - } - bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); + const std::string token_str = result.text_to_send; slot.sampled = result.tok; - // search stop word and delete it slot.generated_text += token_str; + if (slot.params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } slot.has_next_token = true; // check if there is incomplete UTF-8 character at the end - bool incomplete = false; - for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { - unsigned char c = slot.generated_text[slot.generated_text.size() - i]; - if ((c & 0xC0) == 0x80) { - // continuation byte: 10xxxxxx - continue; - } - if ((c & 0xE0) == 0xC0) { - // 2-byte character: 110xxxxx ... - incomplete = i < 2; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character: 1110xxxx ... - incomplete = i < 3; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character: 11110xxx ... - incomplete = i < 4; - } - // else 1-byte character or invalid byte - break; - } + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + // search stop word and delete it if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); - bool is_stop_full = false; + bool send_text = true; - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL); + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); if (stop_pos != std::string::npos) { - is_stop_full = true; slot.generated_text.erase( slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } else { - is_stop_full = false; - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL); + } else if (slot.has_next_token) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; } // check if there is any token to predict - if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { + if (send_text) { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.n_sent_text += result.text_to_send.size(); // add the token to slot queue and cache + } else { + result.text_to_send = ""; } - slot.add_token_string(result); + slot.add_token(result); if (slot.params.stream) { send_partial_response(slot, result); } @@ -1195,103 +2133,144 @@ struct server_context { } // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) { - slot.stopped_limit = true; + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - LOG_VERBOSE("stopped by limit", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_decoded", slot.n_decoded}, - {"n_predict", slot.params.n_predict}, - }); + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); } - if (llama_token_is_eog(model, result.tok)) { - slot.stopped_eos = true; - slot.has_next_token = false; + if (slot.has_new_line) { + // if we have already seen a new line, we stop after a certain time limit + if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; - LOG_VERBOSE("eos token found", {}); + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms); + } + + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } } - auto n_ctx_train = llama_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 - && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { - LOG_WARNING("n_predict is not set and self-context extend is disabled." - " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", { - { "id_slot", slot.id }, - { "params.n_predict", slot.params.n_predict }, - { "slot.n_prompt_tokens", slot.n_prompt_tokens }, - { "slot.n_decoded", slot.n_decoded }, - { "slot.n_predict", slot.n_predict }, - { "n_slots", params.n_parallel }, - { "slot.n_ctx", slot.n_ctx }, - { "n_ctx", n_ctx }, - { "n_ctx_train", n_ctx_train }, - { "ga_n", slot.ga_n }, - }); + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + } + + // if context shift is disabled, we stop when it reaches the context limit + if (slot.n_past >= slot.n_ctx) { slot.truncated = true; - slot.stopped_limit = true; - slot.has_next_token = false; // stop prediction + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); } - LOG_VERBOSE("next token", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"token", result.tok}, - {"token_text", tokens_to_output_formatted_string(ctx, result.tok)}, - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"n_decoded", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }); + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + + SLT_DBG(slot, "%s", "stopped by EOS\n"); + } + + const auto n_ctx_train = llama_model_n_ctx_train(model); + + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; // stop prediction + + SLT_WRN(slot, + "n_predict (%d) is set for infinite generation. " + "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", + slot.params.n_predict, n_ctx_train); + } + + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); return slot.has_next_token; // continue } - json get_formated_generation(const server_slot & slot) const { - std::vector samplers; - samplers.reserve(slot.sparams.samplers.size()); - for (const auto & sampler : slot.sparams.samplers) { - samplers.emplace_back(gpt_sampler_type_to_str(sampler)); - } + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { + size_t n_probs = slot.params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + if (post_sampling) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; - return json { - {"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, // Server configured n_predict - {"model", params.model_alias}, - {"seed", slot.sparams.seed}, - {"temperature", slot.sparams.temp}, - {"dynatemp_range", slot.sparams.dynatemp_range}, - {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, - {"top_k", slot.sparams.top_k}, - {"top_p", slot.sparams.top_p}, - {"min_p", slot.sparams.min_p}, - {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typ_p}, - {"repeat_last_n", slot.sparams.penalty_last_n}, - {"repeat_penalty", slot.sparams.penalty_repeat}, - {"presence_penalty", slot.sparams.penalty_present}, - {"frequency_penalty", slot.sparams.penalty_freq}, - {"mirostat", slot.sparams.mirostat}, - {"mirostat_tau", slot.sparams.mirostat_tau}, - {"mirostat_eta", slot.sparams.mirostat_eta}, - {"penalize_nl", slot.sparams.penalize_nl}, - {"stop", slot.params.antiprompt}, - {"max_tokens", slot.params.n_predict}, // User configured n_predict - {"n_keep", slot.params.n_keep}, - {"n_discard", slot.params.n_discard}, - {"ignore_eos", slot.sparams.ignore_eos}, - {"stream", slot.params.stream}, - //{"logit_bias", slot.sparams.logit_bias}, - {"n_probs", slot.sparams.n_probs}, - {"min_keep", slot.sparams.min_keep}, - {"grammar", slot.sparams.grammar}, - {"samplers", samplers}, - }; + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + common_detokenize(ctx, {cur_p->data[i].id}, special), + cur_p->data[i].p + }); + } + } else { + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); + + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + common_detokenize(ctx, {cur[i].id}, special), + cur[i].p + }); + } + } } void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { @@ -1303,119 +2282,107 @@ struct server_context { } void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - LOG_ERROR("task error", { - {"id_task", id_task}, - {"error", error}, - }); + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - server_task_result res; - res.id = id_task; - res.stop = false; - res.error = true; - res.data = format_error_response(error, type); + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; - queue_results.send(res); + queue_results.send(std::move(res)); } - void send_partial_response(server_slot & slot, completion_token_output tkn) { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = false; - res.data = json { - {"content", tkn.text_to_send}, - {"stop", false}, - {"id_slot", slot.id}, - {"multimodal", false}, - {"index", slot.index}, - }; + void send_partial_response(server_slot & slot, const completion_token_output & tkn) { + auto res = std::make_unique(); - if (slot.sparams.n_probs > 0) { - const std::vector to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false); - const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); - const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); + res->id = slot.id_task; + res->index = slot.index; + res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; - std::vector probs_output; - if (probs_pos < probs_stop_pos) { - probs_output = std::vector( - slot.generated_token_probs.begin() + probs_pos, - slot.generated_token_probs.begin() + probs_stop_pos); - } - slot.n_sent_token_probs = probs_stop_pos; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->post_sampling_probs = slot.params.post_sampling_probs; - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs } - if (slot.oaicompat) { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) { + res->timings = slot.get_timings(); } - queue_results.send(res); + queue_results.send(std::move(res)); } - void send_final_response(const server_slot & slot) { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = true; - res.data = json { - {"content", !slot.params.stream ? slot.generated_text : ""}, - {"id_slot", slot.id}, - {"stop", true}, - {"model", params.model_alias}, - {"tokens_predicted", slot.n_decoded}, - {"tokens_evaluated", slot.n_prompt_tokens}, - {"generation_settings", get_formated_generation(slot)}, - {"prompt", slot.prompt}, - {"truncated", slot.truncated}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()}, - {"index", slot.index}, - }; + void send_final_response(server_slot & slot) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->id_slot = slot.id; - if (slot.sparams.n_probs > 0) { - std::vector probs; - if (!slot.params.stream && slot.stopped_word) { - const std::vector stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false); + res->index = slot.index; + res->content = std::move(slot.generated_text); + res->tokens = std::move(slot.generated_tokens); + res->timings = slot.get_timings(); + res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); + res->response_fields = std::move(slot.params.response_fields); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.params.post_sampling_probs; + + res->verbose = slot.params.verbose; + res->stream = slot.params.stream; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + res->oaicompat_chat_format = slot.params.oaicompat_chat_format; + // populate res.probs_output + if (slot.params.sampling.n_probs > 0) { + if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - probs = std::vector( + res->probs_output = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.end() - safe_offset); } else { - probs = std::vector( + res->probs_output = std::vector( slot.generated_token_probs.begin(), slot.generated_token_probs.end()); } - - res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs); } - if (slot.oaicompat) { - res.data["oaicompat_token_ctr"] = slot.n_decoded; - res.data["model"] = slot.oaicompat_model; - } + res->generation_params = slot.params; // copy the parameters - queue_results.send(res); + queue_results.send(std::move(res)); } void send_embedding(const server_slot & slot, const llama_batch & batch) { - server_task_result res; - res.id = slot.id_task; - res.error = false; - res.stop = true; + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + res->oaicompat = slot.params.oaicompat; - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embd_res(n_embd, 0.0f); for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { continue; } @@ -1425,133 +2392,150 @@ struct server_context { } if (embd == NULL) { - LOG_ERROR("failed to get embeddings", { - {"token", batch.token [i]}, - {"seq_id", batch.seq_id[i][0]} - }); - - res.data = json { - {"embedding", std::vector(n_embd, 0.0f)}, - }; + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + res->embedding.push_back(std::vector(n_embd, 0.0f)); continue; } - llama_embd_normalize(embd, embd_res.data(), n_embd); - - res.data = json { - {"embedding", embd_res}, - {"index", slot.index}, - }; + // normalize only when there is pooling + // TODO: configurable + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + res->embedding.push_back({ embd, embd + n_embd }); + } } - queue_results.send(res); + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); + } + + void send_rerank(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.id_task; + res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->score = -1e6; + continue; + } + + res->score = embd[0]; + } + + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); + + queue_results.send(std::move(res)); } // // Functions to create new task(s) and receive result(s) // - std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { - std::vector tasks; - auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { - server_task task; - task.id = queue_tasks.get_new_id(); - task.cmpl_type = cmpl_type; - task.type = SERVER_TASK_TYPE_COMPLETION; - if (replace_prompt) { - task.data = task_data; - task.data["prompt"] = prompt; - } else { - task.data = std::move(task_data); - } - tasks.push_back(std::move(task)); - }; - - static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts"; - if (!data.contains("prompt")) { - throw std::runtime_error(error_msg); - } - - json prompt = data.at("prompt"); - - // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task - if (prompt.is_string() || json_is_array_of_numbers(prompt)) { - data["index"] = 0; - create_task(data, false, nullptr); - } - // otherwise, it's a multiple-prompt task, we break it into smaller tasks - else if (prompt.is_array()) { - std::vector prompts = prompt; - for (size_t i = 0; i < prompts.size(); i++) { - const auto & e = prompts[i]; - if (e.is_string() || json_is_array_of_numbers(e)) { - data["index"] = i; - create_task(data, true, e); - } else { - throw std::runtime_error(error_msg); - } - } - } - // invalid case - else { - throw std::runtime_error(error_msg); - } - - return tasks; - } - void cancel_tasks(const std::unordered_set & id_tasks) { std::vector cancel_tasks; cancel_tasks.reserve(id_tasks.size()); for (const auto & id_task : id_tasks) { - LOG_VERBOSE("cancel task", {{"id_task", id_task}}); - server_task task; - task.type = SERVER_TASK_TYPE_CANCEL; + SRV_WRN("cancel task, id_task = %d\n", id_task); + + server_task task(SERVER_TASK_TYPE_CANCEL); task.id_target = id_task; - cancel_tasks.push_back(task); queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(task); } // push to beginning of the queue, so it has highest priority queue_tasks.post(cancel_tasks, true); } - // receive the results from task(s) created by create_tasks_cmpl - void receive_cmpl_results(const std::unordered_set & id_tasks, std::function&)> result_handler, std::function error_handler) { - // TODO: currently, there is no way to detect the client has cancelled the request - std::vector results(id_tasks.size()); - for (size_t i = 0; i < id_tasks.size(); i++) { - server_task_result result = queue_results.recv(id_tasks); + // receive the results from task(s) + void receive_multi_results( + const std::unordered_set & id_tasks, + const std::function&)> & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { + std::vector results(id_tasks.size()); + for (int i = 0; i < (int)id_tasks.size(); i++) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - if (result.error) { - error_handler(result.data); + if (is_connection_closed()) { cancel_tasks(id_tasks); - break; + return; } - size_t idx = result.data["index"]; - results[idx] = result; + if (result == nullptr) { + i--; // retry + continue; + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + const size_t idx = result->get_index(); + GGML_ASSERT(idx < results.size() && "index out of range"); + results[idx] = std::move(result); } result_handler(results); } - // receive the results from task(s) created by create_tasks_cmpl, in stream mode - void receive_cmpl_results_stream(const std::unordered_set & id_tasks, std::function result_handler, std::function error_handler) { + // receive the results from task(s), in stream mode + void receive_cmpl_results_stream( + const std::unordered_set & id_tasks, + const std::function & result_handler, + const std::function & error_handler, + const std::function & is_connection_closed) { size_t n_finished = 0; while (true) { - server_task_result result = queue_results.recv(id_tasks); + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + + if (is_connection_closed()) { + cancel_tasks(id_tasks); + return; + } + + if (result == nullptr) { + continue; // retry + } + + if (result->is_error()) { + error_handler(result->to_json()); + cancel_tasks(id_tasks); + return; + } + + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); if (!result_handler(result)) { cancel_tasks(id_tasks); break; } - if (result.error) { - error_handler(result.data); - cancel_tasks(id_tasks); - break; - } - - if (result.stop) { + if (result->is_stop()) { if (++n_finished == id_tasks.size()) { break; } @@ -1563,56 +2547,32 @@ struct server_context { // Functions to process the task // - void process_single_task(const server_task & task) { + void process_single_task(server_task task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: { - const int id_slot = json_value(task.data, "id_slot", -1); + const int id_slot = task.id_selected_slot; - server_slot * slot; - - if (id_slot != -1) { - slot = get_slot_by_id(id_slot); - } else { - std::string prompt; - if (task.data.contains("prompt") && task.data.at("prompt").is_string()) { - prompt = json_value(task.data, "prompt", std::string()); - } - - slot = get_available_slot(prompt); - } + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); if (slot == nullptr) { // if no slot is available, we defer this task for processing later - LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); queue_tasks.defer(task); break; } if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); queue_tasks.defer(task); break; } - if (task.data.contains("system_prompt")) { - std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); - system_prompt_set(sys_prompt); - - for (server_slot & slot : slots) { - slot.n_past = 0; - slot.n_past_se = 0; - } - } - - slot->reset(); - - slot->id_task = task.id; - slot->cmpl_type = task.cmpl_type; - slot->index = json_value(task.data, "index", 0); - if (!launch_slot_with_task(*slot, task)) { - LOG_ERROR("error while launching slot", task.data); + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); break; } } break; @@ -1638,79 +2598,50 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = get_formated_generation(slot); - slot_data["id"] = slot.id; - slot_data["id_task"] = slot.id_task; - slot_data["state"] = slot.state; - slot_data["prompt"] = slot.prompt; - slot_data["next_token"] = { - {"has_next_token", slot.has_next_token}, - {"n_remain", slot.n_remaining}, - {"n_decoded", slot.n_decoded}, - {"stopped_eos", slot.stopped_eos}, - {"stopped_word", slot.stopped_word}, - {"stopped_limit", slot.stopped_limit}, - {"stopping_word", slot.stopping_word}, - }; + json slot_data = slot.to_json(); - if (slot_data["state"] == SLOT_STATE_IDLE) { - n_idle_slots++; - } else { + if (slot.is_processing()) { n_processing_slots++; + } else { + n_idle_slots++; } slots_data.push_back(slot_data); } - LOG_INFO("slot data", { - {"id_task", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots} - }); + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); - LOG_VERBOSE("slot data", { - {"id_task", task.id}, - {"n_idle_slots", n_idle_slots}, - {"n_processing_slots", n_processing_slots}, - {"slots", slots_data} - }); + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size(); + res->t_start = metrics.t_start; - server_task_result res; - res.id = task.id; - res.stop = true; - res.error = false; - res.data = { - { "idle", n_idle_slots }, - { "processing", n_processing_slots }, - { "deferred", queue_tasks.queue_tasks_deferred.size() }, - { "t_start", metrics.t_start}, + res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx); + res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx); - { "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total}, - { "t_tokens_generation_total", metrics.t_tokens_generation_total}, - { "n_tokens_predicted_total", metrics.n_tokens_predicted_total}, - { "t_prompt_processing_total", metrics.t_prompt_processing_total}, + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; - { "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed}, - { "t_prompt_processing", metrics.t_prompt_processing}, - { "n_tokens_predicted", metrics.n_tokens_predicted}, - { "t_tokens_generation", metrics.t_tokens_generation}, + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; - { "n_decode_total", metrics.n_decode_total}, - { "n_busy_slots_total", metrics.n_busy_slots_total}, + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; - { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)}, - { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, - - { "slots", slots_data }, - }; - - if (json_value(task.data, "reset_bucket", false)) { + if (task.metrics_reset_bucket) { metrics.reset_bucket(); } - queue_results.send(res); + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_SAVE: { - int id_slot = task.data.at("id_slot"); + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); @@ -1718,7 +2649,7 @@ struct server_context { } if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); queue_tasks.defer(task); break; } @@ -1726,32 +2657,27 @@ struct server_context { const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count); const int64_t t_end = ggml_time_us(); const double t_save_ms = (t_end - t_start) / 1000.0; - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_saved", token_count }, // tokens saved - { "n_written", nwrite }, // bytes written - { "timings", { - { "save_ms", t_save_ms } - } } - }; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_RESTORE: { - int id_slot = task.data.at("id_slot"); + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); @@ -1759,19 +2685,19 @@ struct server_context { } if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); queue_tasks.defer(task); break; } const int64_t t_start = ggml_time_us(); - std::string filename = task.data.at("filename"); - std::string filepath = task.data.at("filepath"); + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; slot->cache_tokens.resize(slot->n_ctx); size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count); if (nread == 0) { slot->cache_tokens.resize(0); send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); @@ -1782,24 +2708,19 @@ struct server_context { const int64_t t_end = ggml_time_us(); const double t_restore_ms = (t_end - t_start) / 1000.0; - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", token_count }, // tokens restored - { "n_read", nread }, // bytes read - { "timings", { - { "restore_ms", t_restore_ms } - } } - }; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SLOT_ERASE: { - int id_slot = task.data.at("id_slot"); + int id_slot = task.slot_action.slot_id; server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); @@ -1807,44 +2728,33 @@ struct server_context { } if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later - LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); queue_tasks.defer(task); break; } // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json { - { "id_slot", id_slot }, - { "n_erased", n_erased } - }; - queue_results.send(result); + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); } break; case SERVER_TASK_TYPE_SET_LORA: { - llama_lora_adapters_apply(ctx, lora_adapters); - server_task_result result; - result.id = task.id; - result.stop = true; - result.error = false; - result.data = json{{ "success", true }}; - queue_results.send(result); + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); } break; } } void update_slots() { - if (system_need_update) { - system_prompt_update(); - } - // check if all slots are idle { bool all_idle = true; @@ -1857,8 +2767,8 @@ struct server_context { } if (all_idle) { - LOG_INFO("all slots are idle", {}); - if (system_prompt.empty() && clean_kv_cache) { + SRV_INF("%s", "all slots are idle\n"); + if (clean_kv_cache) { kv_cache_clear(); } @@ -1867,57 +2777,59 @@ struct server_context { } { - LOG_VERBOSE("posting NEXT_RESPONSE", {}); - - server_task task; - task.type = SERVER_TASK_TYPE_NEXT_RESPONSE; - task.id_target = -1; + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); queue_tasks.post(task); } // apply context-shift if needed // TODO: simplify and improve for (server_slot & slot : slots) { - if (slot.ga_n == 1) { - if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { - // Shift context - const int n_keep = slot.params.n_keep + add_bos_token; - const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; - const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); + if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + slot.release(); + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + continue; + } - LOG_INFO("slot context shift", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_keep", n_keep}, - {"n_left", n_left}, - {"n_discard", n_discard}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()} - }); + // Shift context + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_left = slot.n_past - n_keep; + const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); - llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - if (slot.params.cache_prompt) { - for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { - slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; - } + llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); - slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); + if (slot.params.cache_prompt) { + for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { + slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; } - slot.n_past -= n_discard; - - slot.truncated = true; + slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard); } + + slot.n_past -= n_discard; + + slot.truncated = true; } } // start populating the batch for this iteration - llama_batch_clear(batch); + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + const auto & trigger_tokens = slot.params.sampling.grammar_trigger_tokens; + return params_base.special || std::find(trigger_tokens.begin(), trigger_tokens.end(), token) != trigger_tokens.end(); + }; // frist, add sampled tokens from any ongoing sequences for (auto & slot : slots) { @@ -1925,13 +2837,16 @@ struct server_context { continue; } + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + slot.i_batch = batch.n_tokens; - const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; - - // TODO: we always have to take into account the "system_tokens" - // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true); + common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; @@ -1939,97 +2854,57 @@ struct server_context { slot.cache_tokens.push_back(slot.sampled); } - LOG_VERBOSE("slot decode token", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated} - }); + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); } // process in chunks of params.n_batch int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // track if this is an embedding or non-embedding batch - // if we've added sampled tokens above, we are in non-embedding mode - // -1: none, 0: non-embedding, 1: embedding - int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; - // next, batch any pending prompts without exceeding n_batch - if (params.cont_batching || batch.n_tokens == 0) { + if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { + // check if we can batch this slot with the previous one + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + } + // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_PROCESSING_PROMPT) { + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { auto & prompt_tokens = slot.prompt_tokens; - // we haven't tokenized the prompt yet - do it now: - if (prompt_tokens.empty()) { - LOG_VERBOSE("tokenizing prompt", { - {"id_slot", slot.id}, - {"id_task", slot.id_task} - }); - + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) { - const bool add_bos = llama_add_bos_token(model); - bool suff_rm_leading_spc = true; - if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { - params.input_suffix.erase(0, 1); - suff_rm_leading_spc = false; - } - - auto prefix_tokens = tokenize(slot.params.input_prefix, false); - auto suffix_tokens = tokenize(slot.params.input_suffix, false); - - const int space_token = 29871; // TODO: this should not be hardcoded - if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) { - suffix_tokens.erase(suffix_tokens.begin()); - } - - prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model)); - suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model)); - - auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; - auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; - if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); - } - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - - const llama_token middle_token = llama_token_middle(model); - if (middle_token >= 0) { - embd_inp.push_back(middle_token); - } - - prompt_tokens = embd_inp; - } else { - prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt - } - slot.n_past = 0; slot.n_prompt_tokens = prompt_tokens.size(); + slot.state = SLOT_STATE_PROCESSING_PROMPT; - LOG_VERBOSE("prompt tokenized", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, - }); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + + // print prompt tokens (for debugging) + if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, prompt_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + } else { + // all + for (int i = 0; i < (int) prompt_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + } + } // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { - LOG_INFO("empty prompt - releasing slot", { - {"id_slot", slot.id}, - {"id_task", slot.id_task} - }); + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); slot.release(); slot.print_timings(); @@ -2037,27 +2912,42 @@ struct server_context { continue; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { - // this prompt is too large to process - discard it + if (slot.is_non_causal()) { if (slot.n_prompt_tokens > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); continue; } + + if (slot.n_prompt_tokens > slot.n_ctx) { + slot.release(); + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); + continue; + } } else { + if (!params_base.ctx_shift) { + // if context shift is disabled, we make sure prompt size is smaller than KV size + // TODO: there should be a separate parameter that control prompt truncation + // context shift should be applied only during the generation phase + if (slot.n_prompt_tokens >= slot.n_ctx) { + slot.release(); + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); + continue; + } + } if (slot.params.n_keep < 0) { slot.params.n_keep = slot.n_prompt_tokens; } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - // if input prompt is too big, truncate it (if group attention self-extend is disabled) - if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) { + // if input prompt is too big, truncate it + if (slot.n_prompt_tokens >= slot.n_ctx) { const int n_left = slot.n_ctx - slot.params.n_keep; const int n_block_size = n_left / 2; const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - std::vector new_tokens( + llama_tokens new_tokens( prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep); @@ -2071,131 +2961,109 @@ struct server_context { slot.truncated = true; slot.n_prompt_tokens = prompt_tokens.size(); - LOG_VERBOSE("input truncated", { - {"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", slot.n_ctx}, - {"n_keep", slot.params.n_keep}, - {"n_left", n_left}, - {"n_prompt_tokens", slot.n_prompt_tokens}, - {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())}, - }); + SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - gpt_sampler_reset(slot.smpl); - - if (!slot.params.cache_prompt) { - slot.n_past_se = 0; - slot.ga_i = 0; - } else { - GGML_ASSERT(slot.ga_n == 1); - + if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = common_part(slot.cache_tokens, prompt_tokens); + slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); - // push the prompt into the sampling context (do not apply grammar) - for (int i = 0; i < slot.n_past; ++i) { - gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false); + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + size_t head_c = slot.n_past; // cache + size_t head_p = slot.n_past; // current prompt + + SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); + + while (head_c < slot.cache_tokens.size() && + head_p < prompt_tokens.size()) { + + size_t n_match = 0; + while (head_c + n_match < slot.cache_tokens.size() && + head_p + n_match < prompt_tokens.size() && + slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t) params_base.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} + + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c); + llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i]; + slot.n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); } } } if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { // we have to evaluate at least 1 token to generate logits. - LOG_INFO("we have to evaluate at least 1 token to generate logits", { - { "id_slot", slot.id }, - { "id_task", slot.id_task } - }); + SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); slot.n_past--; - if (slot.ga_i > 0) { - slot.n_past_se--; - } } slot.n_prompt_tokens_processed = 0; } - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { continue; } } - // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; - if (batch_type == -1) { - batch_type = slot_type; - } else if (batch_type != slot_type) { - continue; - } - // keep only the common part - int p0 = (int) system_tokens.size() + slot.n_past; - if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { + if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, slot.id, -1, -1); - p0 = (int) system_tokens.size(); - if (p0 != 0) { - // copy over the system prompt when there is one - llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); - } - - // there is no common part left (except for the system prompt) + // there is no common part left slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - // TODO: is the system prompt ever in the sampling context? - gpt_sampler_reset(slot.smpl); } + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + // remove the non-common part from the cache slot.cache_tokens.resize(slot.n_past); - LOG_INFO("kv cache rm [p0, end)", { - { "id_slot", slot.id }, - { "id_task", slot.id_task }, - { "p0", p0 } - }); - - int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; - - int32_t ga_i = slot.ga_i; - int32_t ga_n = slot.ga_n; - int32_t ga_w = slot.ga_w; - // add prompt tokens for processing in the current batch - // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow - for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) { - if (slot.ga_n != 1) { - while (slot_npast >= ga_i + ga_w) { - const int bd = (ga_w/ga_n)*(ga_n - 1); - slot_npast -= bd; - ga_i += ga_w/ga_n; - } - } + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + // without pooling, we want to output the embeddings for all the tokens in the batch + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false); + common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); } slot.n_prompt_tokens_processed++; - slot_npast++; + slot.n_past++; } - LOG_VERBOSE("prompt processing progress", { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - {"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens}, - }); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { @@ -2203,18 +3071,20 @@ struct server_context { GGML_ASSERT(batch.n_tokens > 0); + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + common_sampler_accept(slot.smpl, prompt_tokens[i], false); + } + // extract the logits only for the last token batch.logits[batch.n_tokens - 1] = true; slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - LOG_VERBOSE("prompt done", { - {"id_slot", slot.id}, - {"n_past", slot.n_past}, - {"n_ctx", n_ctx}, - {"n_tokens", batch.n_tokens}, - }); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); } } @@ -2225,50 +3095,23 @@ struct server_context { } if (batch.n_tokens == 0) { - LOG_VERBOSE("no tokens to decode", {}); + SRV_WRN("%s", "no tokens to decode\n"); return; } - LOG_VERBOSE("decoding batch", { - {"n_tokens", batch.n_tokens}, - }); + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - // make sure we're in the right embedding mode - llama_set_embeddings(ctx, batch_type == 1); + if (slot_batched) { + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, slot_batched->is_non_causal()); + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + } // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - for (auto & slot : slots) { - if (slot.ga_n != 1) { - // context extension via Self-Extend - // TODO: simplify and/or abstract this - while (slot.n_past_se >= slot.ga_i + slot.ga_w) { - const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; - const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); - const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; - - LOG_TEE("\n"); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); - LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); - LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); - - slot.n_past_se -= bd; - - slot.ga_i += slot.ga_w / slot.ga_n; - - LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); - } - - slot.n_past_se += n_tokens; - } - } - llama_batch batch_view = { n_tokens, batch.token + i, @@ -2277,7 +3120,6 @@ struct server_context { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, - 0, 0, 0, // unused }; const int ret = llama_decode(ctx, batch_view); @@ -2286,11 +3128,7 @@ struct server_context { if (ret != 0) { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size - LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", { - {"i", i}, - {"n_batch", n_batch}, - {"ret", ret}, - }); + SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); for (auto & slot : slots) { slot.release(); send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); @@ -2302,11 +3140,7 @@ struct server_context { n_batch /= 2; i -= n_batch; - LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", { - {"i", i}, - {"n_batch", n_batch}, - {"ret", ret}, - }); + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); continue; // continue loop of n_batch } @@ -2317,7 +3151,7 @@ struct server_context { } if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { + if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding send_embedding(slot, batch_view); slot.release(); @@ -2325,33 +3159,46 @@ struct server_context { continue; // continue loop of slots } + if (slot.task_type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots } - completion_token_output result; - const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); + const int tok_idx = slot.i_batch - i; - gpt_sampler_accept(slot.smpl, id, true); + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; + + const int64_t t_current = ggml_time_us(); + if (slot.n_decoded == 1) { - slot.t_start_generation = ggml_time_us(); + slot.t_start_generation = t_current; slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; metrics.on_prompt_eval(slot); } - result.tok = id; + slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; - const auto * cur_p = gpt_sampler_get_candidates(slot.smpl); + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { - result.probs.push_back({ - cur_p->data[i].id, - i >= cur_p->size ? 0.0f : cur_p->data[i].p, - }); + if (slot.params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { @@ -2360,23 +3207,112 @@ struct server_context { slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); + continue; + } + } + + // do speculative decoding + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; } - slot.i_batch = -1; + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; + params_spec.p_min = slot.params.speculative.p_min; + + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id); + + // ignore small drafts + if (slot.params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + + continue; + } + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_past += ids.size(); + slot.n_decoded += ids.size(); + + slot.cache_tokens.push_back(id); + slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1); + + llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); } } - LOG_VERBOSE("run slots completed", {}); + SRV_DBG("%s", "run slots completed\n"); } json model_meta() const { return json { - {"vocab_type", llama_vocab_type (model)}, - {"n_vocab", llama_n_vocab (model)}, - {"n_ctx_train", llama_n_ctx_train (model)}, - {"n_embd", llama_n_embd (model)}, - {"n_params", llama_model_n_params(model)}, - {"size", llama_model_size (model)}, + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, }; } }; @@ -2387,19 +3323,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp return; } - LOG_INFO("request", { - {"remote_addr", req.remote_addr}, - {"remote_port", req.remote_port}, - {"status", res.status}, - {"method", req.method}, - {"path", req.path}, - {"params", req.params}, - }); + LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); - LOG_VERBOSE("request", { - {"request", req.body}, - {"response", res.body}, - }); + LOG_DBG("request: %s\n", req.body.c_str()); + LOG_DBG("response: %s\n", res.body.c_str()); } std::function shutdown_handler; @@ -2417,100 +3344,73 @@ inline void signal_handler(int signal) { } int main(int argc, char ** argv) { -#if SERVER_VERBOSE != 1 - log_disable(); -#endif // own arguments required by this example - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_SERVER); - if (!gpt_params_parse(argc, argv, params, options)) { + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { return 1; } - // TODO: not great to use extern vars - server_log_json = params.log_json; - server_verbose = params.verbosity > 0; + common_init(); // struct that contains llama context and inference server_context ctx_server; - if (!params.system_prompt.empty()) { - ctx_server.system_prompt_set(params.system_prompt); - } - - if (params.model_alias == "unknown") { - params.model_alias = params.model; - } - llama_backend_init(); llama_numa_init(params.numa); - LOG_INFO("build info", { - {"build", LLAMA_BUILD_NUMBER}, - {"commit", LLAMA_COMMIT} - }); - - LOG_INFO("system info", { - {"n_threads", params.cpuparams.n_threads}, - {"n_threads_batch", params.cpuparams_batch.n_threads}, - {"total_threads", std::thread::hardware_concurrency()}, - {"system_info", llama_print_system_info()}, - }); + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); std::unique_ptr svr; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_INFO("Running with SSL", {{"key", params.ssl_file_key}, {"cert", params.ssl_file_cert}}); + LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); svr.reset( new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) ); } else { - LOG_INFO("Running without SSL", {}); + LOG_INF("Running without SSL\n"); svr.reset(new httplib::Server()); } #else + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_ERR("Server is built without SSL support\n"); + return 1; + } svr.reset(new httplib::Server()); #endif std::atomic state{SERVER_STATE_LOADING_MODEL}; svr->set_default_headers({{"Server", "llama.cpp"}}); - - // CORS preflight - svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) { - // Access-Control-Allow-Origin is already set by middleware - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - return res.set_content("", "text/html"); // blank response, no data - }); - svr->set_logger(log_server_request); - auto res_error = [](httplib::Response & res, json error_data) { + auto res_error = [](httplib::Response & res, const json & error_data) { json final_response {{"error", error_data}}; - res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); res.status = json_value(error_data, "code", 500); }; - auto res_ok = [](httplib::Response & res, json data) { - res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + auto res_ok = [](httplib::Response & res, const json & data) { + res.set_content(safe_json_to_str(data), MIMETYPE_JSON); res.status = 200; }; - svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { std::string message; try { - std::rethrow_exception(std::move(ep)); - } catch (std::exception & e) { + std::rethrow_exception(ep); + } catch (const std::exception & e) { message = e.what(); } catch (...) { message = "Unknown Exception"; } json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); - LOG_VERBOSE("Got exception", formatted_error); + LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); res_error(res, formatted_error); }); @@ -2545,20 +3445,10 @@ int main(int argc, char ** argv) { // auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { - // TODO: should we apply API key to all endpoints, including "/health" and "/models"? - static const std::unordered_set protected_endpoints = { - "/props", - "/completion", - "/completions", - "/v1/completions", - "/chat/completions", - "/v1/chat/completions", - "/infill", - "/tokenize", - "/detokenize", - "/embedding", - "/embeddings", - "/v1/embeddings", + static const std::unordered_set public_endpoints = { + "/health", + "/models", + "/v1/models", }; // If API key is not set, skip validation @@ -2566,8 +3456,8 @@ int main(int argc, char ** argv) { return true; } - // If path is not in protected_endpoints list, skip validation - if (protected_endpoints.find(req.path) == protected_endpoints.end()) { + // If path is public or is static file, skip validation + if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { return true; } @@ -2585,15 +3475,21 @@ int main(int argc, char ** argv) { // API key is invalid or not provided res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); - LOG_WARNING("Unauthorized: Invalid API Key", {}); + LOG_WRN("Unauthorized: Invalid API Key\n"); return false; }; - auto middleware_server_state = [&res_error, &state](const httplib::Request &, httplib::Response & res) { + auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); if (current_state == SERVER_STATE_LOADING_MODEL) { - res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + auto tmp = string_split(req.path, '.'); + if (req.path == "/" || tmp.back() == "html") { + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); + res.status = 503; + } else { + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + } return false; } return true; @@ -2602,6 +3498,14 @@ int main(int argc, char ** argv) { // register server middlewares svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + // If this is OPTIONS request, skip validation because browsers don't include Authorization header + if (req.method == "OPTIONS") { + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "GET, POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_content("", "text/html"); // blank response, no data + return httplib::Server::HandlerResponse::Handled; // skip further processing + } if (!middleware_server_state(req, res)) { return httplib::Server::HandlerResponse::Handled; } @@ -2623,32 +3527,38 @@ int main(int argc, char ** argv) { const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { if (!params.endpoint_slots) { - res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); return; } // request slots data using task queue - server_task task; + server_task task(SERVER_TASK_TYPE_METRICS); task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + // TODO: get rid of this dynamic_cast + auto res_metrics = dynamic_cast(result.get()); + GGML_ASSERT(res_metrics != nullptr); + // optionally return "fail_on_no_slot" error - const int n_idle_slots = result.data.at("idle"); if (req.has_param("fail_on_no_slot")) { - if (n_idle_slots == 0) { + if (res_metrics->n_idle_slots == 0) { res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, result.data.at("slots")); + res_ok(res, res_metrics->slots_data); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { @@ -2658,83 +3568,77 @@ int main(int argc, char ** argv) { } // request slots data using task queue - server_task task; + server_task task(SERVER_TASK_TYPE_METRICS); task.id = ctx_server.queue_tasks.get_new_id(); - task.id_target = -1; - task.type = SERVER_TASK_TYPE_METRICS; - task.data.push_back({{"reset_bucket", true}}); + task.metrics_reset_bucket = true; ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task, true); // high-priority task // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id); - json data = result.data; + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } - const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed"); - const uint64_t t_prompt_processing = data.at("t_prompt_processing"); - - const uint64_t n_tokens_predicted = data.at("n_tokens_predicted"); - const uint64_t t_tokens_generation = data.at("t_tokens_generation"); - - const uint64_t n_decode_total = data.at("n_decode_total"); - const uint64_t n_busy_slots_total = data.at("n_busy_slots_total"); - - const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells"); + // TODO: get rid of this dynamic_cast + auto res_metrics = dynamic_cast(result.get()); + GGML_ASSERT(res_metrics != nullptr); // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names json all_metrics_def = json { {"counter", {{ {"name", "prompt_tokens_total"}, {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")} + {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total} }, { {"name", "prompt_seconds_total"}, {"help", "Prompt process time"}, - {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3} + {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3} }, { {"name", "tokens_predicted_total"}, {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) data.at("n_tokens_predicted_total")} + {"value", (uint64_t) res_metrics->n_tokens_predicted_total} }, { {"name", "tokens_predicted_seconds_total"}, {"help", "Predict process time"}, - {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3} + {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3} }, { {"name", "n_decode_total"}, {"help", "Total number of llama_decode() calls"}, - {"value", n_decode_total} + {"value", res_metrics->n_decode_total} }, { {"name", "n_busy_slots_per_decode"}, {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) n_busy_slots_total / (float) n_decode_total} + {"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total} }}}, {"gauge", {{ {"name", "prompt_tokens_seconds"}, {"help", "Average prompt throughput in tokens/s."}, - {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.} + {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.} },{ {"name", "predicted_tokens_seconds"}, {"help", "Average generation throughput in tokens/s."}, - {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.} + {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.} },{ {"name", "kv_cache_usage_ratio"}, {"help", "KV-cache usage. 1 means 100 percent usage."}, - {"value", 1. * kv_cache_used_cells / params.n_ctx} + {"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx} },{ {"name", "kv_cache_tokens"}, {"help", "KV-cache tokens."}, - {"value", (uint64_t) data.at("kv_cache_tokens_count")} + {"value", (uint64_t) res_metrics->kv_cache_tokens_count} },{ {"name", "requests_processing"}, {"help", "Number of request processing."}, - {"value", (uint64_t) data.at("processing")} + {"value", (uint64_t) res_metrics->n_processing_slots} },{ {"name", "requests_deferred"}, {"help", "Number of request deferred."}, - {"value", (uint64_t) data.at("deferred")} + {"value", (uint64_t) res_metrics->n_tasks_deferred} }}} }; @@ -2755,8 +3659,7 @@ int main(int argc, char ** argv) { } } - const int64_t t_start = data.at("t_start"); - res.set_header("Process-Start-Time-Unix", std::to_string(t_start)); + res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start)); res.set_content(prometheus.str(), "text/plain; version=0.0.4"); res.status = 200; // HTTP OK @@ -2771,25 +3674,24 @@ int main(int argc, char ** argv) { } std::string filepath = params.slot_save_path + filename; - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_SAVE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath }, - }; + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = ctx_server.queue_tasks.get_new_id(); + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); - if (result.error) { - res_error(res, result.data); - } else { - res_ok(res, result.data); + if (result->is_error()) { + res_error(res, result->to_json()); + return; } + + res_ok(res, result->to_json()); }; const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { @@ -2801,45 +3703,45 @@ int main(int argc, char ** argv) { } std::string filepath = params.slot_save_path + filename; - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_RESTORE; - task.data = { - { "id_slot", id_slot }, - { "filename", filename }, - { "filepath", filepath }, - }; + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = ctx_server.queue_tasks.get_new_id(); + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); - if (result.error) { - res_error(res, result.data); - } else { - res_ok(res, result.data); + if (result->is_error()) { + res_error(res, result->to_json()); + return; } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { - server_task task; - task.type = SERVER_TASK_TYPE_SLOT_ERASE; - task.data = { - { "id_slot", id_slot }, - }; + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = ctx_server.queue_tasks.get_new_id(); + task.slot_action.slot_id = id_slot; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); - if (result.error) { - res_error(res, result.data); - } else { - res_ok(res, result.data); + if (result->is_error()) { + res_error(res, result->to_json()); + return; } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { @@ -2872,31 +3774,84 @@ int main(int argc, char ** argv) { }; const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { - std::string template_key = "tokenizer.chat_template", curr_tmpl; - int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); - if (tlen > 0) { - std::vector curr_tmpl_buf(tlen + 1, 0); - if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { - curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); - } - } + // this endpoint is publicly available, please only return what is safe to be exposed json data = { - { "system_prompt", ctx_server.system_prompt.c_str() }, { "default_generation_settings", ctx_server.default_generation_settings_for_props }, - { "total_slots", ctx_server.params.n_parallel }, - { "chat_template", curr_tmpl.c_str() }, + { "total_slots", ctx_server.params_base.n_parallel }, + { "model_path", ctx_server.params_base.model }, + { "chat_template", ctx_server.chat_templates.template_default->source() }, + { "bos_token", ctx_server.chat_templates.template_default->bos_token() }, + { "eos_token", ctx_server.chat_templates.template_default->eos_token() }, + { "build_info", build_info }, }; + if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { + data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); + } res_ok(res, data); }; - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { - if (ctx_server.params.embedding) { + const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (!ctx_server.params_base.endpoint_props) { + res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + json data = json::parse(req.body); + + // update any props here + + res_ok(res, {{ "success", true }}); + }; + + // handle completion-like requests (completion, chat, infill) + // we can optionally provide a custom format for partial results and final results + const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( + server_task_type type, + json & data, + std::function is_connection_closed, + httplib::Response & res, + oaicompat_type oaicompat) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + + if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); + auto completion_id = gen_chatcmplid(); + std::vector tasks; + + try { + const auto & prompt = data.at("prompt"); + LOG_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + tasks.reserve(tokenized_prompts.size()); + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.prompt_tokens = std::move(tokenized_prompts[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_selected_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.oaicompat = oaicompat; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(task); + } + } catch (const std::exception & e) { + res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return; + } + ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -2904,99 +3859,188 @@ int main(int argc, char ** argv) { const auto task_ids = server_task::get_list_id(tasks); if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { if (results.size() == 1) { // single result - res_ok(res, results[0].data); + res_ok(res, results[0]->to_json()); } else { // multiple results (multitask) json arr = json::array(); - for (const auto & res : results) { - arr.push_back(res.data); + for (auto & res : results) { + arr.push_back(res->to_json()); } res_ok(res, arr); } - }, [&](json error_data) { + }, [&](const json & error_data) { res_error(res, error_data); - }); + }, is_connection_closed); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { - return server_sent_event(sink, "data", result.data); - }, [&](json error_data) { + const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { + json res_json = result->to_json(); + if (res_json.is_array()) { + for (const auto & res : res_json) { + if (!server_sent_event(sink, "data", res)) { + // sending failed (HTTP connection closed), cancel the generation + return false; + } + } + return true; + } else { + return server_sent_event(sink, "data", res_json); + } + }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); + }, [&sink]() { + // note: do not use req.is_connection_closed here because req is already destroyed + return !sink.is_writable(); }); + if (oaicompat != OAICOMPAT_TYPE_NONE) { + static const std::string ev_done = "data: [DONE]\n\n"; + sink.write(ev_done.data(), ev_done.size()); + } sink.done(); return false; }; - res.set_chunked_content_provider("text/event-stream", chunked_content_provider); + + auto on_complete = [task_ids, &ctx_server] (bool) { + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }; - const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_NONE); }; - const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { - json data = json::parse(req.body); - return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res); + const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + json data = oaicompat_completion_params_parse(json::parse(req.body)); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_COMPLETION); }; - // TODO: maybe merge this function with "handle_completions_generic" - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { - if (ctx_server.params.embedding) { + const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + // check model compatibility + std::string err; + if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "prefix token is missing. "; + } + if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "suffix token is missing. "; + } + if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "middle token is missing. "; + } + if (!err.empty()) { + res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + json data = json::parse(req.body); + + // validate input + if (data.contains("prompt") && !data.at("prompt").is_string()) { + // prompt is optional + res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_prefix")) { + res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_suffix")) { + res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (data.contains("input_extra") && !data.at("input_extra").is_array()) { + // input_extra is optional + res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + json input_extra = json_value(data, "input_extra", json::array()); + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + if (!chunk.contains("text") || !chunk.at("text").is_string()) { + res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); + return; + } + // filename is optional + if (chunk.contains("filename") && !chunk.at("filename").is_string()) { + res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + data["input_extra"] = input_extra; // default to empty array if it's not exist + + std::string prompt = json_value(data, "prompt", std::string()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true); + SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + data["prompt"] = format_infill( + ctx_server.vocab, + data.at("input_prefix"), + data.at("input_suffix"), + data.at("input_extra"), + ctx_server.params_base.n_batch, + ctx_server.params_base.n_predict, + ctx_server.slots[0].n_ctx, // TODO: there should be a better way + ctx_server.params_base.spm_infill, + tokenized_prompts[0] + ); + + return handle_completions_impl( + SERVER_TASK_TYPE_INFILL, + data, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_NONE); // infill is not OAI compatible + }; + + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { + LOG_DBG("request: %s\n", req.body.c_str()); + if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + auto body = json::parse(req.body); + json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); - std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - - bool stream = json_value(data, "stream", false); - const auto task_ids = server_task::get_list_id(tasks); - const auto completion_id = gen_chatcmplid(); - - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id); - res_ok(res, result_oai); - }, [&](json error_data) { - res_error(res, error_data); - }); - } else { - const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { - std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); - for (auto & event_data : result_array) { - if (event_data.empty()) { - continue; // skip the stop token - } - if (!server_sent_event(sink, "data", event_data)) { - return false; // connection is closed - } - } - return true; // ok - }, [&](json error_data) { - server_sent_event(sink, "error", error_data); - }); - sink.done(); - return true; - }; - res.set_chunked_content_provider("text/event-stream", chunked_content_provider); - } + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + data, + req.is_connection_closed, + res, + OAICOMPAT_TYPE_CHAT); }; - const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { + // same with handle_chat_completions, but without inference part + const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) { + auto body = json::parse(req.body); + json data = oaicompat_completion_params_parse(body, params.use_jinja, ctx_server.chat_templates); + res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); + }; + + const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { json models = { {"object", "list"}, {"data", { { - {"id", params.model_alias}, + {"id", params.model_alias.empty() ? params.model : params.model_alias}, {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, @@ -3005,18 +4049,46 @@ int main(int argc, char ** argv) { }} }; - res.set_content(models.dump(), MIMETYPE_JSON); + res_ok(res, models); }; const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); - std::vector tokens; + json tokens_response = json::array(); if (body.count("content") != 0) { const bool add_special = json_value(body, "add_special", false); - tokens = ctx_server.tokenize(body.at("content"), add_special); + const bool with_pieces = json_value(body, "with_pieces", false); + + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true); + + if (with_pieces) { + for (const auto& token : tokens) { + std::string piece = common_token_to_piece(ctx_server.ctx, token); + json piece_json; + + // Check if the piece is valid UTF-8 + if (is_valid_utf8(piece)) { + piece_json = piece; + } else { + // If not valid UTF-8, store as array of byte values + piece_json = json::array(); + for (unsigned char c : piece) { + piece_json.push_back(static_cast(c)); + } + } + + tokens_response.push_back({ + {"id", token}, + {"piece", piece_json} + }); + } + } else { + tokens_response = tokens; + } } - const json data = format_tokenizer_response(tokens); + + const json data = format_tokenizer_response(tokens_response); res_ok(res, data); }; @@ -3025,7 +4097,7 @@ int main(int argc, char ** argv) { std::string content; if (body.count("tokens") != 0) { - const std::vector tokens = body.at("tokens"); + const llama_tokens tokens = body.at("tokens"); content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); } @@ -3033,42 +4105,81 @@ int main(int argc, char ** argv) { res_ok(res, data); }; - const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { const json body = json::parse(req.body); - bool is_openai = false; - // an input prompt can be a string or a list of tokens (integer) + if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + // for the shape of input/content, see tokenize_input_prompts() json prompt; if (body.count("input") != 0) { - is_openai = true; prompt = body.at("input"); - } else if (body.count("content") != 0) { - // with "content", we only support single prompt - prompt = std::vector{body.at("content")}; + } else if (body.contains("content")) { + oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible + prompt = body.at("content"); } else { res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } else if (format != "float") { + res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true); + for (const auto & tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } + // create and queue the task json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = std::move(tokenized_prompts[i]); + + // OAI-compat + task.params.oaicompat = oaicompat; + + tasks.push_back(task); + } + ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); // get the result std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - for (const auto & res : results) { - responses.push_back(res.data); + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); } - }, [&](json error_data) { + }, [&](const json & error_data) { res_error(res, error_data); error = true; - }); + }, req.is_connection_closed); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); } if (error) { @@ -3076,20 +4187,107 @@ int main(int argc, char ** argv) { } // write JSON response - json root = is_openai - ? format_embeddings_response_oaicompat(body, responses) - : responses[0]; + json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); + res_ok(res, root); + }; + + const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); + }; + + const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { + handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); + }; + + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) { + res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + + const json body = json::parse(req.body); + + // TODO: implement + //int top_n = 1; + //if (body.count("top_n") != 1) { + // top_n = body.at("top_n"); + //} else { + // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + // return; + //} + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + } else { + res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + std::vector documents = json_value(body, "documents", std::vector()); + if (documents.empty()) { + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return; + } + + llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0]; + + // create and queue the task + json responses = json::array(); + bool error = false; + { + std::vector tasks; + std::vector tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true); + tasks.reserve(tokenized_docs.size()); + for (size_t i = 0; i < tokenized_docs.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]); + tasks.push_back(task); + } + + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + + ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { + for (auto & res : results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + }, [&](const json & error_data) { + res_error(res, error_data); + error = true; + }, req.is_connection_closed); + } + + if (error) { + return; + } + + // write JSON response + json root = format_response_rerank(body, responses); res_ok(res, root); }; const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); - for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { - auto & la = ctx_server.lora_adapters[i]; + const auto & loras = ctx_server.params_base.lora_adapters; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; result.push_back({ {"id", i}, - {"path", la.path}, - {"scale", la.scale}, + {"path", lora.path}, + {"scale", lora.scale}, }); } res_ok(res, result); @@ -3097,89 +4295,81 @@ int main(int argc, char ** argv) { }; const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - const std::vector body = json::parse(req.body); - int max_idx = ctx_server.lora_adapters.size(); + const json body = json::parse(req.body); + if (!body.is_array()) { + res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return; + } + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = ctx_server.queue_tasks.get_new_id(); + task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - // clear existing value - for (auto & la : ctx_server.lora_adapters) { - la.scale = 0.0f; + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + if (result->is_error()) { + res_error(res, result->to_json()); + return; } - // set value - for (auto entry : body) { - int id = entry.at("id"); - float scale = entry.at("scale"); - if (0 <= id && id < max_idx) { - ctx_server.lora_adapters[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - - server_task task; - task.type = SERVER_TASK_TYPE_SET_LORA; - const int id_task = ctx_server.queue_tasks.post(task); - ctx_server.queue_results.add_waiting_task_id(id_task); - - server_task_result result = ctx_server.queue_results.recv(id_task); - ctx_server.queue_results.remove_waiting_task_id(id_task); - - res_ok(res, result.data); - res.status = 200; // HTTP OK - }; - - auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { - return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast(content), len, mime_type); - return false; - }; + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; // // Router // - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - svr->set_base_dir(params.public_path); + if (!params.webui) { + LOG_INF("Web UI is disabled\n"); + } else { + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + bool is_found = svr->set_mount_point("/", params.public_path); + if (!is_found) { + LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); + return 1; + } + } else { + // using embedded static index.html + svr->Get("/", [](const httplib::Request & req, httplib::Response & res) { + if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { + res.set_content("Error: gzip is not supported by this browser", "text/plain"); + } else { + res.set_header("Content-Encoding", "gzip"); + res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + } + return false; + }); + } } - // using embedded static files - svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); - svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); - svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); - svr->Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8")); - - // add new-ui files - svr->Get("/colorthemes.css", handle_static_file(colorthemes_css, colorthemes_css_len, "text/css; charset=utf-8")); - svr->Get("/style.css", handle_static_file(style_css, style_css_len, "text/css; charset=utf-8")); - svr->Get("/theme-beeninorder.css", handle_static_file(theme_beeninorder_css, theme_beeninorder_css_len, "text/css; charset=utf-8")); - svr->Get("/theme-ketivah.css", handle_static_file(theme_ketivah_css, theme_ketivah_css_len, "text/css; charset=utf-8")); - svr->Get("/theme-mangotango.css", handle_static_file(theme_mangotango_css, theme_mangotango_css_len, "text/css; charset=utf-8")); - svr->Get("/theme-playground.css", handle_static_file(theme_playground_css, theme_playground_css_len, "text/css; charset=utf-8")); - svr->Get("/theme-polarnight.css", handle_static_file(theme_polarnight_css, theme_polarnight_css_len, "text/css; charset=utf-8")); - svr->Get("/theme-snowstorm.css", handle_static_file(theme_snowstorm_css, theme_snowstorm_css_len, "text/css; charset=utf-8")); - svr->Get("/index-new.html", handle_static_file(index_new_html, index_new_html_len, "text/html; charset=utf-8")); - svr->Get("/system-prompts.js", handle_static_file(system_prompts_js, system_prompts_js_len, "text/javascript; charset=utf-8")); - svr->Get("/prompt-formats.js", handle_static_file(prompt_formats_js, prompt_formats_js_len, "text/javascript; charset=utf-8")); - // register API routes - svr->Get ("/health", handle_health); + svr->Get ("/health", handle_health); // public endpoint (no API key check) svr->Get ("/metrics", handle_metrics); svr->Get ("/props", handle_props); - svr->Get ("/v1/models", handle_models); + svr->Post("/props", handle_props_change); + svr->Get ("/models", handle_models); // public endpoint (no API key check) + svr->Get ("/v1/models", handle_models); // public endpoint (no API key check) svr->Post("/completion", handle_completions); // legacy svr->Post("/completions", handle_completions); - svr->Post("/v1/completions", handle_completions); + svr->Post("/v1/completions", handle_completions_oai); svr->Post("/chat/completions", handle_chat_completions); svr->Post("/v1/chat/completions", handle_chat_completions); svr->Post("/infill", handle_infill); svr->Post("/embedding", handle_embeddings); // legacy svr->Post("/embeddings", handle_embeddings); - svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/v1/embeddings", handle_embeddings_oai); + svr->Post("/rerank", handle_rerank); + svr->Post("/reranking", handle_rerank); + svr->Post("/v1/rerank", handle_rerank); + svr->Post("/v1/reranking", handle_rerank); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); + svr->Post("/apply-template", handle_apply_template); // LoRA adapters hotswap svr->Get ("/lora-adapters", handle_lora_adapters_list); svr->Post("/lora-adapters", handle_lora_adapters_apply); @@ -3203,61 +4393,77 @@ int main(int argc, char ** argv) { llama_backend_free(); }; - // bind HTTP listen port, run the HTTP server in a thread - if (!svr->bind_to_port(params.hostname, params.port)) { - LOG_ERROR("couldn't bind HTTP server socket", { - {"hostname", params.hostname}, - {"port", params.port}, - }); + // bind HTTP listen port + bool was_bound = false; + if (params.port == 0) { + int bound_port = svr->bind_to_any_port(params.hostname); + if ((was_bound = (bound_port >= 0))) { + params.port = bound_port; + } + } else { + was_bound = svr->bind_to_port(params.hostname, params.port); + } + + if (!was_bound) { + //LOG_ERROR("couldn't bind HTTP server socket", { + // {"hostname", params.hostname}, + // {"port", params.port}, + //}); + LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port); clean_up(); - LOG_ERROR("exiting due to HTTP server error", {}); return 1; } + + // run the HTTP server in a thread std::thread t([&]() { svr->listen_after_bind(); }); svr->wait_until_ready(); - LOG_INFO("HTTP server is listening", log_data); + LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http); // load the model - LOG_INFO("loading model", log_data); + LOG_INF("%s: loading model\n", __func__); + if (!ctx_server.load_model(params)) { clean_up(); t.join(); - LOG_ERROR("exiting due to model loading error", {}); + LOG_ERR("%s: exiting due to model loading error\n", __func__); return 1; - } else { - ctx_server.init(); - state.store(SERVER_STATE_READY); - - LOG_INFO("model loaded", {}); - - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (params.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { - LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); - params.chat_template = "chatml"; - } - } - - // print sample chat example to make it clear which template is used - { - LOG_INFO("chat template", { - {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)}, - {"built_in", params.chat_template.empty()}, - }); - } - - ctx_server.queue_tasks.on_new_task(std::bind( - &server_context::process_single_task, &ctx_server, std::placeholders::_1)); - ctx_server.queue_tasks.on_update_slots(std::bind( - &server_context::update_slots, &ctx_server)); - - shutdown_handler = [&](int) { - ctx_server.queue_tasks.terminate(); - }; - ctx_server.queue_tasks.start_loop(); } + ctx_server.init(); + state.store(SERVER_STATE_READY); + + LOG_INF("%s: model loaded\n", __func__); + + // if a custom chat template is not supplied, we will use the one that comes with the model (if any) + if (params.chat_template.empty()) { + if (!ctx_server.validate_builtin_chat_template(params.use_jinja)) { + LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + params.chat_template = "chatml"; + } + } + + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + ctx_server.chat_templates.template_default->source().c_str(), + common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); + + ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { + ctx_server.process_single_task(task); + }); + + ctx_server.queue_tasks.on_update_slots([&ctx_server]() { + ctx_server.update_slots(); + }); + + shutdown_handler = [&](int) { + ctx_server.queue_tasks.terminate(); + }; + + LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port); + + ctx_server.queue_tasks.start_loop(); + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; diff --git a/examples/server/tests/.gitignore b/examples/server/tests/.gitignore new file mode 100644 index 000000000..90ee7fe6d --- /dev/null +++ b/examples/server/tests/.gitignore @@ -0,0 +1,2 @@ +.venv +tmp diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md index 5e6cb277b..1de0eb30e 100644 --- a/examples/server/tests/README.md +++ b/examples/server/tests/README.md @@ -1,19 +1,9 @@ # Server tests -Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) -and [behave](https://behave.readthedocs.io/en/latest/): - -* [issues.feature](./features/issues.feature) Pending issues scenario -* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests -* [security.feature](./features/security.feature) Security, CORS and API Key -* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc... +Python based server tests scenario using [pytest](https://docs.pytest.org/en/stable/). Tests target GitHub workflows job runners with 4 vCPU. -Requests are -using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) -based http client. - Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. To mitigate it, you can increase values in `n_predict`, `kv_size`. @@ -39,27 +29,38 @@ It's possible to override some scenario steps values with environment variables: |--------------------------|------------------------------------------------------------------------------------------------| | `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` | | `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/llama-server` | -| `DEBUG` | "ON" to enable steps and server verbose mode `--verbose` | -| `SERVER_LOG_FORMAT_JSON` | if set switch server logs to json format | +| `DEBUG` | to enable steps and server verbose mode `--verbose` | | `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` | +| `LLAMA_CACHE` | by default server tests re-download models to the `tmp` subfolder. Set this to your cache (e.g. `$HOME/Library/Caches/llama.cpp` on Mac or `$HOME/.cache/llama.cpp` on Unix) to avoid this | -### Run @bug, @wip or @wrong_usage annotated scenario - -Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope. - -- `@bug` annotation aims to link a scenario with a GitHub issue. -- `@wrong_usage` are meant to show user issue that are actually an expected behavior -- `@wip` to focus on a scenario working in progress -- `@slow` heavy test, disabled by default - -To run a scenario annotated with `@bug`, start: +To run slow tests (will download many models, make sure to set `LLAMA_CACHE` if needed): ```shell -DEBUG=ON ./tests.sh --no-skipped --tags bug --stop +SLOW_TESTS=1 ./tests.sh ``` -After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated. +To run with stdout/stderr display in real time (verbose output, but useful for debugging): ```shell -./tests.sh --no-skipped --tags bug,wrong_usage || echo "should failed but compile" +DEBUG=1 ./tests.sh -s -v -x ``` + +To run all the tests in a file: + +```shell +./tests.sh unit/test_chat_completion.py.py -v -x +``` + +To run a single test: + +```shell +./tests.sh unit/test_chat_completion.py::test_invalid_chat_completion_req +``` + +Hint: You can compile and run test in single command, useful for local developement: + +```shell +cmake --build build -j --target llama-server && ./examples/server/tests/tests.sh +``` + +To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html) diff --git a/examples/server/tests/conftest.py b/examples/server/tests/conftest.py new file mode 100644 index 000000000..017d1bb84 --- /dev/null +++ b/examples/server/tests/conftest.py @@ -0,0 +1,15 @@ +import pytest +from utils import * + + +# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test +@pytest.fixture(autouse=True) +def stop_server_after_each_test(): + # do nothing before each test + yield + # stop all servers after each test + instances = set( + server_instances + ) # copy the set to prevent 'Set changed size during iteration' + for server in instances: + server.stop() diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature deleted file mode 100644 index e1eade6cd..000000000 --- a/examples/server/tests/features/embeddings.feature +++ /dev/null @@ -1,99 +0,0 @@ -@llama.cpp -@embeddings -Feature: llama.cpp server - - Background: Server startup - Given a server listening on localhost:8080 - And a model url https://huggingface.co/ggml-org/models/resolve/main/bert-bge-small/ggml-model-f16.gguf - And a model file bert-bge-small.gguf - And a model alias bert-bge-small - And 42 as server seed - And 2 slots - # the bert-bge-small model has context size of 512 - # since the generated prompts are as big as the batch size, we need to set the batch size to 512 - # ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20 - And 512 as batch size - And 512 as ubatch size - And 2048 KV cache size - And embeddings extraction - Then the server is starting - Then the server is healthy - - Scenario: Embedding - When embeddings are computed for: - """ - What is the capital of Bulgaria ? - """ - Then embeddings are generated - - Scenario: OAI Embeddings compatibility - Given a model bert-bge-small - When an OAI compatible embeddings computation request for: - """ - What is the capital of Spain ? - """ - Then embeddings are generated - - Scenario: OAI Embeddings compatibility with multiple inputs - Given a model bert-bge-small - Given a prompt: - """ - In which country Paris is located ? - """ - And a prompt: - """ - Is Madrid the capital of Spain ? - """ - When an OAI compatible embeddings computation request for multiple inputs - Then embeddings are generated - - Scenario: Multi users embeddings - Given a prompt: - """ - Write a very long story about AI. - """ - And a prompt: - """ - Write another very long music lyrics. - """ - And a prompt: - """ - Write a very long poem. - """ - And a prompt: - """ - Write a very long joke. - """ - Given concurrent embedding requests - Then the server is busy - Then the server is idle - Then all embeddings are generated - - Scenario: Multi users OAI compatibility embeddings - Given a prompt: - """ - In which country Paris is located ? - """ - And a prompt: - """ - Is Madrid the capital of Spain ? - """ - And a prompt: - """ - What is the biggest US city ? - """ - And a prompt: - """ - What is the capital of Bulgaria ? - """ - And a model bert-bge-small - Given concurrent OAI embedding requests - Then the server is busy - Then the server is idle - Then all embeddings are generated - - Scenario: All embeddings should be the same - Given 10 fixed prompts - And a model bert-bge-small - Given concurrent OAI embedding requests - Then all embeddings are the same diff --git a/examples/server/tests/features/environment.py b/examples/server/tests/features/environment.py deleted file mode 100644 index e7845dc2f..000000000 --- a/examples/server/tests/features/environment.py +++ /dev/null @@ -1,71 +0,0 @@ -import os -import signal -import socket -import sys -import time -import traceback -from contextlib import closing -from subprocess import TimeoutExpired - - -def before_scenario(context, scenario): - context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' - if context.debug: - print("DEBUG=ON") - print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m") - port = 8080 - if 'PORT' in os.environ: - port = int(os.environ['PORT']) - if is_server_listening("localhost", port): - assert False, "Server already started" - - -def after_scenario(context, scenario): - try: - if 'server_process' not in context or context.server_process is None: - return - if scenario.status == "failed": - if 'GITHUB_ACTIONS' in os.environ: - print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n") - if os.path.isfile('llama.log'): - with closing(open('llama.log', 'r')) as f: - for line in f: - print(line) - if not is_server_listening(context.server_fqdn, context.server_port): - print("\x1b[33;101mERROR: Server stopped listening\x1b[0m") - - if context.server_process.poll() is not None: - assert False, f"Server not running pid={context.server_process.pid} ..." - - server_graceful_shutdown(context) # SIGINT - - try: - context.server_process.wait(0.5) - except TimeoutExpired: - print(f"server still alive after 500ms, force-killing pid={context.server_process.pid} ...") - context.server_process.kill() # SIGKILL - context.server_process.wait() - - while is_server_listening(context.server_fqdn, context.server_port): - time.sleep(0.1) - except Exception: - print("ignoring error in after_scenario:") - traceback.print_exc(file=sys.stdout) - - -def server_graceful_shutdown(context): - print(f"shutting down server pid={context.server_process.pid} ...") - if os.name == 'nt': - interrupt = signal.CTRL_C_EVENT - else: - interrupt = signal.SIGINT - context.server_process.send_signal(interrupt) - - -def is_server_listening(server_fqdn, server_port): - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - result = sock.connect_ex((server_fqdn, server_port)) - _is_server_listening = result == 0 - if _is_server_listening: - print(f"server is listening on {server_fqdn}:{server_port}...") - return _is_server_listening diff --git a/examples/server/tests/features/issues.feature b/examples/server/tests/features/issues.feature deleted file mode 100644 index 7b13e44ca..000000000 --- a/examples/server/tests/features/issues.feature +++ /dev/null @@ -1,5 +0,0 @@ -# List of ongoing issues -# run with: DEBUG=ON ./tests.sh --no-skipped --tags bug -@bug -Feature: Issues - # No confirmed issue at the moment diff --git a/examples/server/tests/features/lora.feature b/examples/server/tests/features/lora.feature deleted file mode 100644 index 7b85988ac..000000000 --- a/examples/server/tests/features/lora.feature +++ /dev/null @@ -1,36 +0,0 @@ -@llama.cpp -@lora -Feature: llama.cpp server - - Background: Server startup - Given a server listening on localhost:8080 - And a model url https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/stories15M_MOE-F16.gguf - And a model file stories15M_MOE-F16.gguf - And a model alias stories15M_MOE - And a lora adapter file from https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf - And 42 as server seed - And 1024 as batch size - And 1024 as ubatch size - And 2048 KV cache size - And 64 max tokens to predict - And 0.0 temperature - Then the server is starting - Then the server is healthy - - Scenario: Completion LoRA disabled - Given switch off lora adapter 0 - Given a prompt: - """ - Look in thy glass - """ - And a completion request with no api error - Then 64 tokens are predicted matching little|girl|three|years|old - - Scenario: Completion LoRA enabled - Given switch on lora adapter 0 - Given a prompt: - """ - Look in thy glass - """ - And a completion request with no api error - Then 64 tokens are predicted matching eye|love|glass|sun diff --git a/examples/server/tests/features/parallel.feature b/examples/server/tests/features/parallel.feature deleted file mode 100644 index 423d0f1d4..000000000 --- a/examples/server/tests/features/parallel.feature +++ /dev/null @@ -1,131 +0,0 @@ -@llama.cpp -@parallel -Feature: Parallel - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models - And a model file test-model-00001-of-00003.gguf - And 42 as server seed - And 128 as batch size - And 256 KV cache size - And 2 slots - And continuous batching - Then the server is starting - Then the server is healthy - - Scenario Outline: Multi users completion - Given a prompt: - """ - Write a very long story about AI. - """ - And a prompt: - """ - Write another very long music lyrics. - """ - And max tokens to predict - Given concurrent completion requests - Then the server is busy - Then the server is idle - And all slots are idle - Then all prompts are predicted with tokens - Examples: - | n_predict | - | 128 | - - Scenario Outline: Multi users OAI completions compatibility - Given a system prompt You are a writer. - And a model tinyllama-2 - Given a prompt: - """ - Write a very long book. - """ - And a prompt: - """ - Write another a poem. - """ - And max tokens to predict - And streaming is - Given concurrent OAI completions requests - Then the server is busy - Then the server is idle - Then all prompts are predicted with tokens - Examples: - | streaming | n_predict | - | disabled | 128 | - | enabled | 64 | - - Scenario Outline: Multi users OAI completions compatibility no v1 - Given a system prompt You are a writer. - And a model tinyllama-2 - Given a prompt: - """ - Write a very long book. - """ - And a prompt: - """ - Write another a poem. - """ - And max tokens to predict - And streaming is - Given concurrent OAI completions requests no v1 - Then the server is busy - Then the server is idle - Then all prompts are predicted with tokens - Examples: - | streaming | n_predict | - | disabled | 128 | - | enabled | 64 | - - Scenario Outline: Multi users with number of prompts exceeding number of slots - Given a system prompt You are a writer. - And a model tinyllama-2 - Given a prompt: - """ - Write a very long book. - """ - And a prompt: - """ - Write another a poem. - """ - And a prompt: - """ - What is LLM? - """ - And a prompt: - """ - The sky is blue and I love it. - """ - And max tokens to predict - And streaming is - Given concurrent OAI completions requests - Then the server is busy - Then the server is idle - Then all prompts are predicted with tokens - Examples: - | streaming | n_predict | - | disabled | 128 | - | enabled | 64 | - - Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969 - Given a prompt: - """ - Write a very long story about AI. - """ - And a prompt: - """ - Write another very long music lyrics. - """ - And a prompt: - """ - Write a very long poem. - """ - And a prompt: - """ - Write a very long joke. - """ - And 128 max tokens to predict - Given concurrent completion requests - Then the server is busy - Then the server is idle - Then all prompts are predicted diff --git a/examples/server/tests/features/passkey.feature b/examples/server/tests/features/passkey.feature deleted file mode 100644 index ff0a82cc4..000000000 --- a/examples/server/tests/features/passkey.feature +++ /dev/null @@ -1,56 +0,0 @@ -# run with: ./tests.sh --no-skipped --tags passkey -@passkey -@slow -Feature: Passkey / Self-extend with context shift - - Background: Server startup - Given a server listening on localhost:8080 - - # Generates a long text of junk and inserts a secret passkey number inside it. - # Then we query the LLM for the secret passkey. - # see #3856 and #4810 - Scenario Outline: Passkey - Given a model file from HF repo - And as batch size - And as number of junk - And server max tokens to predict - And 42 as seed - And 0.0 temperature - And KV cache size - And 1 slots - And group attention factor to extend context size through self-extend - And group attention width to extend context size through self-extend - # Can be override with N_GPU_LAYERS - And GPU offloaded layers - Then the server is starting - # Higher timeout because the model may need to be downloaded from the internet - Then the server is healthy with timeout 120 seconds - Given available models - Then model 0 is trained on tokens context - Given a prefix prompt: - """ - here is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there. - """ - And a passkey prompt template: - """ - The pass key is Remember it. is the pass key. - """ - And a junk suffix prompt: - """ - The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. - """ - And a suffix prompt: - """ - What is the pass key? The pass key is - """ - Given a "" passkey challenge prompt with the passkey inserted every junk - And a completion request with no api error - Then tokens are predicted matching - - Examples: - | hf_repo | hf_file | n_ctx_train | ngl | n_ctx | n_batch | n_ga | n_ga_w | n_junk | i_pos | passkey | n_predicted | re_content | - | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 4 | 512 | 250 | 50 | 42 | 1 | 42 | - | TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 2 | 512 | 250 | 50 | 42 | 1 | \b((?!42)\w)+\b | - #| TheBloke/Llama-2-7B-GGUF | llama-2-7b.Q2_K.gguf | 4096 | 3 | 16384 | 512 | 4 | 512 | 500 | 300 | 1234 | 5 | 1234 | - #| TheBloke/Mixtral-8x7B-v0.1-GGUF | mixtral-8x7b-v0.1.Q2_K.gguf | 32768 | 2 | 16384 | 512 | 4 | 512 | 500 | 100 | 0987 | 5 | 0 - # 987 | diff --git a/examples/server/tests/features/results.feature b/examples/server/tests/features/results.feature deleted file mode 100644 index e8e1b5414..000000000 --- a/examples/server/tests/features/results.feature +++ /dev/null @@ -1,118 +0,0 @@ -@llama.cpp -@results -Feature: Results - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/split/stories15M-00001-of-00003.gguf from HF repo ggml-org/models - And a model file test-model-00001-of-00003.gguf - And 128 as batch size - And 1024 KV cache size - And 128 max tokens to predict - And continuous batching - - Scenario Outline: consistent results with same seed - Given slots - And 1.0 temperature - Then the server is starting - Then the server is healthy - - Given 4 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42 - - Given concurrent completion requests - Then the server is busy - Then the server is idle - And all slots are idle - Then all predictions are equal - Examples: - | n_slots | - | 1 | - # FIXME: unified KV cache nondeterminism - # | 2 | - - Scenario Outline: different results with different seed - Given slots - And 1.0 temperature - Then the server is starting - Then the server is healthy - - Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 42 - Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 43 - Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 44 - Given 1 prompts "Title: Little Red Riding Hood But In Space\n\nSummary:" with seed 45 - - Given concurrent completion requests - Then the server is busy - Then the server is idle - And all slots are idle - Then all predictions are different - Examples: - | n_slots | - | 1 | - | 2 | - - Scenario Outline: consistent results with same seed and varying batch size - Given 4 slots - And temperature - # And 0 as draft - Then the server is starting - Then the server is healthy - - Given 1 prompts "Write a very long story about AI." with seed 42 - And concurrent completion requests - # Then the server is busy # Not all slots will be utilized. - Then the server is idle - And all slots are idle - - Given prompts "Write a very long story about AI." with seed 42 - And concurrent completion requests - # Then the server is busy # Not all slots will be utilized. - Then the server is idle - And all slots are idle - - Then all predictions are equal - Examples: - | n_parallel | temp | - | 1 | 0.0 | - | 1 | 1.0 | - # FIXME: unified KV cache nondeterminism - # See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227 - # and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574 - # and https://github.com/ggerganov/llama.cpp/pull/7347 . - # | 2 | 0.0 | - # | 4 | 0.0 | - # | 2 | 1.0 | - # | 4 | 1.0 | - - Scenario Outline: consistent token probs with same seed and prompt - Given slots - And KV cache size - And 1.0 temperature - And max tokens to predict - Then the server is starting - Then the server is healthy - - Given 1 prompts "The meaning of life is" with seed 42 - And concurrent completion requests - # Then the server is busy # Not all slots will be utilized. - Then the server is idle - And all slots are idle - - Given prompts "The meaning of life is" with seed 42 - And concurrent completion requests - # Then the server is busy # Not all slots will be utilized. - Then the server is idle - And all slots are idle - - Then all token probabilities are equal - Examples: - | n_slots | n_kv | n_predict | n_parallel | - | 4 | 1024 | 1 | 1 | - # FIXME: unified KV cache nondeterminism - # See https://github.com/ggerganov/whisper.cpp/issues/1941#issuecomment-1986923227 - # and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574 - # and https://github.com/ggerganov/llama.cpp/pull/7347 . - # | 4 | 1024 | 1 | 4 | - # | 4 | 1024 | 100 | 1 | - # This test still fails even the above patches; the first token probabilities are already different. - # | 4 | 1024 | 100 | 4 | diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature deleted file mode 100644 index eb82e7aca..000000000 --- a/examples/server/tests/features/security.feature +++ /dev/null @@ -1,68 +0,0 @@ -@llama.cpp -@security -Feature: Security - - Background: Server startup with an api key defined - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And a server api key llama.cpp - Then the server is starting - Then the server is healthy - - Scenario Outline: Completion with some user api key - Given a prompt test - And a user api key - And 4 max tokens to predict - And a completion request with api error - - Examples: Prompts - | api_key | api_error | - | llama.cpp | no | - | llama.cpp | no | - | hackeme | raised | - | | raised | - - Scenario Outline: OAI Compatibility - Given a system prompt test - And a user prompt test - And a model test - And 2 max tokens to predict - And streaming is disabled - And a user api key - Given an OAI compatible chat completions request with api error - - Examples: Prompts - | api_key | api_error | - | llama.cpp | no | - | llama.cpp | no | - | hackme | raised | - - Scenario Outline: OAI Compatibility (invalid response formats) - Given a system prompt test - And a user prompt test - And a response format - And a model test - And 2 max tokens to predict - And streaming is disabled - Given an OAI compatible chat completions request with raised api error - - Examples: Prompts - | response_format | - | {"type": "sound"} | - | {"type": "json_object", "schema": 123} | - | {"type": "json_object", "schema": {"type": 123}} | - | {"type": "json_object", "schema": {"type": "hiccup"}} | - - - Scenario Outline: CORS Options - Given a user api key llama.cpp - When an OPTIONS request is sent from - Then CORS header is set to - - Examples: Headers - | origin | cors_header | cors_header_value | - | localhost | Access-Control-Allow-Origin | localhost | - | web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr | - | origin | Access-Control-Allow-Credentials | true | - | web.mydomain.fr | Access-Control-Allow-Methods | POST | - | web.mydomain.fr | Access-Control-Allow-Headers | * | diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature deleted file mode 100644 index b55971454..000000000 --- a/examples/server/tests/features/server.feature +++ /dev/null @@ -1,112 +0,0 @@ -@llama.cpp -@server -Feature: llama.cpp server - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And a model file test-model.gguf - And a model alias tinyllama-2 - And BOS token is 1 - And 42 as server seed - # KV Cache corresponds to the total amount of tokens - # that can be stored across all independent sequences: #4130 - # see --ctx-size and #5568 - And 256 KV cache size - And 32 as batch size - And 2 slots - And 64 server max tokens to predict - And prometheus compatible metrics exposed - Then the server is starting - Then the server is healthy - - Scenario: Health - Then the server is ready - And all slots are idle - - - Scenario Outline: Completion - Given a prompt - And max tokens to predict - And a completion request with no api error - Then tokens are predicted matching - And the completion is truncated - And prompt tokens are processed - And prometheus metrics are exposed - And metric llamacpp:tokens_predicted is - - Examples: Prompts - | prompt | n_predict | re_content | n_prompt | n_predicted | truncated | - | I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not | - | Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids\|Anna\|forest)+ | 46 | 64 | not | - - Scenario: Completion prompt truncated - Given a prompt: - """ - Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. - Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. - Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. - Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. - """ - And a completion request with no api error - Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl - And the completion is truncated - And 109 prompt tokens are processed - - - Scenario Outline: OAI Compatibility - Given a model - And a system prompt - And a user prompt - And max tokens to predict - And streaming is - Given an OAI compatible chat completions request with no api error - Then tokens are predicted matching - And prompt tokens are processed - And the completion is truncated - - Examples: Prompts - | model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated | - | llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not | - | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | | - - - Scenario Outline: OAI Compatibility w/ response format - Given a model test - And a system prompt test - And a user prompt test - And a response format - And 10 max tokens to predict - Given an OAI compatible chat completions request with no api error - Then tokens are predicted matching - - Examples: Prompts - | response_format | n_predicted | re_content | - | {"type": "json_object", "schema": {"const": "42"}} | 6 | "42" | - | {"type": "json_object", "schema": {"items": [{"type": "integer"}]}} | 10 | \[ -300 \] | - | {"type": "json_object"} | 10 | \{ " Jacky. | - - - Scenario: Tokenize / Detokenize - When tokenizing: - """ - What is the capital of France ? - """ - Then tokens can be detokenized - And tokens do not begin with BOS - - Scenario: Tokenize w/ BOS - Given adding special tokens - When tokenizing: - """ - What is the capital of Germany? - """ - Then tokens begin with BOS - Given first token is removed - Then tokens can be detokenized - - Scenario: Models available - Given available models - Then 1 models are supported - Then model 0 is identified by tinyllama-2 - Then model 0 is trained on 128 tokens context diff --git a/examples/server/tests/features/slotsave.feature b/examples/server/tests/features/slotsave.feature deleted file mode 100644 index 1c281c074..000000000 --- a/examples/server/tests/features/slotsave.feature +++ /dev/null @@ -1,58 +0,0 @@ -@llama.cpp -@slotsave -Feature: llama.cpp server slot management - - Background: Server startup - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And prompt caching is enabled - And 2 slots - And . as slot save path - And 2048 KV cache size - And 42 as server seed - And 24 max tokens to predict - Then the server is starting - Then the server is healthy - - Scenario: Save and Restore Slot - # First prompt in slot 1 should be fully processed - Given a user prompt "What is the capital of France?" - And using slot id 1 - And a completion request with no api error - Then 24 tokens are predicted matching (Lily|cake) - And 22 prompt tokens are processed - When the slot 1 is saved with filename "slot1.bin" - Then the server responds with status code 200 - # Since we have cache, this should only process the last tokens - Given a user prompt "What is the capital of Germany?" - And a completion request with no api error - Then 24 tokens are predicted matching (Thank|special) - And 7 prompt tokens are processed - # Loading the original cache into slot 0, - # we should only be processing 1 prompt token and get the same output - When the slot 0 is restored with filename "slot1.bin" - Then the server responds with status code 200 - Given a user prompt "What is the capital of France?" - And using slot id 0 - And a completion request with no api error - Then 24 tokens are predicted matching (Lily|cake) - And 1 prompt tokens are processed - # For verification that slot 1 was not corrupted during slot 0 load, same thing - Given a user prompt "What is the capital of Germany?" - And using slot id 1 - And a completion request with no api error - Then 24 tokens are predicted matching (Thank|special) - And 1 prompt tokens are processed - - Scenario: Erase Slot - Given a user prompt "What is the capital of France?" - And using slot id 1 - And a completion request with no api error - Then 24 tokens are predicted matching (Lily|cake) - And 22 prompt tokens are processed - When the slot 1 is erased - Then the server responds with status code 200 - Given a user prompt "What is the capital of France?" - And a completion request with no api error - Then 24 tokens are predicted matching (Lily|cake) - And 22 prompt tokens are processed diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py deleted file mode 100644 index 65b71a8e8..000000000 --- a/examples/server/tests/features/steps/steps.py +++ /dev/null @@ -1,1375 +0,0 @@ -import asyncio -import json -import os -import re -import socket -import subprocess -import sys -import threading -import time -import requests -from collections.abc import Sequence -from contextlib import closing -from re import RegexFlag -from typing import Any, Literal, cast - -import aiohttp -import numpy as np -import openai -from openai.types.chat import ChatCompletionChunk -from behave import step # pyright: ignore[reportAttributeAccessIssue] -from behave.api.async_step import async_run_until_complete -from prometheus_client import parser - -# pyright: reportRedeclaration=false - -DEFAULT_TIMEOUT_SECONDS = aiohttp.ClientTimeout(total=600) - -@step("a server listening on {server_fqdn}:{server_port}") -def step_server_config(context, server_fqdn: str, server_port: str): - context.server_fqdn = server_fqdn - context.server_port = int(server_port) - context.n_threads = None - context.n_gpu_layer = None - if 'PORT' in os.environ: - context.server_port = int(os.environ['PORT']) - print(f"$PORT set, overriding server port with to {context.server_port}") - if 'FQDN' in os.environ: - context.server_fqdn = os.environ['FQDN'] - print(f"$FQDN set, overriding server fqdn with to {context.server_fqdn}") - if 'N_GPU_LAYERS' in os.environ: - context.n_gpu_layer = int(os.environ['N_GPU_LAYERS']) - print(f"$N_GPU_LAYERS set, overriding n_gpu_layer with to {context.n_gpu_layer}") - - context.base_url = f'http://{context.server_fqdn}:{context.server_port}' - - context.model_alias = None - context.model_file = None - context.model_hf_repo = None - context.model_hf_file = None - context.model_url = None - context.n_batch = None - context.n_ubatch = None - context.n_ctx = None - context.n_ga = None - context.n_ga_w = None - context.n_predict = None - context.n_prompts = 0 - context.n_server_predict = None - context.slot_save_path = None - context.id_slot = None - context.cache_prompt = None - context.n_slots = None - context.prompt_prefix = None - context.prompt_suffix = None - context.server_api_key = None - context.server_continuous_batching = False - context.server_embeddings = False - context.server_metrics = False - context.server_process = None - context.seed = None - context.draft = None - context.server_seed = None - context.user_api_key = None - context.response_format = None - context.temperature = None - context.lora_file = None - - context.tasks_result = [] - context.concurrent_tasks = [] - context.prompts = [] - - -@step('a model file {hf_file} from HF repo {hf_repo}') -def step_download_hf_model(context, hf_file: str, hf_repo: str): - context.model_hf_repo = hf_repo - context.model_hf_file = hf_file - context.model_file = os.path.basename(hf_file) - -@step('a lora adapter file from {lora_file_url}') -def step_download_lora_file(context, lora_file_url: str): - file_name = lora_file_url.split('/').pop() - context.lora_file = f'../../../{file_name}' - with open(context.lora_file, 'wb') as f: - f.write(requests.get(lora_file_url).content) - -@step('a model file {model_file}') -def step_model_file(context, model_file: str): - context.model_file = model_file - - -@step('a model url {model_url}') -def step_model_url(context, model_url: str): - context.model_url = model_url - - -@step('a model alias {model_alias}') -def step_model_alias(context, model_alias: str): - context.model_alias = model_alias - - -@step('{seed:d} as server seed') -def step_seed(context, seed: int): - context.server_seed = seed - - -@step('{ngl:d} GPU offloaded layers') -def step_n_gpu_layer(context, ngl: int): - if 'N_GPU_LAYERS' in os.environ: - new_ngl = int(os.environ['N_GPU_LAYERS']) - if context.debug: - print(f"-ngl upgraded from {ngl} to {new_ngl}") - ngl = new_ngl - context.n_gpu_layer = ngl - - -@step('{n_threads:d} threads') -def step_n_threads(context, n_threads: int): - context.n_thread = n_threads - - -@step('{draft:d} as draft') -def step_draft(context, draft: int): - context.draft = draft - - -@step('{n_ctx:d} KV cache size') -def step_n_ctx(context, n_ctx: int): - context.n_ctx = n_ctx - - -@step('{n_slots:d} slots') -def step_n_slots(context, n_slots: int): - context.n_slots = n_slots - - -@step('{n_predict:d} server max tokens to predict') -def step_server_n_predict(context, n_predict: int): - context.n_server_predict = n_predict - - -@step('{slot_save_path} as slot save path') -def step_slot_save_path(context, slot_save_path: str): - context.slot_save_path = slot_save_path - - -@step('using slot id {id_slot:d}') -def step_id_slot(context, id_slot: int): - context.id_slot = id_slot - - -@step('prompt caching is enabled') -def step_enable_prompt_cache(context): - context.cache_prompt = True - - -@step('continuous batching') -def step_server_continuous_batching(context): - context.server_continuous_batching = True - - -@step('embeddings extraction') -def step_server_embeddings(context): - context.server_embeddings = True - - -@step('prometheus compatible metrics exposed') -def step_server_metrics(context): - context.server_metrics = True - - -@step("the server is starting") -def step_start_server(context): - start_server_background(context) - attempts = 0 - max_attempts = 20 - if 'GITHUB_ACTIONS' in os.environ: - max_attempts *= 2 - - addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM) - family, typ, proto, _, sockaddr = addrs[0] - - while True: - with closing(socket.socket(family, typ, proto)) as sock: - result = sock.connect_ex(sockaddr) - if result == 0: - print("\x1b[33;46mserver started!\x1b[0m") - return - attempts += 1 - if attempts > max_attempts: - assert False, "server not started" - print(f"waiting for server to start, connect error code = {result}...") - time.sleep(0.1) - - -async def wait_for_server_status_with_timeout(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str, timeout: int): - match expecting_status: - case 'healthy': - await wait_for_slots_status(context, context.base_url, 200, - timeout=timeout) - - case 'ready' | 'idle': - await wait_for_slots_status(context, context.base_url, 200, - timeout=timeout, - params={'fail_on_no_slot': 1}, - slots_idle=context.n_slots, - slots_processing=0) - case 'busy': - await wait_for_slots_status(context, context.base_url, 503, - params={'fail_on_no_slot': 1}, - slots_idle=0, - slots_processing=context.n_slots) - case _: - assert False, "unknown status" - - -@step("the server is {expecting_status} with timeout {timeout:d} seconds") -@async_run_until_complete -async def step_wait_for_server_status_with_timeout(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str, timeout: int): - await wait_for_server_status_with_timeout(context, expecting_status, timeout) - - -@step("the server is {expecting_status}") -@async_run_until_complete -async def step_wait_for_server_status(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str): - await wait_for_server_status_with_timeout(context, expecting_status, 30) - - -@step('all slots are {expected_slot_status_string}') -@async_run_until_complete -async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str): - match expected_slot_status_string: - case 'idle': - expected_slot_status = 0 - case 'busy': - expected_slot_status = 1 - case _: - assert False, "unknown status" - - expected_slots = [{'id': slot_id, 'state': expected_slot_status} - for slot_id in range(context.n_slots)] - await request_slots_status(context, expected_slots) - - -@step('a completion request with {api_error} api error') -@async_run_until_complete -async def step_request_completion(context, api_error: Literal['raised'] | str): - expect_api_error = api_error == 'raised' - seeds = await completions_seed(context, num_seeds=1) - completion = await request_completion(context.prompts.pop(), - seeds[0] if seeds is not None else seeds, - context.base_url, - debug=context.debug, - n_predict=context.n_predict, - cache_prompt=context.cache_prompt, - id_slot=context.id_slot, - expect_api_error=expect_api_error, - user_api_key=context.user_api_key, - temperature=context.temperature) - context.tasks_result.append(completion) - if context.debug: - print(f"Completion response: {completion}") - if expect_api_error: - assert completion == 401, f"completion must be an 401 status code: {completion}" - - -@step('{predicted_n:d} tokens are predicted matching {re_content}') -def step_n_tokens_predicted_with_content(context, predicted_n, re_content): - context.completion = context.tasks_result.pop() - assert_n_tokens_predicted(context.completion, predicted_n, re_content) - - -@step('{predicted_n:d} tokens are predicted') -def step_n_tokens_predicted(context, predicted_n): - context.completion = context.tasks_result.pop() - assert_n_tokens_predicted(context.completion, predicted_n) - - -@step('all predictions are equal') -@async_run_until_complete -async def step_predictions_equal(context): - n_completions = await gather_tasks_results(context) - assert n_completions >= 2, "need at least 2 completions" - assert_all_predictions_equal(context.tasks_result) - context.tasks_result = [] - - -@step('all predictions are different') -@async_run_until_complete -async def step_predictions_different(context): - n_completions = await gather_tasks_results(context) - assert n_completions >= 2, "need at least 2 completions" - assert_all_predictions_different(context.tasks_result) - context.tasks_result = [] - - -@step('all token probabilities are equal') -@async_run_until_complete -async def step_token_probabilities_equal(context): - n_completions = await gather_tasks_results(context) - assert n_completions >= 2, "need at least 2 completions" - assert_all_token_probabilities_equal(context.tasks_result) - context.tasks_result = [] - - -@step('the completion is truncated') -def step_assert_completion_truncated(context): - step_assert_completion_truncated(context, '') - - -@step('the completion is {truncated} truncated') -def step_assert_completion_truncated(context, truncated): - truncated = truncated != "not" - assert context.completion['truncated'] == truncated, f'{context.completion}' - - -@step('{n_prompt:d} prompt tokens are processed') -def step_impl(context, n_prompt): - assert n_prompt < 0 or n_prompt == context.completion['timings']['prompt_n'], f"n_prompt={context.completion['timings']['prompt_n']}" - - -@step('a user prompt {user_prompt}') -def step_user_prompt(context, user_prompt): - context.prompts.append(user_prompt) - context.n_prompts = len(context.prompts) - - -@step('a system prompt {system_prompt}') -def step_system_prompt(context, system_prompt): - context.system_prompt = system_prompt - - -@step('a model {model}') -def step_model(context, model): - context.model = model - - -@step('{max_tokens:d} max tokens to predict') -def step_max_tokens(context, max_tokens): - context.n_predict = max_tokens - - -@step('a response format {response_format}') -def step_response_format(context, response_format): - context.response_format = json.loads(response_format) - - -@step('{temperature:f} temperature') -def step_temperature(context, temperature): - context.temperature = temperature - - -@step('streaming is {enable_streaming}') -def step_streaming(context, enable_streaming): - context.enable_streaming = enable_streaming == 'enabled' - - -@step('a user api key {user_api_key}') -def step_user_api_key(context, user_api_key): - context.user_api_key = user_api_key - - -@step('no user api key') -def step_no_user_api_key(context): - context.user_api_key = None - - -@step('a user api key ') -def step_no_user_api_key_space(context): - context.user_api_key = None - - -@step('a server api key {server_api_key}') -def step_server_api_key(context, server_api_key): - context.server_api_key = server_api_key - - -@step('{n_junk:d} as number of junk') -def step_n_junk(context, n_junk): - context.n_junk = n_junk - - -@step('{n_batch:d} as batch size') -def step_n_batch(context, n_batch): - context.n_batch = n_batch - - -@step('{n_ubatch:d} as ubatch size') -def step_n_ubatch(context, n_ubatch): - context.n_ubatch = n_ubatch - - -@step('{seed:d} as seed') -def step_seed(context, seed): - if context.seed is None: - context.seed = [seed] - else: - context.seed.append(seed) - - -@step('BOS token is {bos:d}') -def step_bos_token(context, bos): - context.bos = bos - - -@step('a prefix prompt') -def step_prompt_prefix(context): - context.prompt_prefix = context_text(context) - - -@step('a junk suffix prompt') -def step_prompt_junk_suffix(context): - context.prompt_junk_suffix = context_text(context) - - -@step('a suffix prompt') -def step_prompt_suffix(context): - context.prompt_suffix = context_text(context) - - -@step('{n_ga:d} group attention factor' - ' to extend context size through self-extend') -def step_impl(context, n_ga): - context.n_ga = n_ga - - -@step('{n_ga_w:d} group attention width to extend context size through self-extend') -def step_impl(context, n_ga_w): - context.n_ga_w = n_ga_w - - -@step('a passkey prompt template') -def step_prompt_passkey(context): - context.prompt_passkey = context_text(context) - - -@step('{n_prompts:d} fixed prompts') -def step_fixed_prompts(context, n_prompts): - context.prompts.extend([str(0)*(context.n_batch if context.n_batch is not None else 512) for i in range(n_prompts)]) - context.n_prompts = n_prompts - - -@step('a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk') -def step_prompt_passkey(context, passkey, i_pos): - prompt = "" - for i in range(context.n_junk): - if i % context.n_junk == i_pos: - prompt += context.prompt_passkey # the passkey is already substituted - prompt += context.prompt_junk_suffix - if context.debug: - passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m" - print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```") - context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix) - context.n_prompts = len(context.prompts) - - -@step('an OAI compatible chat completions request with {api_error} api error') -@async_run_until_complete -async def step_oai_chat_completions(context, api_error): - if context.debug: - print(f"Submitting OAI compatible completions request...") - expect_api_error = api_error == 'raised' - seeds = await completions_seed(context, num_seeds=1), - completion = await oai_chat_completions(context.prompts.pop(), - seeds[0] if seeds is not None else seeds, - context.system_prompt, - context.base_url, - '/v1/chat', - False, - model=context.model if hasattr(context, 'model') else None, - - n_predict=context.n_predict - if hasattr(context, 'n_predict') else None, - - enable_streaming=context.enable_streaming - if hasattr(context, 'enable_streaming') else None, - - response_format=context.response_format - if hasattr(context, 'response_format') else None, - - user_api_key=context.user_api_key - if hasattr(context, 'user_api_key') else None, - - expect_api_error=expect_api_error) - context.tasks_result.append(completion) - if context.debug: - print(f"Completion response: {completion}") - if expect_api_error: - assert completion == 401, f"completion must be an 401 status code: {completion}" - - if context.debug: - print(f"Completion response: {completion}") - - -@step('a prompt') -def step_a_prompt(context): - context.prompts.append(context_text(context)) - context.n_prompts = len(context.prompts) - - -@step('a prompt {prompt}') -def step_a_prompt_prompt(context, prompt): - context.prompts.append(prompt) - context.n_prompts = len(context.prompts) - - -@step('{num_prompts:d} prompts {prompt} with seed {seed:d}') -def step_many_prompts(context, num_prompts, prompt, seed): - if context.seed is None: - context.seed = [] - for _ in range(num_prompts): - context.seed.append(seed) - context.prompts.append(prompt) - context.n_prompts = len(context.prompts) - - -@step('concurrent completion requests') -@async_run_until_complete() -async def step_concurrent_completion_requests(context): - await concurrent_requests( - context, - request_completion, - # prompt is inserted automatically - context.base_url, - debug=context.debug, - prompt_prefix=context.prompt_prefix, - prompt_suffix=context.prompt_suffix, - n_predict=context.n_predict if hasattr(context, 'n_predict') else None, - user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None, - temperature=context.temperature, - ) - - -@step('concurrent OAI completions requests') -@async_run_until_complete -async def step_oai_chat_completions(context): - await concurrent_requests(context, oai_chat_completions, - # user_prompt is inserted automatically - context.system_prompt, - context.base_url, - '/v1/chat/completions', - True, # async_client - model=context.model - if hasattr(context, 'model') else None, - n_predict=context.n_predict - if hasattr(context, 'n_predict') else None, - enable_streaming=context.enable_streaming - if hasattr(context, 'enable_streaming') else None, - response_format=context.response_format - if hasattr(context, 'response_format') else None, - user_api_key=context.user_api_key - if hasattr(context, 'user_api_key') else None) - - -@step('concurrent OAI completions requests no v1') -@async_run_until_complete -async def step_oai_chat_completions(context): - await concurrent_requests(context, oai_chat_completions, - # user_prompt is inserted automatically - context.system_prompt, - context.base_url, - '/chat/completions', - True, # async_client - model=context.model - if hasattr(context, 'model') else None, - n_predict=context.n_predict - if hasattr(context, 'n_predict') else None, - enable_streaming=context.enable_streaming - if hasattr(context, 'enable_streaming') else None, - response_format=context.response_format - if hasattr(context, 'response_format') else None, - user_api_key=context.user_api_key - if hasattr(context, 'user_api_key') else None) - - -@step('all prompts are predicted') -@async_run_until_complete -async def step_all_prompts_are_predicted(context): - await all_prompts_are_predicted(context) - - -@step('all prompts are predicted with {n_expected_predicted:d} tokens') -@async_run_until_complete -async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted): - await all_prompts_are_predicted(context, n_expected_predicted) - - -async def all_prompts_are_predicted(context, expected_predicted_n=None): - n_completions = await gather_tasks_results(context) - assert n_completions > 0 - for i in range(n_completions): - assert_n_tokens_predicted(context.tasks_result.pop(), expected_predicted_n=expected_predicted_n) - assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests" - - -@step('embeddings are computed for') -@async_run_until_complete -async def step_compute_embedding(context): - context.n_prompts = 1 - context.embeddings = await request_embedding(context_text(context), None, base_url=context.base_url) - - -@step('all embeddings are the same') -@async_run_until_complete -async def step_all_embeddings_are_the_same(context): - n_embedding_requests = await gather_tasks_results(context) - assert n_embedding_requests > 0 - embeddings = [] - for i in range(n_embedding_requests): - embedding = context.tasks_result.pop().pop() - embeddings.append(embedding) - assert_embeddings(embedding) - n = len(embeddings) - for i in range(n-1): - for j in range(i+1, n): - embedding1 = np.array(embeddings[i]) - embedding2 = np.array(embeddings[j]) - if context.debug: - print(f"embedding1: {embedding1[-8:]}") - print(f"embedding2: {embedding2[-8:]}") - similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) - msg = f"Similarity between {i} and {j}: {similarity:.10f}" - if context.debug: - print(f"{msg}") - assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg - - -@step('embeddings are generated') -def step_assert_embeddings(context): - assert context.n_prompts == len(context.embeddings), (f"unexpected response:\n" - f"context.n_prompts={context.n_prompts}\n" - f"context.embeddings={context.embeddings}") - for embedding in context.embeddings: - assert_embeddings(embedding) - - -@step('an OAI compatible embeddings computation request for') -@async_run_until_complete -async def step_oai_compute_embeddings(context): - context.n_prompts = 1 - context.embeddings = await request_oai_embeddings(context_text(context), None, - base_url=context.base_url, - user_api_key=context.user_api_key, - model=context.model) - - -@step('an OAI compatible embeddings computation request for multiple inputs') -@async_run_until_complete -async def step_oai_compute_embeddings_multiple_inputs(context): - context.embeddings = await request_oai_embeddings(context.prompts, None, - base_url=context.base_url, - user_api_key=context.user_api_key, - model=context.model) - context.prompts.clear() - - -@step('concurrent embedding requests') -@async_run_until_complete() -async def step_concurrent_embedding_requests(context): - await concurrent_requests(context, - request_embedding, - # prompt is inserted automatically - base_url=context.base_url) - - -@step('concurrent OAI embedding requests') -@async_run_until_complete() -async def step_concurrent_oai_embedding_requests(context): - await concurrent_requests(context, - request_oai_embeddings, - # prompt is inserted automatically - base_url=context.base_url, - async_client=True, - model=context.model) - - -@step('all embeddings are generated') -@async_run_until_complete() -async def all_embeddings_are_generated(context): - n_embedding_requests = await gather_tasks_results(context) - assert n_embedding_requests == context.n_prompts - for i in range(n_embedding_requests): - assert_embeddings(context.tasks_result.pop().pop()) - - -@step('adding special tokens') -def step_tokenize_set_add_special(context): - context.tokenize_add_special = True - - -@step('tokenizing') -@async_run_until_complete -async def step_tokenize(context): - context.tokenized_text = context_text(context) - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - tokenize_args = { - "content": context.tokenized_text, - } - if getattr(context, 'tokenize_add_special', None) is not None: - tokenize_args['add_special'] = context.tokenize_add_special - async with session.post(f'{context.base_url}/tokenize', - json=tokenize_args) as response: - assert response.status == 200 - tokenize_json = await response.json() - context.tokens = tokenize_json['tokens'] - - -@step('tokens can be detokenized') -@async_run_until_complete -async def step_detokenize(context): - assert len(context.tokens) > 0 - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/detokenize', - json={ - "tokens": context.tokens, - }) as response: - assert response.status == 200 - detokenize_json = await response.json() - # SPM tokenizer adds a whitespace prefix: https://github.com/google/sentencepiece/issues/15 - assert context.tokenized_text == detokenize_json['content'].strip() - - -@step('tokens begin with BOS') -def step_strings_for_tokenization(context): - assert context.tokens[0] == context.bos - - -@step('tokens do not begin with BOS') -def step_strings_for_tokenization(context): - assert context.tokens[0] != context.bos - - -@step('first token is removed') -def step_strings_for_tokenization(context): - context.tokens = context.tokens[1:] - - -@step('an OPTIONS request is sent from {origin}') -@async_run_until_complete -async def step_options_request(context, origin): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - headers = {'Authorization': f'Bearer {context.user_api_key}', 'Origin': origin} - async with session.options(f'{context.base_url}/v1/chat/completions', - headers=headers) as response: - assert response.status == 200 - context.options_response = response - - -@step('CORS header {cors_header} is set to {cors_header_value}') -def step_check_options_header_value(context, cors_header, cors_header_value): - assert context.options_response.headers[cors_header] == cors_header_value - - -@step('prometheus metrics are exposed') -@async_run_until_complete -async def step_prometheus_metrics_exported(context): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with await session.get(f'{context.base_url}/metrics') as metrics_response: - assert metrics_response.status == 200 - assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4" - metrics_raw = await metrics_response.text() - metric_exported = False - if context.debug: - print(f"/metrics answer:\n{metrics_raw}") - context.metrics = {} - for metric in parser.text_string_to_metric_families(metrics_raw): - match metric.name: - case "llamacpp:kv_cache_usage_ratio": - assert len(metric.samples) > 0 - metric_exported = True - context.metrics[metric.name] = metric - assert int(metrics_response.headers["Process-Start-Time-Unix"]) > 0, "no header process start time" - assert metric_exported, "No metrics exported" - - -@step('metric {metric_name} is {metric_value:d}') -def step_assert_metric_value(context, metric_name, metric_value): - if metric_name not in context.metrics: - assert False, f"no metric {metric_name} in {context.metrics.keys()}" - assert context.metrics[metric_name].samples[0].value == metric_value, f"metric: {context.metrics[metric_name]}" - - -@step('available models') -def step_available_models(context): - # openai client always expects an api_key - openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope' - openai.base_url = f'{context.base_url}/v1/' - context.models = openai.models.list().data - - -@step('{n_model:d} models are supported') -def step_supported_models(context, n_model): - if context.debug: - print("server models available:", context.models) - assert len(context.models) == n_model - - -@step('model {i_model:d} is {param} {preposition} {param_value}') -def step_supported_models(context, i_model: int, param: Literal['identified', 'trained'] | str, preposition: str, param_value: str): - assert i_model < len(context.models) - model = context.models[i_model] - - param_value = param_value.split(' ', 1)[0] - match param: - case 'identified': - value = model.id - case 'trained': - value = str(model.meta["n_ctx_train"]) - case _: - assert False, "param {param} not supported" - assert param_value == value, f"model param {param} {value} != {param_value}" - - -async def concurrent_requests(context, f_completion, *args, **kwargs): - context.n_prompts = len(context.prompts) - if context.debug: - print(f"starting {context.n_prompts} concurrent completion requests...") - assert context.n_prompts > 0 - seeds = await completions_seed(context) - assert seeds is not None - for prompt_no in range(context.n_prompts): - shifted_args = [context.prompts.pop(), seeds[prompt_no], *args] - context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs))) - await asyncio.sleep(0.01) - - -@step('the slot {slot_id:d} is saved with filename "{filename}"') -@async_run_until_complete -async def step_save_slot(context, slot_id, filename): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/slots/{slot_id}?action=save', - json={"filename": filename}, - headers={"Content-Type": "application/json"}) as response: - context.response = response - - -@step('the slot {slot_id:d} is restored with filename "{filename}"') -@async_run_until_complete -async def step_restore_slot(context, slot_id, filename): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/slots/{slot_id}?action=restore', - json={"filename": filename}, - headers={"Content-Type": "application/json"}) as response: - context.response = response - - -@step('the slot {slot_id:d} is erased') -@async_run_until_complete -async def step_erase_slot(context, slot_id): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/slots/{slot_id}?action=erase', - headers={"Content-Type": "application/json"}) as response: - context.response = response - - -@step('switch {on_or_off} lora adapter {lora_id:d}') -@async_run_until_complete -async def toggle_lora_adapter(context, on_or_off: str, lora_id: int): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{context.base_url}/lora-adapters', - json=[{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}], - headers={"Content-Type": "application/json"}) as response: - context.response = response - print([{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}]) - - -@step('the server responds with status code {status_code:d}') -def step_server_responds_with_status_code(context, status_code): - assert context.response.status == status_code - - -async def request_completion(prompt, - seed, - base_url, - debug=False, - prompt_prefix=None, - prompt_suffix=None, - n_predict=None, - cache_prompt=False, - id_slot=None, - expect_api_error=None, - user_api_key=None, - temperature=None) -> int | dict[str, Any]: - if debug: - print(f"Sending completion request: {prompt}") - origin = "my.super.domain" - headers = { - 'Origin': origin - } - if user_api_key is not None: - if debug: - print(f"Set user_api_key: {user_api_key}") - headers['Authorization'] = f'Bearer {user_api_key}' - - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{base_url}/completion', - json={ - "input_prefix": prompt_prefix, - "prompt": prompt, - "input_suffix": prompt_suffix, - "n_predict": n_predict if n_predict is not None else -1, - "cache_prompt": cache_prompt, - "id_slot": id_slot, - "seed": seed if seed is not None else 42, - "temperature": temperature if temperature is not None else 0.8, - "n_probs": 2, - }, - headers=headers) as response: - if expect_api_error is None or not expect_api_error: - assert response.status == 200 - assert response.headers['Access-Control-Allow-Origin'] == origin - return await response.json() - else: - return response.status - - -async def oai_chat_completions(user_prompt, - seed, - system_prompt, - base_url: str, - base_path: str, - async_client, - debug=False, - temperature=None, - model=None, - n_predict=None, - enable_streaming=None, - response_format=None, - user_api_key=None, - expect_api_error=None) -> int | dict[str, Any]: - if debug: - print(f"Sending OAI Chat completions request: {user_prompt}") - # openai client always expects an api key - user_api_key = user_api_key if user_api_key is not None else 'nope' - seed = seed if seed is not None else 42 - enable_streaming = enable_streaming if enable_streaming is not None else False - payload = { - "messages": [ - { - "role": "system", - "content": system_prompt, - }, - { - "role": "user", - "content": user_prompt, - } - ], - "model": model, - "max_tokens": n_predict, - "stream": enable_streaming, - "temperature": temperature if temperature is not None else 0.0, - "seed": seed, - } - if response_format is not None: - payload['response_format'] = response_format - completion_response = { - 'content': '', - 'timings': { - 'predicted_n': 0, - 'prompt_n': 0 - } - } - if async_client: - origin = 'llama.cpp' - headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{base_url}{base_path}', - json=payload, - headers=headers) as response: - if enable_streaming: - assert response.status == 200 - assert response.headers['Access-Control-Allow-Origin'] == origin - assert response.headers['Content-Type'] == "text/event-stream" - event_received = True - while event_received: - event_received = False - async for line_in_bytes in response.content: - line = line_in_bytes.decode('utf-8') - line = line.rstrip('\n').rstrip('\r') - if line == '': - continue - event_data = line.split(': ', 1) - assert event_data[0] == 'data', f'Bad event code received: ```{event_data}```' - chunk_raw = event_data[1] - - chunk = json.loads(chunk_raw) - assert len(chunk['choices']) == 1, f"no choices provided, line ```{line}```" - delta = chunk['choices'][0]['delta'] - if 'content' in delta: - completion_response['content'] += delta['content'] - completion_response['timings']['predicted_n'] += 1 - else: - if expect_api_error is None or not expect_api_error: - assert response.status == 200 - assert response.headers['Access-Control-Allow-Origin'] == origin - assert response.headers['Content-Type'] == "application/json; charset=utf-8" - chat_completion_raw = await response.json() - completion_response = { - 'content': chat_completion_raw['choices'][0]['message'], - 'timings': { - 'predicted_n': chat_completion_raw['usage']['completion_tokens'], - 'prompt_n': chat_completion_raw['usage']['prompt_tokens'] - } - } - else: - return response.status - else: - try: - openai.api_key = user_api_key - openai.base_url = f'{base_url}{base_path.removesuffix("chat")}' - assert model is not None - chat_completion = openai.chat.completions.create( - messages=payload['messages'], - model=model, - max_tokens=n_predict, - stream=enable_streaming, - response_format=payload.get('response_format') or openai.NOT_GIVEN, - seed=seed, - temperature=payload['temperature'] - ) - except openai.AuthenticationError as e: - if expect_api_error is not None and expect_api_error: - return 401 - else: - assert False, f'error raised: {e}' - - if enable_streaming: - chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion) - for chunk in chat_completion: - assert len(chunk.choices) == 1 - delta = chunk.choices[0].delta - if delta.content is not None: - completion_response['content'] += delta.content - completion_response['timings']['predicted_n'] += 1 - completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop' - else: - assert len(chat_completion.choices) == 1 - assert chat_completion.usage is not None - completion_response = { - 'content': chat_completion.choices[0].message.content, - 'timings': { - 'predicted_n': chat_completion.usage.completion_tokens, - 'prompt_n': chat_completion.usage.prompt_tokens - }, - 'truncated': chat_completion.choices[0].finish_reason != 'stop' - } - if debug: - print("OAI response formatted to llama.cpp:", completion_response) - return completion_response - - -async def request_embedding(content, seed, base_url=None) -> list[list[float]]: - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{base_url}/embedding', - json={ - "content": content, - }) as response: - assert response.status == 200 - response_json = await response.json() - return [response_json['embedding']] - - -async def request_oai_embeddings(input, seed, - base_url=None, user_api_key=None, - model=None, async_client=False) -> list[list[float]]: - # openai client always expects an api_key - user_api_key = user_api_key if user_api_key is not None else 'nope' - if async_client: - origin = 'llama.cpp' - headers=[] - if user_api_key is not None: - headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with session.post(f'{base_url}/v1/embeddings', - json={ - "input": input, - "model": model, - }, - headers=headers) as response: - assert response.status == 200, f"received status code not expected: {response.status}" - assert response.headers['Access-Control-Allow-Origin'] == origin - assert response.headers['Content-Type'] == "application/json; charset=utf-8" - response_json = await response.json() - assert response_json['model'] == model, f"invalid model received: {response_json['model']}" - assert response_json['object'] == 'list' - if isinstance(input, Sequence): - embeddings = [] - for an_oai_embeddings in response_json['data']: - embeddings.append(an_oai_embeddings['embedding']) - else: - embeddings = [response_json['data']['embedding']] - return embeddings - else: - openai.api_key = user_api_key - openai.base_url = f'{base_url}/v1/' - assert model is not None - oai_embeddings = openai.embeddings.create( - model=model, - input=input, - ) - - return [e.embedding for e in oai_embeddings.data] - - -def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None): - content = completion_response['content'] - n_predicted = completion_response['timings']['predicted_n'] - assert len(content) > 0, "no token predicted" - if re_content is not None: - p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL) - matches = p.finditer(content) - last_match = 0 - highlighted = '' - for match in matches: - start, end = match.span() - highlighted += content[last_match: start] - highlighted += '\x1b[33m' - highlighted += content[start: end] - highlighted += '\x1b[0m' - last_match = end - highlighted += content[last_match:] - if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': - print(f"Checking completion response: {highlighted}") - assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```' - if expected_predicted_n and expected_predicted_n > 0: - assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' - f' {n_predicted} <> {expected_predicted_n}') - -def assert_all_predictions_equal(completion_responses): - if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': - for i, response_i in enumerate(completion_responses): - content_i = response_i['content'] - print(f"content {i}: {content_i}") - for i, response_i in enumerate(completion_responses): - content_i = response_i['content'] - for j, response_j in enumerate(completion_responses): - if i == j: - continue - content_j = response_j['content'] - assert content_i == content_j, "contents not equal" - - -def assert_all_predictions_different(completion_responses): - if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': - for i, response_i in enumerate(completion_responses): - content_i = response_i['content'] - print(f"content {i}: {content_i}") - for i, response_i in enumerate(completion_responses): - content_i = response_i['content'] - for j, response_j in enumerate(completion_responses): - if i == j: - continue - content_j = response_j['content'] - assert content_i != content_j, "contents not different" - - -def assert_all_token_probabilities_equal(completion_responses): - n_predict = len(completion_responses[0]['completion_probabilities']) - if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON': - for pos in range(n_predict): - for i, response_i in enumerate(completion_responses): - probs_i = response_i['completion_probabilities'][pos]['probs'] - print(f"pos {pos}, probs {i}: {probs_i}") - for pos in range(n_predict): - for i, response_i in enumerate(completion_responses): - probs_i = response_i['completion_probabilities'][pos]['probs'] - for j, response_j in enumerate(completion_responses): - if i == j: - continue - probs_j = response_j['completion_probabilities'][pos]['probs'] - assert probs_i == probs_j, "contents not equal" - - -async def gather_tasks_results(context): - n_tasks = len(context.concurrent_tasks) - if context.debug: - print(f"Waiting for all {n_tasks} tasks results...") - for task_no in range(n_tasks): - context.tasks_result.append(await context.concurrent_tasks.pop()) - n_completions = len(context.tasks_result) - return n_completions - - -async def wait_for_slots_status(context, - base_url, - expected_http_status_code, - timeout=3, - params=None, - slots_idle=None, - slots_processing=None): - if context.debug: - print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}") - interval = 0.5 - counter = 0 - if 'GITHUB_ACTIONS' in os.environ: - timeout *= 2 - - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - while True: - async with await session.get(f'{base_url}/slots', params=params) as slots_response: - status_code = slots_response.status - slots = await slots_response.json() - if context.debug: - print(f"slots responses {slots}\n") - if status_code == 503 and status_code == expected_http_status_code: - return - if status_code == 200 and status_code == expected_http_status_code: - n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots) - n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots) - if ((slots_idle is None or slots_idle == n_slots_idle) - and (slots_processing is None or slots_processing == n_slots_processing)): - return - await asyncio.sleep(interval) - - counter += interval - if counter >= timeout: - # Sometimes health requests are triggered after completions are predicted - if expected_http_status_code == 503: - if len(context.tasks_result) == 0: - print("\x1b[5;37;43mWARNING: forcing concurrent tasks," - " busy health check missed, probably too fast inference\x1b[0m\n") - n_completions = await gather_tasks_results(context) - if n_completions > 0: - return - - assert False, f'slots check timeout exceeded {counter}s>={timeout}' - - -def assert_embeddings(embeddings): - assert len(embeddings) > 0 - embeddings_computed = False - for emb in embeddings: - if not isinstance(emb, float): - assert False, f"Bad embeddings: {embeddings}" - if emb != 0: - embeddings_computed = True - assert embeddings_computed, f"Embeddings: {embeddings}" - - -async def request_slots_status(context, expected_slots): - async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: - async with await session.get(f'{context.base_url}/slots') as slots_response: - assert slots_response.status == 200 - slots = await slots_response.json() - assert_slots_status(slots, expected_slots) - - -def assert_slots_status(slots, expected_slots): - assert len(slots) == len(expected_slots) - for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)): - for key in expected: - assert expected[key] == slot[key], (f"invalid slot {slot_id}" - f" expected[{key}] != slot[{key}]" - f" = {expected[key]} != {slot[key]}") - - -async def completions_seed(context, num_seeds=None): - if hasattr(context, "seed") and context.seed is not None: - assert len(context.seed) == context.n_prompts - if num_seeds is None: - num_seeds = context.n_prompts - assert num_seeds <= context.n_prompts - seeds = context.seed[:num_seeds] - context.seed = context.seed[num_seeds:] if num_seeds < context.n_prompts else None - return seeds - - if hasattr(context, "server_seed") and context.server_seed is not None: - if num_seeds is None: - return [context.server_seed] * context.n_prompts - else: - return [context.server_seed] * num_seeds - return None - - -def context_text(context): - return context.text.replace('\r', '') - - -def start_server_background(context): - if os.name == 'nt': - context.server_path = '../../../build/bin/Release/llama-server.exe' - else: - context.server_path = '../../../build/bin/llama-server' - if 'LLAMA_SERVER_BIN_PATH' in os.environ: - context.server_path = os.environ['LLAMA_SERVER_BIN_PATH'] - server_listen_addr = context.server_fqdn - server_args = [ - '--host', server_listen_addr, - '--port', context.server_port, - ] - if context.model_file: - server_args.extend(['--model', context.model_file]) - if context.model_url: - server_args.extend(['--model-url', context.model_url]) - if context.model_hf_repo: - server_args.extend(['--hf-repo', context.model_hf_repo]) - if context.model_hf_file: - server_args.extend(['--hf-file', context.model_hf_file]) - if context.n_batch: - server_args.extend(['--batch-size', context.n_batch]) - if context.n_ubatch: - server_args.extend(['--ubatch-size', context.n_ubatch]) - if context.n_threads: - server_args.extend(['--threads', context.threads]) - if context.n_gpu_layer: - server_args.extend(['--n-gpu-layers', context.n_gpu_layer]) - if context.draft is not None: - server_args.extend(['--draft', context.draft]) - if context.server_continuous_batching: - server_args.append('--cont-batching') - if context.server_embeddings: - server_args.append('--embedding') - if context.server_metrics: - server_args.append('--metrics') - if context.model_alias: - server_args.extend(['--alias', context.model_alias]) - if context.n_ctx: - server_args.extend(['--ctx-size', context.n_ctx]) - if context.n_slots: - server_args.extend(['--parallel', context.n_slots]) - if context.n_server_predict: - server_args.extend(['--n-predict', context.n_server_predict]) - if context.slot_save_path: - server_args.extend(['--slot-save-path', context.slot_save_path]) - if context.server_api_key: - server_args.extend(['--api-key', context.server_api_key]) - if context.n_ga: - server_args.extend(['--grp-attn-n', context.n_ga]) - if context.n_ga_w: - server_args.extend(['--grp-attn-w', context.n_ga_w]) - if context.debug: - server_args.append('--verbose') - if context.lora_file: - server_args.extend(['--lora', context.lora_file]) - if 'SERVER_LOG_FORMAT_JSON' not in os.environ: - server_args.extend(['--log-format', "text"]) - - args = [str(arg) for arg in [context.server_path, *server_args]] - print(f"bench: starting server with: {' '.join(args)}") - - flags = 0 - if 'nt' == os.name: - flags |= subprocess.DETACHED_PROCESS - flags |= subprocess.CREATE_NEW_PROCESS_GROUP - flags |= subprocess.CREATE_NO_WINDOW - - pkwargs = { - 'creationflags': flags, - 'stdout': subprocess.PIPE, - 'stderr': subprocess.PIPE - } - context.server_process = subprocess.Popen( - [str(arg) for arg in [context.server_path, *server_args]], - **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue] - - def server_log(in_stream, out_stream): - for line in iter(in_stream.readline, b''): - print(line.decode('utf-8'), end='', file=out_stream) - - thread_stdout = threading.Thread(target=server_log, args=(context.server_process.stdout, sys.stdout)) - thread_stdout.start() - - thread_stderr = threading.Thread(target=server_log, args=(context.server_process.stderr, sys.stderr)) - thread_stderr.start() - - print(f"server pid={context.server_process.pid}, behave pid={os.getpid()}") diff --git a/examples/server/tests/features/wrong_usages.feature b/examples/server/tests/features/wrong_usages.feature deleted file mode 100644 index 61d5f315e..000000000 --- a/examples/server/tests/features/wrong_usages.feature +++ /dev/null @@ -1,25 +0,0 @@ -# run with: ./tests.sh --no-skipped --tags wrong_usage -@wrong_usage -Feature: Wrong usage of llama.cpp server - - #3969 The user must always set --n-predict option - # to cap the number of tokens any completion request can generate - # or pass n_predict/max_tokens in the request. - Scenario: Infinite loop - Given a server listening on localhost:8080 - And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models - And 42 as server seed - And 2048 KV cache size - # Uncomment below to fix the issue - #And 64 server max tokens to predict - Then the server is starting - Then the server is healthy - Given a prompt: - """ - Go to: infinite loop - """ - # Uncomment below to fix the issue - #And 128 max tokens to predict - Given concurrent completion requests - Then the server is idle - Then all prompts are predicted diff --git a/examples/server/tests/pytest.ini b/examples/server/tests/pytest.ini new file mode 100644 index 000000000..6df308df7 --- /dev/null +++ b/examples/server/tests/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +markers = + slow: marks tests as slow (deselect with '-m "not slow"') + serial diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index f2d7e5c57..15d024914 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -1,7 +1,8 @@ aiohttp~=3.9.3 -behave~=1.2.6 -huggingface_hub~=0.20.3 +pytest~=8.3.3 +huggingface_hub~=0.23.2 numpy~=1.26.4 -openai~=1.30.3 +openai~=1.55.3 prometheus-client~=0.20.0 requests~=2.32.3 +wget~=3.2 diff --git a/examples/server/tests/tests.sh b/examples/server/tests/tests.sh index 72a0fbad8..33fa8cc64 100755 --- a/examples/server/tests/tests.sh +++ b/examples/server/tests/tests.sh @@ -1,11 +1,23 @@ #!/bin/bash +# make sure we are in the right directory +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + set -eu +if [[ "${SLOW_TESTS:-0}" == 1 ]]; then + # Slow tests for tool calls need quite a few models ahead of time to avoid timing out. + python $SCRIPT_DIR/../../../scripts/fetch_server_test_models.py +fi + if [ $# -lt 1 ] then - # Start @llama.cpp scenario - behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp + if [[ "${SLOW_TESTS:-0}" == 1 ]]; then + pytest -v -x + else + pytest -v -x -m "not slow" + fi else - behave "$@" + pytest "$@" fi diff --git a/examples/server/tests/unit/test_basic.py b/examples/server/tests/unit/test_basic.py new file mode 100644 index 000000000..1485de8ce --- /dev/null +++ b/examples/server/tests/unit/test_basic.py @@ -0,0 +1,96 @@ +import pytest +import requests +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +def test_server_start_simple(): + global server + server.start() + res = server.make_request("GET", "/health") + assert res.status_code == 200 + + +def test_server_props(): + global server + server.start() + res = server.make_request("GET", "/props") + assert res.status_code == 200 + assert ".gguf" in res.body["model_path"] + assert res.body["total_slots"] == server.n_slots + default_val = res.body["default_generation_settings"] + assert server.n_ctx is not None and server.n_slots is not None + assert default_val["n_ctx"] == server.n_ctx / server.n_slots + assert default_val["params"]["seed"] == server.seed + + +def test_server_models(): + global server + server.start() + res = server.make_request("GET", "/models") + assert res.status_code == 200 + assert len(res.body["data"]) == 1 + assert res.body["data"][0]["id"] == server.model_alias + + +def test_server_slots(): + global server + + # without slots endpoint enabled, this should return error + server.server_slots = False + server.start() + res = server.make_request("GET", "/slots") + assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED + assert "error" in res.body + server.stop() + + # with slots endpoint enabled, this should return slots info + server.server_slots = True + server.n_slots = 2 + server.start() + res = server.make_request("GET", "/slots") + assert res.status_code == 200 + assert len(res.body) == server.n_slots + assert server.n_ctx is not None and server.n_slots is not None + assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots + assert "params" in res.body[0] + assert res.body[0]["params"]["seed"] == server.seed + + +def test_load_split_model(): + global server + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf" + server.model_alias = "tinyllama-split" + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 16, + "prompt": "Hello", + "temperature": 0.0, + }) + assert res.status_code == 200 + assert match_regex("(little|girl)+", res.body["content"]) + + +def test_no_webui(): + global server + # default: webui enabled + server.start() + url = f"http://{server.server_host}:{server.server_port}" + res = requests.get(url) + assert res.status_code == 200 + assert "" in res.text + server.stop() + + # with --no-webui + server.no_webui = True + server.start() + res = requests.get(url) + assert res.status_code == 404 diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py new file mode 100644 index 000000000..0be04bab5 --- /dev/null +++ b/examples/server/tests/unit/test_chat_completion.py @@ -0,0 +1,265 @@ +import pytest +from openai import OpenAI +from utils import * + +server: ServerProcess + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +@pytest.mark.parametrize( + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template", + [ + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), + # TODO: fix testing of non-tool jinja mode + # (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), + # (None, "Book", "What is the best book", 8, "I want to play with", 23, 8, "length", True, "This is not a chat template, it is"), + # ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), + ] +) +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): + global server + server.jinja = jinja + server.chat_template = chat_template + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "model": model, + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + }) + assert res.status_code == 200 + assert "cmpl" in res.body["id"] # make sure the completion id has the expected format + assert res.body["system_fingerprint"].startswith("b") + assert res.body["model"] == model if model is not None else server.model_alias + assert res.body["usage"]["prompt_tokens"] == n_prompt + assert res.body["usage"]["completion_tokens"] == n_predicted + choice = res.body["choices"][0] + assert "assistant" == choice["message"]["role"] + assert match_regex(re_content, choice["message"]["content"]) + assert choice["finish_reason"] == finish_reason + + +@pytest.mark.parametrize( + "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", + [ + ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), + ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), + ] +) +def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): + global server + server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL + server.start() + res = server.make_stream_request("POST", "/chat/completions", data={ + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "stream": True, + }) + content = "" + last_cmpl_id = None + for data in res: + choice = data["choices"][0] + assert data["system_fingerprint"].startswith("b") + assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future + if last_cmpl_id is None: + last_cmpl_id = data["id"] + assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream + if choice["finish_reason"] in ["stop", "length"]: + assert data["usage"]["prompt_tokens"] == n_prompt + assert data["usage"]["completion_tokens"] == n_predicted + assert "content" not in choice["delta"] + assert match_regex(re_content, content) + assert choice["finish_reason"] == finish_reason + else: + assert choice["finish_reason"] is None + content += choice["delta"]["content"] + + +def test_chat_completion_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=8, + seed=42, + temperature=0.8, + ) + assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b") + assert res.choices[0].finish_reason == "length" + assert res.choices[0].message.content is not None + assert match_regex("(Suddenly)+", res.choices[0].message.content) + + +def test_chat_template(): + global server + server.chat_template = "llama3" + server.debug = True # to get the "__verbose" object in the response + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ] + }) + assert res.status_code == 200 + assert "__verbose" in res.body + assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + + +def test_apply_chat_template(): + global server + server.chat_template = "command-r" + server.start() + res = server.make_request("POST", "/apply-template", data={ + "messages": [ + {"role": "system", "content": "You are a test."}, + {"role": "user", "content":"Hi there"}, + ] + }) + assert res.status_code == 200 + assert "prompt" in res.body + assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + + +@pytest.mark.parametrize("response_format,n_predicted,re_content", [ + ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), + ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), + ({"type": "json_object"}, 10, "(\\{|John)+"), + ({"type": "sound"}, 0, None), + # invalid response format (expected to fail) + ({"type": "json_object", "schema": 123}, 0, None), + ({"type": "json_object", "schema": {"type": 123}}, 0, None), + ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None), +]) +def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None): + global server + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "response_format": response_format, + }) + if re_content is not None: + assert res.status_code == 200 + choice = res.body["choices"][0] + assert match_regex(re_content, choice["message"]["content"]) + else: + assert res.status_code != 200 + assert "error" in res.body + + +@pytest.mark.parametrize("messages", [ + None, + "string", + [123], + [{}], + [{"role": 123}], + [{"role": "system", "content": 123}], + # [{"content": "hello"}], # TODO: should not be a valid case + [{"role": "system", "content": "test"}, {}], +]) +def test_invalid_chat_completion_req(messages): + global server + server.start() + res = server.make_request("POST", "/chat/completions", data={ + "messages": messages, + }) + assert res.status_code == 400 or res.status_code == 500 + assert "error" in res.body + + +def test_chat_completion_with_timings_per_token(): + global server + server.start() + res = server.make_stream_request("POST", "/chat/completions", data={ + "max_tokens": 10, + "messages": [{"role": "user", "content": "test"}], + "stream": True, + "timings_per_token": True, + }) + for data in res: + assert "timings" in data + assert "prompt_per_second" in data["timings"] + assert "predicted_per_second" in data["timings"] + assert "predicted_n" in data["timings"] + assert data["timings"]["predicted_n"] <= 10 + + +def test_logprobs(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=5, + logprobs=True, + top_logprobs=10, + ) + output_text = res.choices[0].message.content + aggregated_text = '' + assert res.choices[0].logprobs is not None + assert res.choices[0].logprobs.content is not None + for token in res.choices[0].logprobs.content: + aggregated_text += token.token + assert token.logprob <= 0.0 + assert token.bytes is not None + assert len(token.top_logprobs) > 0 + assert aggregated_text == output_text + + +def test_logprobs_stream(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=5, + logprobs=True, + top_logprobs=10, + stream=True, + ) + output_text = '' + aggregated_text = '' + for data in res: + choice = data.choices[0] + if choice.finish_reason is None: + if choice.delta.content: + output_text += choice.delta.content + assert choice.logprobs is not None + assert choice.logprobs.content is not None + for token in choice.logprobs.content: + aggregated_text += token.token + assert token.logprob <= 0.0 + assert token.bytes is not None + assert token.top_logprobs is not None + assert len(token.top_logprobs) > 0 + assert aggregated_text == output_text diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py new file mode 100644 index 000000000..0ed5b99be --- /dev/null +++ b/examples/server/tests/unit/test_completion.py @@ -0,0 +1,428 @@ +import pytest +import requests +import time +from openai import OpenAI +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + +@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), +]) +def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": prompt, + "return_tokens": return_tokens, + }) + assert res.status_code == 200 + assert res.body["timings"]["prompt_n"] == n_prompt + assert res.body["timings"]["predicted_n"] == n_predicted + assert res.body["truncated"] == truncated + assert type(res.body["has_new_line"]) == bool + assert match_regex(re_content, res.body["content"]) + if return_tokens: + assert len(res.body["tokens"]) > 0 + assert all(type(tok) == int for tok in res.body["tokens"]) + else: + assert res.body["tokens"] == [] + + +@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), + ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), +]) +def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): + global server + server.start() + res = server.make_stream_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": prompt, + "stream": True, + }) + content = "" + for data in res: + assert "stop" in data and type(data["stop"]) == bool + if data["stop"]: + assert data["timings"]["prompt_n"] == n_prompt + assert data["timings"]["predicted_n"] == n_predicted + assert data["truncated"] == truncated + assert data["stop_type"] == "limit" + assert type(data["has_new_line"]) == bool + assert "generation_settings" in data + assert server.n_predict is not None + assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict) + assert data["generation_settings"]["seed"] == server.seed + assert match_regex(re_content, content) + else: + assert len(data["tokens"]) > 0 + assert all(type(tok) == int for tok in data["tokens"]) + content += data["content"] + + +def test_completion_stream_vs_non_stream(): + global server + server.start() + res_stream = server.make_stream_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + "stream": True, + }) + res_non_stream = server.make_request("POST", "/completion", data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + }) + content_stream = "" + for data in res_stream: + content_stream += data["content"] + assert content_stream == res_non_stream.body["content"] + + +def test_completion_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8, + ) + assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b") + assert res.choices[0].finish_reason == "length" + assert res.choices[0].text is not None + assert match_regex("(going|bed)+", res.choices[0].text) + + +def test_completion_stream_with_openai_library(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.completions.create( + model="davinci-002", + prompt="I believe the meaning of life is", + max_tokens=8, + stream=True, + ) + output_text = '' + for data in res: + choice = data.choices[0] + if choice.finish_reason is None: + assert choice.text is not None + output_text += choice.text + assert match_regex("(going|bed)+", output_text) + + +@pytest.mark.parametrize("n_slots", [1, 2]) +def test_consistent_result_same_seed(n_slots: int): + global server + server.n_slots = n_slots + server.start() + last_res = None + for _ in range(4): + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 0.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + if last_res is not None: + assert res.body["content"] == last_res.body["content"] + last_res = res + + +@pytest.mark.parametrize("n_slots", [1, 2]) +def test_different_result_different_seed(n_slots: int): + global server + server.n_slots = n_slots + server.start() + last_res = None + for seed in range(4): + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": seed, + "temperature": 1.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + if last_res is not None: + assert res.body["content"] != last_res.body["content"] + last_res = res + +# TODO figure why it don't work with temperature = 1 +# @pytest.mark.parametrize("temperature", [0.0, 1.0]) +@pytest.mark.parametrize("n_batch", [16, 32]) +@pytest.mark.parametrize("temperature", [0.0]) +def test_consistent_result_different_batch_size(n_batch: int, temperature: float): + global server + server.n_batch = n_batch + server.start() + last_res = None + for _ in range(4): + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": temperature, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + if last_res is not None: + assert res.body["content"] == last_res.body["content"] + last_res = res + + +@pytest.mark.skip(reason="This test fails on linux, need to be fixed") +def test_cache_vs_nocache_prompt(): + global server + server.start() + res_cache = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 1.0, + "cache_prompt": True, + }) + res_no_cache = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 1.0, + "cache_prompt": False, + }) + assert res_cache.body["content"] == res_no_cache.body["content"] + + +def test_completion_with_tokens_input(): + global server + server.temperature = 0.0 + server.start() + prompt_str = "I believe the meaning of life is" + res = server.make_request("POST", "/tokenize", data={ + "content": prompt_str, + "add_special": True, + }) + assert res.status_code == 200 + tokens = res.body["tokens"] + + # single completion + res = server.make_request("POST", "/completion", data={ + "prompt": tokens, + }) + assert res.status_code == 200 + assert type(res.body["content"]) == str + + # batch completion + res = server.make_request("POST", "/completion", data={ + "prompt": [tokens, tokens], + }) + assert res.status_code == 200 + assert type(res.body) == list + assert len(res.body) == 2 + assert res.body[0]["content"] == res.body[1]["content"] + + # mixed string and tokens + res = server.make_request("POST", "/completion", data={ + "prompt": [tokens, prompt_str], + }) + assert res.status_code == 200 + assert type(res.body) == list + assert len(res.body) == 2 + assert res.body[0]["content"] == res.body[1]["content"] + + # mixed string and tokens in one sequence + res = server.make_request("POST", "/completion", data={ + "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], + }) + assert res.status_code == 200 + assert type(res.body["content"]) == str + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (1, 3), + (2, 2), + (2, 4), + (4, 2), # some slots must be idle + (4, 6), +]) +def test_completion_parallel_slots(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.temperature = 0.0 + server.start() + + PROMPTS = [ + ("Write a very long book.", "(very|special|big)+"), + ("Write another a poem.", "(small|house)+"), + ("What is LLM?", "(Dad|said)+"), + ("The sky is blue and I love it.", "(climb|leaf)+"), + ("Write another very long music lyrics.", "(friends|step|sky)+"), + ("Write a very long joke.", "(cat|Whiskers)+"), + ] + def check_slots_status(): + should_all_slots_busy = n_requests >= n_slots + time.sleep(0.1) + res = server.make_request("GET", "/slots") + n_busy = sum([1 for slot in res.body if slot["is_processing"]]) + if should_all_slots_busy: + assert n_busy == n_slots + else: + assert n_busy <= n_slots + + tasks = [] + for i in range(n_requests): + prompt, re_content = PROMPTS[i % len(PROMPTS)] + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": prompt, + "seed": 42, + "temperature": 1.0, + }))) + tasks.append((check_slots_status, ())) + results = parallel_function_calls(tasks) + + # check results + for i in range(n_requests): + prompt, re_content = PROMPTS[i % len(PROMPTS)] + res = results[i] + assert res.status_code == 200 + assert type(res.body["content"]) == str + assert len(res.body["content"]) > 10 + # FIXME: the result is not deterministic when using other slot than slot 0 + # assert match_regex(re_content, res.body["content"]) + + +@pytest.mark.parametrize( + "prompt,n_predict,response_fields", + [ + ("I believe the meaning of life is", 8, []), + ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]), + ], +) +def test_completion_response_fields( + prompt: str, n_predict: int, response_fields: list[str] +): + global server + server.start() + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": n_predict, + "prompt": prompt, + "response_fields": response_fields, + }, + ) + assert res.status_code == 200 + assert "content" in res.body + assert len(res.body["content"]) + if len(response_fields): + assert res.body["generation_settings/n_predict"] == n_predict + assert res.body["prompt"] == " " + prompt + assert isinstance(res.body["content"], str) + assert len(res.body) == len(response_fields) + else: + assert len(res.body) + assert "generation_settings" in res.body + + +def test_n_probs(): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + }) + assert res.status_code == 200 + assert "completion_probabilities" in res.body + assert len(res.body["completion_probabilities"]) == 5 + for tok in res.body["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "logprob" in tok and tok["logprob"] <= 0.0 + assert "bytes" in tok and type(tok["bytes"]) == list + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "logprob" in prob and prob["logprob"] <= 0.0 + assert "bytes" in prob and type(prob["bytes"]) == list + + +def test_n_probs_stream(): + global server + server.start() + res = server.make_stream_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "stream": True, + }) + for data in res: + if data["stop"] == False: + assert "completion_probabilities" in data + assert len(data["completion_probabilities"]) == 1 + for tok in data["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "logprob" in tok and tok["logprob"] <= 0.0 + assert "bytes" in tok and type(tok["bytes"]) == list + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "logprob" in prob and prob["logprob"] <= 0.0 + assert "bytes" in prob and type(prob["bytes"]) == list + + +def test_n_probs_post_sampling(): + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "post_sampling_probs": True, + }) + assert res.status_code == 200 + assert "completion_probabilities" in res.body + assert len(res.body["completion_probabilities"]) == 5 + for tok in res.body["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "prob" in tok and 0.0 < tok["prob"] <= 1.0 + assert "bytes" in tok and type(tok["bytes"]) == list + assert len(tok["top_probs"]) == 10 + for prob in tok["top_probs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 + assert "bytes" in prob and type(prob["bytes"]) == list + # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs + assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) + + +def test_cancel_request(): + global server + server.n_ctx = 4096 + server.n_predict = -1 + server.n_slots = 1 + server.server_slots = True + server.start() + # send a request that will take a long time, but cancel it before it finishes + try: + server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + }, timeout=0.1) + except requests.exceptions.ReadTimeout: + pass # expected + # make sure the slot is free + time.sleep(1) # wait for HTTP_POLLING_SECONDS + res = server.make_request("GET", "/slots") + assert res.body[0]["is_processing"] == False diff --git a/examples/server/tests/unit/test_ctx_shift.py b/examples/server/tests/unit/test_ctx_shift.py new file mode 100644 index 000000000..be93a6d31 --- /dev/null +++ b/examples/server/tests/unit/test_ctx_shift.py @@ -0,0 +1,67 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +LONG_TEXT = """ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. +Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. +Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. +Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. +""".strip() + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.n_ctx = 256 + server.n_slots = 2 + + +def test_ctx_shift_enabled(): + # the prompt is 301 tokens + # the slot context is 256/2 = 128 tokens + # the prompt is truncated to keep the last 109 tokens + # 64 tokens are generated thanks to shifting the context when it gets full + global server + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 64, + "prompt": LONG_TEXT, + }) + assert res.status_code == 200 + assert res.body["timings"]["prompt_n"] == 109 + assert res.body["timings"]["predicted_n"] == 64 + assert res.body["truncated"] is True + + +@pytest.mark.parametrize("n_predict,n_token_output,truncated", [ + (64, 64, False), + (-1, 120, True), +]) +def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): + global server + server.disable_ctx_shift = True + server.n_predict = -1 + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": n_predict, + "prompt": "Hi how are you", + }) + assert res.status_code == 200 + assert res.body["timings"]["predicted_n"] == n_token_output + assert res.body["truncated"] == truncated + + +def test_ctx_shift_disabled_long_prompt(): + global server + server.disable_ctx_shift = True + server.start() + res = server.make_request("POST", "/completion", data={ + "n_predict": 64, + "prompt": LONG_TEXT, + }) + assert res.status_code != 200 + assert "error" in res.body + assert "exceeds the available context size" in res.body["error"]["message"] diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py new file mode 100644 index 000000000..8b0eb42b0 --- /dev/null +++ b/examples/server/tests/unit/test_embedding.py @@ -0,0 +1,237 @@ +import base64 +import struct +import pytest +from openai import OpenAI +from utils import * + +server = ServerPreset.bert_bge_small() + +EPSILON = 1e-3 + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.bert_bge_small() + + +def test_embedding_single(): + global server + server.pooling = 'last' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "I believe the meaning of life is", + }) + assert res.status_code == 200 + assert len(res.body['data']) == 1 + assert 'embedding' in res.body['data'][0] + assert len(res.body['data'][0]['embedding']) > 1 + + # make sure embedding vector is normalized + assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON + + +def test_embedding_multiple(): + global server + server.pooling = 'last' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "Write a joke about AI from a very long prompt which will not be truncated", + "This is a test", + "This is another test", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + +@pytest.mark.parametrize( + "input,is_multi_prompt", + [ + # do not crash on empty input + ("", False), + # single prompt + ("string", False), + ([12, 34, 56], False), + ([12, 34, "string", 56, 78], False), + # multiple prompts + (["string1", "string2"], True), + (["string1", [12, 34, 56]], True), + ([[12, 34, 56], [12, 34, 56]], True), + ([[12, 34, 56], [12, "string", 34, 56]], True), + ] +) +def test_embedding_mixed_input(input, is_multi_prompt: bool): + global server + server.start() + res = server.make_request("POST", "/v1/embeddings", data={"input": input}) + assert res.status_code == 200 + data = res.body['data'] + if is_multi_prompt: + assert len(data) == len(input) + for d in data: + assert 'embedding' in d + assert len(d['embedding']) > 1 + else: + assert 'embedding' in data[0] + assert len(data[0]['embedding']) > 1 + + +def test_embedding_pooling_none(): + global server + server.pooling = 'none' + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": "hello hello hello", + }) + assert res.status_code == 200 + assert 'embedding' in res.body[0] + assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special + + # make sure embedding vector is not normalized + for x in res.body[0]['embedding']: + assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON + + +def test_embedding_pooling_none_oai(): + global server + server.pooling = 'none' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "hello hello hello", + }) + + # /v1/embeddings does not support pooling type 'none' + assert res.status_code == 400 + assert "error" in res.body + + +def test_embedding_openai_library_single(): + global server + server.pooling = 'last' + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is") + assert len(res.data) == 1 + assert len(res.data[0].embedding) > 1 + + +def test_embedding_openai_library_multiple(): + global server + server.pooling = 'last' + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.embeddings.create(model="text-embedding-3-small", input=[ + "I believe the meaning of life is", + "Write a joke about AI from a very long prompt which will not be truncated", + "This is a test", + "This is another test", + ]) + assert len(res.data) == 4 + for d in res.data: + assert len(d.embedding) > 1 + + +def test_embedding_error_prompt_too_long(): + global server + server.pooling = 'last' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": "This is a test " * 512, + }) + assert res.status_code != 200 + assert "too large" in res.body["error"]["message"] + + +def test_same_prompt_give_same_result(): + server.pooling = 'last' + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 5 + for i in range(1, len(res.body['data'])): + v0 = res.body['data'][0]['embedding'] + vi = res.body['data'][i]['embedding'] + for x, y in zip(v0, vi): + assert abs(x - y) < EPSILON + + +@pytest.mark.parametrize( + "content,n_tokens", + [ + ("I believe the meaning of life is", 9), + ("This is a test", 6), + ] +) +def test_embedding_usage_single(content, n_tokens): + global server + server.start() + res = server.make_request("POST", "/v1/embeddings", data={"input": content}) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == n_tokens + + +def test_embedding_usage_multiple(): + global server + server.start() + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == 2 * 9 + + +def test_embedding_openai_library_base64(): + server.start() + test_input = "Test base64 embedding output" + + # get embedding in default format + res = server.make_request("POST", "/v1/embeddings", data={ + "input": test_input + }) + assert res.status_code == 200 + vec0 = res.body["data"][0]["embedding"] + + # get embedding in base64 format + res = server.make_request("POST", "/v1/embeddings", data={ + "input": test_input, + "encoding_format": "base64" + }) + + assert res.status_code == 200 + assert "data" in res.body + assert len(res.body["data"]) == 1 + + embedding_data = res.body["data"][0] + assert "embedding" in embedding_data + assert isinstance(embedding_data["embedding"], str) + + # Verify embedding is valid base64 + decoded = base64.b64decode(embedding_data["embedding"]) + # Verify decoded data can be converted back to float array + float_count = len(decoded) // 4 # 4 bytes per float + floats = struct.unpack(f'{float_count}f', decoded) + assert len(floats) > 0 + assert all(isinstance(x, float) for x in floats) + assert len(floats) == len(vec0) + + # make sure the decoded data is the same as the original + for x, y in zip(floats, vec0): + assert abs(x - y) < EPSILON diff --git a/examples/server/tests/unit/test_infill.py b/examples/server/tests/unit/test_infill.py new file mode 100644 index 000000000..10554db0f --- /dev/null +++ b/examples/server/tests/unit/test_infill.py @@ -0,0 +1,77 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama_infill() + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama_infill() + + +def test_infill_without_input_extra(): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 200 + assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"]) + + +def test_infill_with_input_extra(): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "input_extra": [{ + "filename": "llama.h", + "text": "LLAMA_API int32_t llama_n_threads();\n" + }], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 200 + assert match_regex("(Dad|excited|park)+", res.body["content"]) + + +@pytest.mark.parametrize("input_extra", [ + {}, + {"filename": "ok"}, + {"filename": 123}, + {"filename": 123, "text": "abc"}, + {"filename": 123, "text": 456}, +]) +def test_invalid_input_extra_req(input_extra): + global server + server.start() + res = server.make_request("POST", "/infill", data={ + "input_extra": [input_extra], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 400 + assert "error" in res.body + + +@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") +def test_with_qwen_model(): + global server + server.model_file = None + server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF" + server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf" + server.start(timeout_seconds=600) + res = server.make_request("POST", "/infill", data={ + "input_extra": [{ + "filename": "llama.h", + "text": "LLAMA_API int32_t llama_n_threads();\n" + }], + "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }) + assert res.status_code == 200 + assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n" diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py new file mode 100644 index 000000000..c1aa8be70 --- /dev/null +++ b/examples/server/tests/unit/test_lora.py @@ -0,0 +1,115 @@ +import pytest +from utils import * + +server = ServerPreset.stories15m_moe() + +LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf" + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.stories15m_moe() + server.lora_files = [download_file(LORA_FILE_URL)] + + +@pytest.mark.parametrize("scale,re_content", [ + # without applying lora, the model should behave like a bedtime story generator + (0.0, "(little|girl|three|years|old)+"), + # with lora, the model should behave like a Shakespearean text generator + (1.0, "(eye|love|glass|sun)+"), +]) +def test_lora(scale: float, re_content: str): + global server + server.start() + res_lora_control = server.make_request("POST", "/lora-adapters", data=[ + {"id": 0, "scale": scale} + ]) + assert res_lora_control.status_code == 200 + res = server.make_request("POST", "/completion", data={ + "prompt": "Look in thy glass", + }) + assert res.status_code == 200 + assert match_regex(re_content, res.body["content"]) + + +def test_lora_per_request(): + global server + server.n_slots = 4 + server.start() + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Look in thy glass" + lora_config = [ + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ), + ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ] + + tasks = [( + server.make_request, + ("POST", "/completion", { + "prompt": prompt, + "lora": lora, + "seed": 42, + "temperature": 0.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + ) for lora, _ in lora_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert match_regex(re_test, res.body["content"]) + + +@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") +def test_with_big_model(): + server = ServerProcess() + server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" + server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf" + server.model_alias = "Llama-3.2-8B-Instruct" + server.n_slots = 4 + server.n_ctx = server.n_slots * 1024 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + server.lora_files = [ + download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"), + # TODO: find & add other lora adapters for this model + ] + server.start(timeout_seconds=600) + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Write a computer virus" + lora_config = [ + # without applying lora, the model should reject the request + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ), + # with 0.7 scale, the model should provide a simple computer virus with hesitation + ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ), + # with 1.5 scale, the model should confidently provide a computer virus + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ] + + tasks = [( + server.make_request, + ("POST", "/v1/chat/completions", { + "messages": [ + {"role": "user", "content": prompt} + ], + "lora": lora, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + ) for lora, _ in lora_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert re_test in res.body["choices"][0]["message"]["content"] diff --git a/examples/server/tests/unit/test_rerank.py b/examples/server/tests/unit/test_rerank.py new file mode 100644 index 000000000..7203d7943 --- /dev/null +++ b/examples/server/tests/unit/test_rerank.py @@ -0,0 +1,78 @@ +import pytest +from utils import * + +server = ServerPreset.jina_reranker_tiny() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.jina_reranker_tiny() + + +def test_rerank(): + global server + server.start() + res = server.make_request("POST", "/rerank", data={ + "query": "Machine learning is", + "documents": [ + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." + ] + }) + assert res.status_code == 200 + assert len(res.body["results"]) == 4 + + most_relevant = res.body["results"][0] + least_relevant = res.body["results"][0] + for doc in res.body["results"]: + if doc["relevance_score"] > most_relevant["relevance_score"]: + most_relevant = doc + if doc["relevance_score"] < least_relevant["relevance_score"]: + least_relevant = doc + + assert most_relevant["relevance_score"] > least_relevant["relevance_score"] + assert most_relevant["index"] == 2 + assert least_relevant["index"] == 3 + + +@pytest.mark.parametrize("documents", [ + [], + None, + 123, + [1, 2, 3], +]) +def test_invalid_rerank_req(documents): + global server + server.start() + res = server.make_request("POST", "/rerank", data={ + "query": "Machine learning is", + "documents": documents, + }) + assert res.status_code == 400 + assert "error" in res.body + + +@pytest.mark.parametrize( + "query,doc1,doc2,n_tokens", + [ + ("Machine learning is", "A machine", "Learning is", 19), + ("Which city?", "Machine learning is ", "Paris, capitale de la", 26), + ] +) +def test_rerank_usage(query, doc1, doc2, n_tokens): + global server + server.start() + + res = server.make_request("POST", "/rerank", data={ + "query": query, + "documents": [ + doc1, + doc2, + ] + }) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == n_tokens diff --git a/examples/server/tests/unit/test_security.py b/examples/server/tests/unit/test_security.py new file mode 100644 index 000000000..620b25376 --- /dev/null +++ b/examples/server/tests/unit/test_security.py @@ -0,0 +1,83 @@ +import pytest +from openai import OpenAI +from utils import * + +server = ServerPreset.tinyllama2() + +TEST_API_KEY = "sk-this-is-the-secret-key" + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.api_key = TEST_API_KEY + + +@pytest.mark.parametrize("endpoint", ["/health", "/models"]) +def test_access_public_endpoint(endpoint: str): + global server + server.start() + res = server.make_request("GET", endpoint) + assert res.status_code == 200 + assert "error" not in res.body + + +@pytest.mark.parametrize("api_key", [None, "invalid-key"]) +def test_incorrect_api_key(api_key: str): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "prompt": "I believe the meaning of life is", + }, headers={ + "Authorization": f"Bearer {api_key}" if api_key else None, + }) + assert res.status_code == 401 + assert "error" in res.body + assert res.body["error"]["type"] == "authentication_error" + + +def test_correct_api_key(): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "prompt": "I believe the meaning of life is", + }, headers={ + "Authorization": f"Bearer {TEST_API_KEY}", + }) + assert res.status_code == 200 + assert "error" not in res.body + assert "content" in res.body + + +def test_openai_library_correct_api_key(): + global server + server.start() + client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}") + res = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a chatbot."}, + {"role": "user", "content": "What is the meaning of life?"}, + ], + ) + assert len(res.choices) == 1 + + +@pytest.mark.parametrize("origin,cors_header,cors_header_value", [ + ("localhost", "Access-Control-Allow-Origin", "localhost"), + ("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"), + ("origin", "Access-Control-Allow-Credentials", "true"), + ("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"), + ("web.mydomain.fr", "Access-Control-Allow-Headers", "*"), +]) +def test_cors_options(origin: str, cors_header: str, cors_header_value: str): + global server + server.start() + res = server.make_request("OPTIONS", "/completions", headers={ + "Origin": origin, + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Authorization", + }) + assert res.status_code == 200 + assert cors_header in res.headers + assert res.headers[cors_header] == cors_header_value diff --git a/examples/server/tests/unit/test_slot_save.py b/examples/server/tests/unit/test_slot_save.py new file mode 100644 index 000000000..38704f5ec --- /dev/null +++ b/examples/server/tests/unit/test_slot_save.py @@ -0,0 +1,98 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.slot_save_path = "./tmp" + server.temperature = 0.0 + + +def test_slot_save_restore(): + global server + server.start() + + # First prompt in slot 1 should be fully processed + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed + + # Save state of slot 1 + res = server.make_request("POST", "/slots/1?action=save", data={ + "filename": "slot1.bin", + }) + assert res.status_code == 200 + assert res.body["n_saved"] == 84 + + # Since we have cache, this should only process the last tokens + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of Germany?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Jack|said)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 6 # only different part is processed + + # Loading the saved cache into slot 0 + res = server.make_request("POST", "/slots/0?action=restore", data={ + "filename": "slot1.bin", + }) + assert res.status_code == 200 + assert res.body["n_restored"] == 84 + + # Since we have cache, slot 0 should only process the last tokens + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of Germany?", + "id_slot": 0, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Jack|said)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 6 # only different part is processed + + # For verification that slot 1 was not corrupted during slot 0 load, same thing should work + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of Germany?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Jack|said)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 1 + + +def test_slot_erase(): + global server + server.start() + + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed + + # erase slot 1 + res = server.make_request("POST", "/slots/1?action=erase") + assert res.status_code == 200 + + # re-run the same prompt, it should process all tokens again + res = server.make_request("POST", "/completion", data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }) + assert res.status_code == 200 + assert match_regex("(Whiskers|Flana)+", res.body["content"]) + assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed diff --git a/examples/server/tests/unit/test_speculative.py b/examples/server/tests/unit/test_speculative.py new file mode 100644 index 000000000..54db38cf3 --- /dev/null +++ b/examples/server/tests/unit/test_speculative.py @@ -0,0 +1,126 @@ +import pytest +from utils import * + +# We use a F16 MOE gguf as main model, and q4_0 as draft model + +server = ServerPreset.stories15m_moe() + +MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" + +def create_server(): + global server + server = ServerPreset.stories15m_moe() + # set default values + server.model_draft = download_file(MODEL_DRAFT_FILE_URL) + server.draft_min = 4 + server.draft_max = 8 + + +@pytest.fixture(scope="module", autouse=True) +def fixture_create_server(): + return create_server() + + +def test_with_and_without_draft(): + global server + server.model_draft = None # disable draft model + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + content_no_draft = res.body["content"] + server.stop() + + # create new server with draft model + create_server() + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + content_draft = res.body["content"] + + assert content_no_draft == content_draft + + +def test_different_draft_min_draft_max(): + global server + test_values = [ + (1, 2), + (1, 4), + (4, 8), + (4, 12), + (8, 16), + ] + last_content = None + for draft_min, draft_max in test_values: + server.stop() + server.draft_min = draft_min + server.draft_max = draft_max + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }) + assert res.status_code == 200 + if last_content is not None: + assert last_content == res.body["content"] + last_content = res.body["content"] + + +def test_slot_ctx_not_exceeded(): + global server + server.n_ctx = 64 + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "Hello " * 56, + "temperature": 0.0, + "top_k": 1, + "speculative.p_min": 0.0, + }) + assert res.status_code == 200 + assert len(res.body["content"]) > 0 + + +def test_with_ctx_shift(): + global server + server.n_ctx = 64 + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "Hello " * 56, + "temperature": 0.0, + "top_k": 1, + "n_predict": 64, + "speculative.p_min": 0.0, + }) + assert res.status_code == 200 + assert len(res.body["content"]) > 0 + assert res.body["tokens_predicted"] == 64 + assert res.body["truncated"] == True + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (1, 2), + (2, 2), +]) +def test_multi_requests_parallel(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.start() + tasks = [] + for _ in range(n_requests): + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }))) + results = parallel_function_calls(tasks) + for res in results: + assert res.status_code == 200 + assert match_regex("(wise|kind|owl|answer)+", res.body["content"]) diff --git a/examples/server/tests/unit/test_tokenize.py b/examples/server/tests/unit/test_tokenize.py new file mode 100644 index 000000000..382457c9d --- /dev/null +++ b/examples/server/tests/unit/test_tokenize.py @@ -0,0 +1,59 @@ +import pytest +from utils import * + +server = ServerPreset.tinyllama2() + + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + + +def test_tokenize_detokenize(): + global server + server.start() + # tokenize + content = "What is the capital of France ?" + res_tok = server.make_request("POST", "/tokenize", data={ + "content": content + }) + assert res_tok.status_code == 200 + assert len(res_tok.body["tokens"]) > 5 + # detokenize + res_detok = server.make_request("POST", "/detokenize", data={ + "tokens": res_tok.body["tokens"], + }) + assert res_detok.status_code == 200 + assert res_detok.body["content"].strip() == content + + +def test_tokenize_with_bos(): + global server + server.start() + # tokenize + content = "What is the capital of France ?" + bosId = 1 + res_tok = server.make_request("POST", "/tokenize", data={ + "content": content, + "add_special": True, + }) + assert res_tok.status_code == 200 + assert res_tok.body["tokens"][0] == bosId + + +def test_tokenize_with_pieces(): + global server + server.start() + # tokenize + content = "This is a test string with unicode 媽 and emoji 🤗" + res_tok = server.make_request("POST", "/tokenize", data={ + "content": content, + "with_pieces": True, + }) + assert res_tok.status_code == 200 + for token in res_tok.body["tokens"]: + assert "id" in token + assert token["id"] > 0 + assert "piece" in token + assert len(token["piece"]) > 0 diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py new file mode 100644 index 000000000..e6ed9c9be --- /dev/null +++ b/examples/server/tests/unit/test_tool_call.py @@ -0,0 +1,352 @@ +import pytest +from utils import * + +server: ServerProcess + +TIMEOUT_SERVER_START = 15*60 +TIMEOUT_HTTP_REQUEST = 60 + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.model_alias = "tinyllama-2-tool-call" + server.server_port = 8081 + + +TEST_TOOL = { + "type":"function", + "function": { + "name": "test", + "description": "", + "parameters": { + "type": "object", + "properties": { + "success": {"type": "boolean", "const": True}, + }, + "required": ["success"] + } + } +} + +PYTHON_TOOL = { + "type": "function", + "function": { + "name": "python", + "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The code to run in the ipython interpreter." + } + }, + "required": ["code"] + } + } +} + +WEATHER_TOOL = { + "type":"function", + "function":{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'" + } + }, + "required":["location"] + } + } +} + + +def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): + n_predict = 512 + global server + # server = ServerPreset.stories15m_moe() + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + assert isinstance(actual_arguments, str) + if argument_key is not None: + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + + +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("google-gemma-2-2b-it", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), +]) +def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None): + do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + + +@pytest.mark.slow +@pytest.mark.parametrize("template_name,tool,argument_key", [ + ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), + ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), + ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), + ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), + ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"), + ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), +]) +def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None): + do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + + +@pytest.mark.slow +@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [ + (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + # TODO: fix these + # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), +]) +def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None): + n_predict = 512 + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = n_predict + server.model_hf_repo = hf_repo + server.model_hf_file = None + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + assert expected_function_name == tool_call["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + assert isinstance(actual_arguments, str) + if argument_key is not None: + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + + +def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + global server + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": tools if tools else None, + "tool_choice": tool_choice, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + + +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + + +@pytest.mark.slow +@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ + ("meetkai-functionary-medium-v3.2", 256, [], None), + ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'), + ("meetkai-functionary-medium-v3.1", 256, [], None), + ("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [], None), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'), +]) +def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): + do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + + +@pytest.mark.slow +@pytest.mark.parametrize("hf_repo,template_override", [ + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), +]) +def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | None] | None): + global server + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 512 + server.model_hf_repo = hf_repo + server.model_hf_file = None + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 256, + "messages": [ + {"role": "user", "content": "What is the weather in Istanbul?"}, + ], + "tools": [WEATHER_TOOL], + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" + location = actual_arguments["location"] + assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" + assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' + + +@pytest.mark.slow +@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ + (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), + ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), +]) +def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None): + global server + server.n_slots = 1 + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 128 + server.model_hf_repo = hf_repo + server.model_hf_file = None + if template_override: + (template_hf_repo, template_variant) = template_override + server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" + assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + server.start(timeout_seconds=TIMEOUT_SERVER_START) + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 256, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": [PYTHON_TOOL], + # Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test. + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }, timeout=TIMEOUT_HTTP_REQUEST) + assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" + choice = res.body["choices"][0] + tool_calls = choice["message"].get("tool_calls") + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + tool_call = tool_calls[0] + assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] + actual_arguments = tool_call["function"]["arguments"] + if expected_arguments_override is not None: + assert actual_arguments == expected_arguments_override + else: + actual_arguments = json.loads(actual_arguments) + assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" + code = actual_arguments["code"] + assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" + assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py new file mode 100644 index 000000000..ce0680662 --- /dev/null +++ b/examples/server/tests/utils.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# type: ignore[reportUnusedImport] + +import subprocess +import os +import re +import json +import sys +import requests +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import ( + Any, + Callable, + ContextManager, + Iterable, + Iterator, + List, + Literal, + Tuple, + Set, +) +from re import RegexFlag +import wget + + +DEFAULT_HTTP_TIMEOUT = 12 if "LLAMA_SANITIZE" not in os.environ else 30 + + +class ServerResponse: + headers: dict + status_code: int + body: dict | Any + + +class ServerProcess: + # default options + debug: bool = False + server_port: int = 8080 + server_host: str = "127.0.0.1" + model_hf_repo: str = "ggml-org/models" + model_hf_file: str | None = "tinyllamas/stories260K.gguf" + model_alias: str = "tinyllama-2" + temperature: float = 0.8 + seed: int = 42 + + # custom options + model_alias: str | None = None + model_url: str | None = None + model_file: str | None = None + model_draft: str | None = None + n_threads: int | None = None + n_gpu_layer: int | None = None + n_batch: int | None = None + n_ubatch: int | None = None + n_ctx: int | None = None + n_ga: int | None = None + n_ga_w: int | None = None + n_predict: int | None = None + n_prompts: int | None = 0 + slot_save_path: str | None = None + id_slot: int | None = None + cache_prompt: bool | None = None + n_slots: int | None = None + server_continuous_batching: bool | None = False + server_embeddings: bool | None = False + server_reranking: bool | None = False + server_metrics: bool | None = False + server_slots: bool | None = False + pooling: str | None = None + draft: int | None = None + api_key: str | None = None + lora_files: List[str] | None = None + disable_ctx_shift: int | None = False + draft_min: int | None = None + draft_max: int | None = None + no_webui: bool | None = None + jinja: bool | None = None + chat_template: str | None = None + chat_template_file: str | None = None + + # session variables + process: subprocess.Popen | None = None + + def __init__(self): + if "N_GPU_LAYERS" in os.environ: + self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"]) + if "DEBUG" in os.environ: + self.debug = True + if "PORT" in os.environ: + self.server_port = int(os.environ["PORT"]) + + def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: + if "LLAMA_SERVER_BIN_PATH" in os.environ: + server_path = os.environ["LLAMA_SERVER_BIN_PATH"] + elif os.name == "nt": + server_path = "../../../build/bin/Release/llama-server.exe" + else: + server_path = "../../../build/bin/llama-server" + server_args = [ + "--host", + self.server_host, + "--port", + self.server_port, + "--temp", + self.temperature, + "--seed", + self.seed, + ] + if self.model_file: + server_args.extend(["--model", self.model_file]) + if self.model_url: + server_args.extend(["--model-url", self.model_url]) + if self.model_draft: + server_args.extend(["--model-draft", self.model_draft]) + if self.model_hf_repo: + server_args.extend(["--hf-repo", self.model_hf_repo]) + if self.model_hf_file: + server_args.extend(["--hf-file", self.model_hf_file]) + if self.n_batch: + server_args.extend(["--batch-size", self.n_batch]) + if self.n_ubatch: + server_args.extend(["--ubatch-size", self.n_ubatch]) + if self.n_threads: + server_args.extend(["--threads", self.n_threads]) + if self.n_gpu_layer: + server_args.extend(["--n-gpu-layers", self.n_gpu_layer]) + if self.draft is not None: + server_args.extend(["--draft", self.draft]) + if self.server_continuous_batching: + server_args.append("--cont-batching") + if self.server_embeddings: + server_args.append("--embedding") + if self.server_reranking: + server_args.append("--reranking") + if self.server_metrics: + server_args.append("--metrics") + if self.server_slots: + server_args.append("--slots") + if self.pooling: + server_args.extend(["--pooling", self.pooling]) + if self.model_alias: + server_args.extend(["--alias", self.model_alias]) + if self.n_ctx: + server_args.extend(["--ctx-size", self.n_ctx]) + if self.n_slots: + server_args.extend(["--parallel", self.n_slots]) + if self.n_predict: + server_args.extend(["--n-predict", self.n_predict]) + if self.slot_save_path: + server_args.extend(["--slot-save-path", self.slot_save_path]) + if self.n_ga: + server_args.extend(["--grp-attn-n", self.n_ga]) + if self.n_ga_w: + server_args.extend(["--grp-attn-w", self.n_ga_w]) + if self.debug: + server_args.append("--verbose") + if self.lora_files: + for lora_file in self.lora_files: + server_args.extend(["--lora", lora_file]) + if self.disable_ctx_shift: + server_args.extend(["--no-context-shift"]) + if self.api_key: + server_args.extend(["--api-key", self.api_key]) + if self.draft_max: + server_args.extend(["--draft-max", self.draft_max]) + if self.draft_min: + server_args.extend(["--draft-min", self.draft_min]) + if self.no_webui: + server_args.append("--no-webui") + if self.jinja: + server_args.append("--jinja") + if self.chat_template: + server_args.extend(["--chat-template", self.chat_template]) + if self.chat_template_file: + server_args.extend(["--chat-template-file", self.chat_template_file]) + + args = [str(arg) for arg in [server_path, *server_args]] + print(f"bench: starting server with: {' '.join(args)}") + + flags = 0 + if "nt" == os.name: + flags |= subprocess.DETACHED_PROCESS + flags |= subprocess.CREATE_NEW_PROCESS_GROUP + flags |= subprocess.CREATE_NO_WINDOW + + self.process = subprocess.Popen( + [str(arg) for arg in [server_path, *server_args]], + creationflags=flags, + stdout=sys.stdout, + stderr=sys.stdout, + env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None, + ) + server_instances.add(self) + + print(f"server pid={self.process.pid}, pytest pid={os.getpid()}") + + # wait for server to start + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + response = self.make_request("GET", "/health", headers={ + "Authorization": f"Bearer {self.api_key}" if self.api_key else None + }) + if response.status_code == 200: + self.ready = True + return # server is ready + except Exception as e: + pass + print(f"Waiting for server to start...") + time.sleep(0.5) + raise TimeoutError(f"Server did not start within {timeout_seconds} seconds") + + def stop(self) -> None: + if self in server_instances: + server_instances.remove(self) + if self.process: + print(f"Stopping server with pid={self.process.pid}") + self.process.kill() + self.process = None + + def make_request( + self, + method: str, + path: str, + data: dict | Any | None = None, + headers: dict | None = None, + timeout: float | None = None, + ) -> ServerResponse: + url = f"http://{self.server_host}:{self.server_port}{path}" + parse_body = False + if method == "GET": + response = requests.get(url, headers=headers, timeout=timeout) + parse_body = True + elif method == "POST": + response = requests.post(url, headers=headers, json=data, timeout=timeout) + parse_body = True + elif method == "OPTIONS": + response = requests.options(url, headers=headers, timeout=timeout) + else: + raise ValueError(f"Unimplemented method: {method}") + result = ServerResponse() + result.headers = dict(response.headers) + result.status_code = response.status_code + result.body = response.json() if parse_body else None + print("Response from server", json.dumps(result.body, indent=2)) + return result + + def make_stream_request( + self, + method: str, + path: str, + data: dict | None = None, + headers: dict | None = None, + ) -> Iterator[dict]: + url = f"http://{self.server_host}:{self.server_port}{path}" + if method == "POST": + response = requests.post(url, headers=headers, json=data, stream=True) + else: + raise ValueError(f"Unimplemented method: {method}") + for line_bytes in response.iter_lines(): + line = line_bytes.decode("utf-8") + if '[DONE]' in line: + break + elif line.startswith('data: '): + data = json.loads(line[6:]) + print("Partial response from server", json.dumps(data, indent=2)) + yield data + + +server_instances: Set[ServerProcess] = set() + + +class ServerPreset: + @staticmethod + def tinyllama2() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/stories260K.gguf" + server.model_alias = "tinyllama-2" + server.n_ctx = 256 + server.n_batch = 32 + server.n_slots = 2 + server.n_predict = 64 + server.seed = 42 + return server + + @staticmethod + def bert_bge_small() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 512 + server.n_batch = 128 + server.n_ubatch = 128 + server.n_slots = 2 + server.seed = 42 + server.server_embeddings = True + return server + + @staticmethod + def tinyllama_infill() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "tinyllamas/stories260K-infill.gguf" + server.model_alias = "tinyllama-infill" + server.n_ctx = 2048 + server.n_batch = 1024 + server.n_slots = 1 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + return server + + @staticmethod + def stories15m_moe() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/stories15M_MOE" + server.model_hf_file = "stories15M_MOE-F16.gguf" + server.model_alias = "stories15m-moe" + server.n_ctx = 2048 + server.n_batch = 1024 + server.n_slots = 1 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + return server + + @staticmethod + def jina_reranker_tiny() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf" + server.model_alias = "jina-reranker" + server.n_ctx = 512 + server.n_batch = 512 + server.n_slots = 1 + server.seed = 42 + server.server_reranking = True + return server + + +def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: + """ + Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS. + + Example usage: + + results = parallel_function_calls([ + (func1, (arg1, arg2)), + (func2, (arg3, arg4)), + ]) + """ + results = [None] * len(function_list) + exceptions = [] + + def worker(index, func, args): + try: + result = func(*args) + results[index] = result + except Exception as e: + exceptions.append((index, str(e))) + + with ThreadPoolExecutor() as executor: + futures = [] + for i, (func, args) in enumerate(function_list): + future = executor.submit(worker, i, func, args) + futures.append(future) + + # Wait for all futures to complete + for future in as_completed(futures): + pass + + # Check if there were any exceptions + if exceptions: + print("Exceptions occurred:") + for index, error in exceptions: + print(f"Function at index {index}: {error}") + + return results + + +def match_regex(regex: str, text: str) -> bool: + return ( + re.compile( + regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL + ).search(text) + is not None + ) + + +def download_file(url: str, output_file_path: str | None = None) -> str: + """ + Download a file from a URL to a local path. If the file already exists, it will not be downloaded again. + + output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory. + + Returns the local path of the downloaded file. + """ + file_name = url.split('/').pop() + output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path + if not os.path.exists(output_file): + print(f"Downloading {url} to {output_file}") + wget.download(url, out=output_file) + print(f"Done downloading to {output_file}") + else: + print(f"File already exists at {output_file}") + return output_file + + +def is_slow_test_allowed(): + return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON" diff --git a/examples/server/themes/buttons-top/index.html b/examples/server/themes/buttons-top/index.html index 8334bcde5..3fb88fcc8 100644 --- a/examples/server/themes/buttons-top/index.html +++ b/examples/server/themes/buttons-top/index.html @@ -222,11 +222,9 @@ temperature: 0.7, repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_penalty: 1.18, // 1.0 = disabled - penalize_nl: false, top_k: 40, // <= 0 to use vocab size top_p: 0.95, // 1.0 = disabled min_p: 0.05, // 0 = disabled - tfs_z: 1.0, // 1.0 = disabled typical_p: 1.0, // 1.0 = disabled presence_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled @@ -780,7 +778,6 @@ ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} - ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} @@ -788,7 +785,6 @@
More options
- ${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })} ${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} ${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} diff --git a/examples/server/themes/wild/index.html b/examples/server/themes/wild/index.html index 8361c5774..73f36d4b2 100644 --- a/examples/server/themes/wild/index.html +++ b/examples/server/themes/wild/index.html @@ -225,11 +225,9 @@ temperature: 0.7, repeat_last_n: 256, // 0 = disable penalty, -1 = context size repeat_penalty: 1.18, // 1.0 = disabled - penalize_nl: false, top_k: 40, // <= 0 to use vocab size top_p: 0.95, // 1.0 = disabled min_p: 0.05, // 0 = disabled - tfs_z: 1.0, // 1.0 = disabled typical_p: 1.0, // 1.0 = disabled presence_penalty: 0.0, // 0.0 = disabled frequency_penalty: 0.0, // 0.0 = disabled @@ -783,7 +781,6 @@ ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })} ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })} ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })} - ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })} ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })} ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })} ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })} @@ -791,7 +788,6 @@
More options
- ${FloatField({ label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z })} ${FloatField({ label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p })} ${FloatField({ label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty })} ${FloatField({ label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty })} diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index edfce65b6..3d2c04666 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -1,7 +1,9 @@ #pragma once -#include "llama.h" #include "common.h" +#include "log.h" +#include "llama.h" +#include "common/base64.hpp" #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error @@ -14,52 +16,34 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" +#include "minja.hpp" +#include "chat.hpp" +#include "chat-template.hpp" +#include +#include #include #include -#include -#include +#include -#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" +#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo" using json = nlohmann::ordered_json; -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error -}; +#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) +#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) -extern bool server_verbose; -extern bool server_log_json; +#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#ifndef SERVER_VERBOSE -#define SERVER_VERBOSE 1 -#endif - -#if SERVER_VERBOSE != 1 -#define LOG_VERBOSE(MSG, ...) -#else -#define LOG_VERBOSE(MSG, ...) \ - do \ - { \ - if (server_verbose) \ - { \ - server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \ - } \ - } while (0) -#endif - -#define LOG_ERROR( MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__) -#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) - -static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra); +#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) +#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) template static T json_value(const json & body, const std::string & key, const T & default_value) { @@ -68,9 +52,7 @@ static T json_value(const json & body, const std::string & key, const T & defaul try { return body.at(key); } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { - std::stringstream ss; - ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value."; - LOG_WARNING(ss.str().c_str(), body); + LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name()); return default_value; } } else { @@ -78,55 +60,300 @@ static T json_value(const json & body, const std::string & key, const T & defaul } } -static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra) { - std::stringstream ss_tid; - ss_tid << std::this_thread::get_id(); - json log = json{ - {"tid", ss_tid.str()}, - {"timestamp", time(nullptr)}, - }; +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); - if (server_log_json) { - log.merge_patch({ - {"level", level}, - {"function", function}, - {"line", line}, - {"msg", message}, - }); +// +// tokenizer and input processing utils +// - if (!extra.empty()) { - log.merge_patch(extra); +static bool json_is_array_of_numbers(const json & data) { + if (data.is_array()) { + for (const auto & e : data) { + if (!e.is_number_integer()) { + return false; + } } - - printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str()); - } else { - char buf[1024]; - snprintf(buf, 1024, "%4s [%24s] %s", level, function, message); - - if (!extra.empty()) { - log.merge_patch(extra); - } - std::stringstream ss; - ss << buf << " |"; - for (const auto & el : log.items()) - { - const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); - ss << " " << el.key() << "=" << value; - } - - const std::string str = ss.str(); - printf("%.*s\n", (int)str.size(), str.data()); + return true; } - fflush(stdout); + return false; +} + +// is array having BOTH numbers & strings? +static bool json_is_array_of_mixed_numbers_strings(const json & data) { + bool seen_string = false; + bool seen_number = false; + if (data.is_array()) { + for (const auto & e : data) { + seen_string |= e.is_string(); + seen_number |= e.is_number_integer(); + if (seen_number && seen_string) { + return true; + } + } + } + return false; +} + +// get value by path(key1 / key2) +static json json_get_nested_values(const std::vector & paths, const json & js) { + json result = json::object(); + + for (const std::string & path : paths) { + json current = js; + const auto keys = string_split(path, /*separator*/ '/'); + bool valid_path = true; + for (const std::string & k : keys) { + if (valid_path && current.is_object() && current.contains(k)) { + current = current[k]; + } else { + valid_path = false; + } + } + if (valid_path) { + result[path] = current; + } + } + return result; +} + +/** + * this handles 2 cases: + * - only string, example: "string" + * - mixed string and tokens, example: [12, 34, "string", 56, 78] + */ +static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + llama_tokens prompt_tokens; + + if (json_prompt.is_array()) { + bool first = true; + for (const auto & p : json_prompt) { + if (p.is_string()) { + auto s = p.template get(); + + llama_tokens p; + if (first) { + p = common_tokenize(vocab, s, add_special, parse_special); + first = false; + } else { + p = common_tokenize(vocab, s, false, parse_special); + } + + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } else { + if (first) { + first = false; + } + + prompt_tokens.push_back(p.template get()); + } + } + } else { + auto s = json_prompt.template get(); + prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); + } + + return prompt_tokens; +} + +/** + * break the input "prompt" object into multiple prompt if needed, then tokenize them + * this supports these cases: + * - "prompt": "string" + * - "prompt": [12, 34, 56] + * - "prompt": [12, 34, "string", 56, 78] + * and multiple prompts (multi-tasks): + * - "prompt": ["string1", "string2"] + * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] + * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] + */ +static std::vector tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { + std::vector result; + if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { + // string or mixed + result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special)); + } else if (json_is_array_of_numbers(json_prompt)) { + // array of tokens + result.push_back(json_prompt.get()); + } else if (json_prompt.is_array()) { + // array of prompts + result.reserve(json_prompt.size()); + for (const auto & p : json_prompt) { + if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { + result.push_back(tokenize_mixed(vocab, p, add_special, parse_special)); + } else if (json_is_array_of_numbers(p)) { + // array of tokens + result.push_back(p.get()); + } else { + throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens"); + } + } + } else { + throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts"); + } + if (result.empty()) { + throw std::runtime_error("\"prompt\" must not be empty"); + } + return result; +} + +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +static size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; } // -// chat template utils +// template utils // +// format rerank task: [BOS]query[EOS][SEP]doc[EOS] +static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) { + llama_tokens result; + + result.reserve(doc.size() + query.size() + 4); + result.push_back(llama_vocab_bos(vocab)); + result.insert(result.end(), query.begin(), query.end()); + result.push_back(llama_vocab_eos(vocab)); + result.push_back(llama_vocab_sep(vocab)); + result.insert(result.end(), doc.begin(), doc.end()); + result.push_back(llama_vocab_eos(vocab)); + + return result; +} + +// format infill task +static llama_tokens format_infill( + const llama_vocab * vocab, + const json & input_prefix, + const json & input_suffix, + const json & input_extra, + const int n_batch, + const int n_predict, + const int n_ctx, + const bool spm_infill, + const llama_tokens & tokens_prompt + ) { + // TODO: optimize this block by reducing memory allocations and movement + + // use FIM repo-level pattern: + // ref: https://arxiv.org/pdf/2409.12186 + // + // [FIM_REP]myproject + // [FIM_SEP]filename0 + // extra chunk 0 + // [FIM_SEP]filename1 + // extra chunk 1 + // ... + // [FIM_SEP]filename + // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt + // + llama_tokens extra_tokens; + extra_tokens.reserve(n_ctx); + + auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); + auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); + + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: make project name an input + static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); + + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); + } + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + const std::string text = json_value(chunk, "text", std::string()); + const std::string filename = json_value(chunk, "filename", std::string("tmp")); + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } else { + // chunk separator in binary form to avoid confusing the AI + static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; + static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); + + extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); + } + + const auto chunk_tokens = common_tokenize(vocab, text, false, false); + extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); + } + + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + // TODO: current filename + static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); + + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); + } + + // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) + const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); + const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); + + SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); + + // fill the rest of the context with extra chunks + const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); + + tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); + tokens_suffix.resize(n_suffix_take); + + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); + + auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; + auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; + + if (llama_vocab_get_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); + } + + SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); + + // put the extra context before the FIM prefix + embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); + + return embd_inp; +} + // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { - std::vector chat; +inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { + std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { const auto & curr_msg = messages[i]; @@ -150,11 +377,12 @@ inline std::string format_chat(const struct llama_model * model, const std::stri throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } - chat.push_back({role, content}); + chat.push_back({role, content, /* tool_calls= */ {}}); } - auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true); - LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); + const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); + LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); + return formatted_chat; } @@ -243,68 +471,38 @@ static std::string random_string() { } static std::string gen_chatcmplid() { - std::stringstream chatcmplid; - chatcmplid << "chatcmpl-" << random_string(); - - return chatcmplid.str(); + return "chatcmpl-" + random_string(); } // // other common utils // -static size_t common_part(const std::vector & a, const std::vector & b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} - - return i; -} - -static size_t common_part(const std::string & a, const std::string & b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} - - return i; -} - static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } static size_t find_partial_stop_string(const std::string &stop, const std::string &text) { if (!text.empty() && !stop.empty()) { - const char text_last_char = text.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { - if (stop[char_index] == text_last_char) { - const std::string current_partial = stop.substr(0, char_index + 1); - if (ends_with(text, current_partial)) { - return text.size() - char_index - 1; - } + auto it = std::find(stop.rbegin(), stop.rend(), text.back()); + while (it != stop.rend()) { + size_t length = std::distance(it, stop.rend()); + if (text.length() >= length && 0 == text.compare(text.length() - length, length, stop)) { + return text.length() - length; } + it = std::find(std::next(it), stop.rend(), text.back()); } } return std::string::npos; } -static bool json_is_array_of_numbers(json data) { - if (data.is_array()) { - for (const auto & e : data) { - if (!e.is_number()) { - return false; - } - } - return true; - } - return false; -} - // TODO: reuse llama_detokenize template static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { std::string ret; for (; begin != end; ++begin) { - ret += llama_token_to_piece(ctx, *begin); + ret += common_token_to_piece(ctx, *begin); } return ret; @@ -312,7 +510,7 @@ static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { // format incomplete utf-8 multibyte character for output static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { - std::string out = token == -1 ? "" : llama_token_to_piece(ctx, token); + std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); // if the size is 1 and first bit is 1, meaning it's a partial character // (size > 1 meaning it's already a known token) @@ -326,52 +524,13 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, return out; } -struct completion_token_output { - llama_token tok; - std::string text_to_send; - - struct token_prob { - llama_token tok; - float prob; - }; - - std::vector probs; -}; - -// convert a vector of completion_token_output to json -static json probs_vector_to_json(const llama_context * ctx, const std::vector & probs) { - json out = json::array(); - - for (const auto & prob : probs) { - json probs_for_token = json::array(); - - for (const auto & p : prob.probs) { - const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok); - probs_for_token.push_back(json { - {"tok_str", tok_str}, - {"prob", p.prob}, - }); - } - - const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok); - out.push_back(json { - {"content", tok_str}, - {"probs", probs_for_token}, - }); - } - - return out; -} - -static bool server_sent_event(httplib::DataSink & sink, const char * event, json & data) { +static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) { const std::string str = std::string(event) + ": " + data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; + "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). - LOG_VERBOSE("data stream", { - { "to_send", str } - }); + LOG_DBG("data stream, to_send: %s", str.c_str()); return sink.write(str.c_str(), str.size()); } @@ -380,16 +539,71 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, json // OAI utils // -static json oaicompat_completion_params_parse( - const struct llama_model * model, - const json & body, /* openai api json semantics */ - const std::string & chat_template) { +static json oaicompat_completion_params_parse(const json & body) { json llama_params; - llama_params["__oaicompat"] = true; + if (!body.contains("prompt")) { + throw std::runtime_error("\"prompt\" is required"); + } - // Apply chat template to the list of messages - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); + // Handle "stop" field + if (body.contains("stop") && body.at("stop").is_string()) { + llama_params["stop"] = json::array({body.at("stop").get()}); + } else { + llama_params["stop"] = json_value(body, "stop", json::array()); + } + + // Handle "n" field + int n_choices = json_value(body, "n", 1); + if (n_choices != 1) { + throw std::runtime_error("Only one completion choice is allowed"); + } + + // Params supported by OAI but unsupported by llama.cpp + static const std::vector unsupported_params { "best_of", "echo", "suffix" }; + for (const auto & param : unsupported_params) { + if (body.contains(param)) { + throw std::runtime_error("Unsupported param: " + param); + } + } + + // Copy remaining properties to llama_params + for (const auto & item : body.items()) { + // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" + if (!llama_params.contains(item.key()) || item.key() == "n_predict") { + llama_params[item.key()] = item.value(); + } + } + + return llama_params; +} + +static json oaicompat_completion_params_parse( + const json & body, /* openai api json semantics */ + bool use_jinja, + const common_chat_templates & chat_templates) +{ + json llama_params; + const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use + ? *chat_templates.template_tool_use + : *chat_templates.template_default; + + auto tools = json_value(body, "tools", json()); + auto stream = json_value(body, "stream", false); + + if (tools.is_array() && !tools.empty()) { + if (stream) { + throw std::runtime_error("Cannot use tools with stream"); + } + if (!use_jinja) { + throw std::runtime_error("tools param requires --jinja flag"); + } + } + if (!use_jinja) { + if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) { + throw std::runtime_error("Unsupported param: tool_choice"); + } + } // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { @@ -404,11 +618,52 @@ static json oaicompat_completion_params_parse( std::string response_type = json_value(response_format, "type", std::string()); if (response_type == "json_object") { llama_params["json_schema"] = json_value(response_format, "schema", json::object()); + } else if (response_type == "json_schema") { + json json_schema = json_value(response_format, "json_schema", json::object()); + llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } } + // Apply chat template to the list of messages + if (use_jinja) { + auto tool_choice = json_value(body, "tool_choice", std::string("auto")); + if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { + throw std::runtime_error("Invalid tool_choice: " + tool_choice); + } + if (tool_choice != "none" && llama_params.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + common_chat_inputs inputs; + inputs.messages = body.at("messages"); + inputs.tools = tools; + inputs.tool_choice = tool_choice; + inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.stream = stream; + // TODO: support mixing schema w/ tools beyond generic format. + inputs.json_schema = json_value(llama_params, "json_schema", json::object()); + auto chat_params = common_chat_params_init(tmpl, inputs); + + llama_params["chat_format"] = static_cast(chat_params.format); + llama_params["prompt"] = chat_params.prompt; + llama_params["grammar"] = chat_params.grammar; + llama_params["grammar_lazy"] = chat_params.grammar_lazy; + auto grammar_triggers = json::array(); + for (const auto & trigger : chat_params.grammar_triggers) { + grammar_triggers.push_back({ + {"word", trigger.word}, + {"at_start", trigger.at_start}, + }); + } + llama_params["grammar_triggers"] = grammar_triggers; + for (const auto & stop : chat_params.additional_stops) { + llama_params["stop"].push_back(stop); + } + } else { + llama_params["prompt"] = format_chat(tmpl, body.at("messages")); + } + // Handle "n" field int n_choices = json_value(body, "n", 1); if (n_choices != 1) { @@ -417,22 +672,14 @@ static json oaicompat_completion_params_parse( // Handle "logprobs" field // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future - if (body.contains("logprobs")) { + if (json_value(body, "logprobs", false)) { llama_params["n_probs"] = json_value(body, "top_logprobs", 20); - } else if (body.contains("top_logprobs")) { + } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { throw std::runtime_error("top_logprobs requires logprobs to be set to true"); } - // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tools", "tool_choice" }; - for (auto & param : unsupported_params) { - if (body.contains(param)) { - throw std::runtime_error("Unsupported param: " + param); - } - } - // Copy remaining properties to llama_params - // This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint. + // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp for (const auto & item : body.items()) { // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" @@ -444,171 +691,41 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) { - bool stopped_word = result.count("stopped_word") != 0; - bool stopped_eos = json_value(result, "stopped_eos", false); - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason = "length"; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - - json choices = - streaming ? json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}) - : json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"message", json{{"content", content}, - {"role", "assistant"}}}}}); - - std::time_t t = std::time(0); - - json res = json { - {"choices", choices}, - {"created", t}, - {"model", - json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, - {"object", streaming ? "chat.completion.chunk" : "chat.completion"}, - {"usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} - }}, - {"id", completion_id} - }; - - if (server_verbose) { - res["__verbose"] = result; - } - - if (result.contains("completion_probabilities")) { - res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array()); - } - - return res; -} - -// return value is vector as there is one case where we might need to generate two responses -static std::vector format_partial_response_oaicompat(json result, const std::string & completion_id) { - if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { - return std::vector({result}); - } - - bool first = json_value(result, "oaicompat_token_ctr", 0) == 0; - std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL)); - - bool stopped_word = json_value(result, "stopped_word", false); - bool stopped_eos = json_value(result, "stopped_eos", false); - bool stopped_limit = json_value(result, "stopped_limit", false); - std::string content = json_value(result, "content", std::string("")); - - std::string finish_reason; - if (stopped_word || stopped_eos) { - finish_reason = "stop"; - } - if (stopped_limit) { - finish_reason = "length"; - } - - std::time_t t = std::time(0); - - json choices; - - if (!finish_reason.empty()) { - choices = json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}); - } else { - if (first) { - if (content.empty()) { - choices = json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{{"role", "assistant"}}}}}); - } else { - // We have to send this as two updates to conform to openai behavior - json initial_ret = json{{"choices", json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"role", "assistant"} - }}}})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - json second_ret = json{ - {"choices", json::array({json{{"finish_reason", nullptr}, - {"index", 0}, - {"delta", json{ - {"content", content}}} - }})}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"}}; - - return std::vector({initial_ret, second_ret}); - } - } else { - // Some idiosyncrasy in task processing logic makes several trailing calls - // with empty content, we ignore these at the calee site. - if (content.empty()) { - return std::vector({json::object()}); - } - - choices = json::array({json{ - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", - json{ - {"content", content}, - }}, - }}); - } - } - - json ret = json { - {"choices", choices}, - {"created", t}, - {"id", completion_id}, - {"model", modelname}, - {"object", "chat.completion.chunk"} - }; - if (!finish_reason.empty()) { - int num_tokens_predicted = json_value(result, "tokens_predicted", 0); - int num_prompt_tokens = json_value(result, "tokens_evaluated", 0); - ret.push_back({"usage", json { - {"completion_tokens", num_tokens_predicted}, - {"prompt_tokens", num_prompt_tokens}, - {"total_tokens", num_tokens_predicted + num_prompt_tokens} - }}); - } - - return std::vector({ret}); -} - -static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { +static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) { json data = json::array(); + int32_t n_tokens = 0; int i = 0; - for (auto & elem : embeddings) { - data.push_back(json{ - {"embedding", json_value(elem, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }); + for (const auto & elem : embeddings) { + json embedding_obj; + + if (use_base64) { + const auto& vec = json_value(elem, "embedding", json::array()).get>(); + const char* data_ptr = reinterpret_cast(vec.data()); + size_t data_size = vec.size() * sizeof(float); + embedding_obj = { + {"embedding", base64::encode(data_ptr, data_size)}, + {"index", i++}, + {"object", "embedding"}, + {"encoding_format", "base64"} + }; + } else { + embedding_obj = { + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + } + data.push_back(embedding_obj); + + n_tokens += json_value(elem, "tokens_evaluated", 0); } json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, {"usage", json { - {"prompt_tokens", 0}, - {"total_tokens", 0} + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} }}, {"data", data} }; @@ -616,7 +733,66 @@ static json format_embeddings_response_oaicompat(const json & request, const jso return res; } -static json format_tokenizer_response(const std::vector & tokens) { +static json format_response_rerank(const json & request, const json & ranks) { + json data = json::array(); + int32_t n_tokens = 0; + int i = 0; + for (const auto & rank : ranks) { + data.push_back(json{ + {"index", i++}, + {"relevance_score", json_value(rank, "score", 0.0)}, + }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); + } + + json res = json { + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} + }}, + {"results", data} + }; + + return res; +} + +static bool is_valid_utf8(const std::string & str) { + const unsigned char* bytes = reinterpret_cast(str.data()); + const unsigned char* end = bytes + str.length(); + + while (bytes < end) { + if (*bytes <= 0x7F) { + // 1-byte sequence (0xxxxxxx) + bytes++; + } else if ((*bytes & 0xE0) == 0xC0) { + // 2-byte sequence (110xxxxx 10xxxxxx) + if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) + return false; + bytes += 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) + return false; + bytes += 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) + if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || + (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) + return false; + bytes += 4; + } else { + // Invalid UTF-8 lead byte + return false; + } + } + + return true; +} + +static json format_tokenizer_response(const json & tokens) { return json { {"tokens", tokens} }; @@ -628,42 +804,92 @@ static json format_detokenized_response(const std::string & content) { }; } -static json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; +static json format_logit_bias(const std::vector & logit_bias) { + json data = json::array(); + for (const auto & lb : logit_bias) { + data.push_back(json{ + {"bias", lb.bias}, + {"token", lb.token}, + }); } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; + return data; +} + +static std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} + +static std::vector get_token_probabilities(llama_context * ctx, int idx) { + std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + // sort tokens by logits + std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + + // apply softmax + float max_l = cur[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < cur.size(); ++i) { + float p = expf(cur[i].logit - max_l); + cur[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < cur.size(); ++i) { + cur[i].p /= cum_sum; + } + + return cur; +} + +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { + return false; + } + } + return true; +} + +// parse lora config from JSON request, returned a copy of lora_base with updated scale +static std::vector parse_lora_request( + const std::vector & lora_base, + const json & data) { + std::vector lora(lora_base); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; } diff --git a/examples/server/webui/index.html b/examples/server/webui/index.html new file mode 100644 index 000000000..d3893ea4e --- /dev/null +++ b/examples/server/webui/index.html @@ -0,0 +1,343 @@ + + + + + + + 🦙 llama.cpp - chat + + + +
+
+ + + +
+ +
+
+

Conversations

+ + + +
+ + +
+ + New conversation +
+
+ {{ conv.messages[0].content }} +
+
+ Conversations are saved to browser's localStorage +
+
+
+ + +
+ +
+ + + +
llama.cpp
+ + +
+ +
+ +
+ + +
+ +
+
+
+ + +
+
+ + {{ messages.length === 0 ? 'Send a message to start' : '' }} +
+
+ +
+ + +
+ +
+
+ + +
+ + + +
+
+ +
+ + + + + + + +
+ + + + + + + + + + + + + diff --git a/examples/server/webui/package-lock.json b/examples/server/webui/package-lock.json new file mode 100644 index 000000000..bbebccbf2 --- /dev/null +++ b/examples/server/webui/package-lock.json @@ -0,0 +1,3309 @@ +{ + "name": "webui", + "version": "0.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "webui", + "version": "0.0.0", + "dependencies": { + "@sec-ant/readable-stream": "^0.6.0", + "@vscode/markdown-it-katex": "^1.1.1", + "autoprefixer": "^10.4.20", + "daisyui": "^4.12.14", + "highlight.js": "^11.10.0", + "katex": "^0.16.15", + "markdown-it": "^14.1.0", + "postcss": "^8.4.49", + "tailwindcss": "^3.4.15", + "textlinestream": "^1.1.1", + "vite-plugin-singlefile": "^2.0.3", + "vue": "^3.5.13" + }, + "devDependencies": { + "sass-embedded": "^1.83.0", + "vite": "^5.4.10" + } + }, + "node_modules/@alloc/quick-lru": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/@alloc/quick-lru/-/quick-lru-5.2.0.tgz", + "integrity": "sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==", + "license": "MIT", + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/@bufbuild/protobuf": { + "version": "2.2.3", + "resolved": "https://registry.npmjs.org/@bufbuild/protobuf/-/protobuf-2.2.3.tgz", + "integrity": "sha512-tFQoXHJdkEOSwj5tRIZSPNUuXK3RaR7T1nUrPgbYX1pUbvqqaaZAsfo+NXBPsz5rZMSKVFrgK1WL8Q/MSLvprg==", + "devOptional": true, + "license": "(Apache-2.0 AND BSD-3-Clause)" + }, + "node_modules/@esbuild/aix-ppc64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.21.5.tgz", + "integrity": "sha512-1SDgH6ZSPTlggy1yI6+Dbkiz8xzpHJEVAlF/AM1tHPLsf5STom9rwtjE4hKAF20FfXXNTFqEYXyJNWh1GiZedQ==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "aix" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-arm": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.21.5.tgz", + "integrity": "sha512-vCPvzSjpPHEi1siZdlvAlsPxXl7WbOVUBBAowWug4rJHb68Ox8KualB+1ocNvT5fjv6wpkX6o/iEpbDrf68zcg==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.21.5.tgz", + "integrity": "sha512-c0uX9VAUBQ7dTDCjq+wdyGLowMdtR/GoC2U5IYk/7D1H1JYC0qseD7+11iMP2mRLN9RcCMRcjC4YMclCzGwS/A==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/android-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.21.5.tgz", + "integrity": "sha512-D7aPRUUNHRBwHxzxRvp856rjUHRFW1SdQATKXH2hqA0kAZb1hKmi02OpYRacl0TxIGz/ZmXWlbZgjwWYaCakTA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/darwin-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.21.5.tgz", + "integrity": "sha512-se/JjF8NlmKVG4kNIuyWMV/22ZaerB+qaSi5MdrXtd6R08kvs2qCN4C09miupktDitvh8jRFflwGFBQcxZRjbw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/freebsd-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.21.5.tgz", + "integrity": "sha512-5JcRxxRDUJLX8JXp/wcBCy3pENnCgBR9bN6JsY4OmhfUtIHe3ZW0mawA7+RDAcMLrMIZaf03NlQiX9DGyB8h4g==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/freebsd-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.21.5.tgz", + "integrity": "sha512-J95kNBj1zkbMXtHVH29bBriQygMXqoVQOQYA+ISs0/2l3T9/kj42ow2mpqerRBxDJnmkUDCaQT/dfNXWX/ZZCQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-arm": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.21.5.tgz", + "integrity": "sha512-bPb5AHZtbeNGjCKVZ9UGqGwo8EUu4cLq68E95A53KlxAPRmUyYv2D6F0uUI65XisGOL1hBP5mTronbgo+0bFcA==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.21.5.tgz", + "integrity": "sha512-ibKvmyYzKsBeX8d8I7MH/TMfWDXBF3db4qM6sy+7re0YXya+K1cem3on9XgdT2EQGMu4hQyZhan7TeQ8XkGp4Q==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-ia32": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.21.5.tgz", + "integrity": "sha512-YvjXDqLRqPDl2dvRODYmmhz4rPeVKYvppfGYKSNGdyZkA01046pLWyRKKI3ax8fbJoK5QbxblURkwK/MWY18Tg==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-loong64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.21.5.tgz", + "integrity": "sha512-uHf1BmMG8qEvzdrzAqg2SIG/02+4/DHB6a9Kbya0XDvwDEKCoC8ZRWI5JJvNdUjtciBGFQ5PuBlpEOXQj+JQSg==", + "cpu": [ + "loong64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-mips64el": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.21.5.tgz", + "integrity": "sha512-IajOmO+KJK23bj52dFSNCMsz1QP1DqM6cwLUv3W1QwyxkyIWecfafnI555fvSGqEKwjMXVLokcV5ygHW5b3Jbg==", + "cpu": [ + "mips64el" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-ppc64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.21.5.tgz", + "integrity": "sha512-1hHV/Z4OEfMwpLO8rp7CvlhBDnjsC3CttJXIhBi+5Aj5r+MBvy4egg7wCbe//hSsT+RvDAG7s81tAvpL2XAE4w==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-riscv64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.21.5.tgz", + "integrity": "sha512-2HdXDMd9GMgTGrPWnJzP2ALSokE/0O5HhTUvWIbD3YdjME8JwvSCnNGBnTThKGEB91OZhzrJ4qIIxk/SBmyDDA==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-s390x": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.21.5.tgz", + "integrity": "sha512-zus5sxzqBJD3eXxwvjN1yQkRepANgxE9lgOW2qLnmr8ikMTphkjgXu1HR01K4FJg8h1kEEDAqDcZQtbrRnB41A==", + "cpu": [ + "s390x" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/linux-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.21.5.tgz", + "integrity": "sha512-1rYdTpyv03iycF1+BhzrzQJCdOuAOtaqHTWJZCWvijKD2N5Xu0TtVC8/+1faWqcP9iBCWOmjmhoH94dH82BxPQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/netbsd-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.21.5.tgz", + "integrity": "sha512-Woi2MXzXjMULccIwMnLciyZH4nCIMpWQAs049KEeMvOcNADVxo0UBIQPfSmxB3CWKedngg7sWZdLvLczpe0tLg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "netbsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/openbsd-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.21.5.tgz", + "integrity": "sha512-HLNNw99xsvx12lFBUwoT8EVCsSvRNDVxNpjZ7bPn947b8gJPzeHWyNVhFsaerc0n3TsbOINvRP2byTZ5LKezow==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "openbsd" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/sunos-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.21.5.tgz", + "integrity": "sha512-6+gjmFpfy0BHU5Tpptkuh8+uw3mnrvgs+dSPQXQOv3ekbordwnzTVEb4qnIvQcYXq6gzkyTnoZ9dZG+D4garKg==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "sunos" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.21.5.tgz", + "integrity": "sha512-Z0gOTd75VvXqyq7nsl93zwahcTROgqvuAcYDUr+vOv8uHhNSKROyU961kgtCD1e95IqPKSQKH7tBTslnS3tA8A==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-ia32": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.21.5.tgz", + "integrity": "sha512-SWXFF1CL2RVNMaVs+BBClwtfZSvDgtL//G/smwAc5oVK/UPu2Gu9tIaRgFmYFFKrmg3SyAjSrElf0TiJ1v8fYA==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@esbuild/win32-x64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.21.5.tgz", + "integrity": "sha512-tQd/1efJuzPC6rCFwEvLtci/xNFcTZknmXs98FYDfGE4wP9ClFV98nyKrzJKVPMhdDnjzLhdUyMX4PsQAPjwIw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/@rollup/rollup-android-arm-eabi": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.28.0.tgz", + "integrity": "sha512-wLJuPLT6grGZsy34g4N1yRfYeouklTgPhH1gWXCYspenKYD0s3cR99ZevOGw5BexMNywkbV3UkjADisozBmpPQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-android-arm64": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.28.0.tgz", + "integrity": "sha512-eiNkznlo0dLmVG/6wf+Ifi/v78G4d4QxRhuUl+s8EWZpDewgk7PX3ZyECUXU0Zq/Ca+8nU8cQpNC4Xgn2gFNDA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/@rollup/rollup-darwin-x64": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.28.0.tgz", + "integrity": "sha512-8hxgfReVs7k9Js1uAIhS6zq3I+wKQETInnWQtgzt8JfGx51R1N6DRVy3F4o0lQwumbErRz52YqwjfvuwRxGv1w==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/@rollup/rollup-freebsd-arm64": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-arm64/-/rollup-freebsd-arm64-4.28.0.tgz", + "integrity": "sha512-lA1zZB3bFx5oxu9fYud4+g1mt+lYXCoch0M0V/xhqLoGatbzVse0wlSQ1UYOWKpuSu3gyN4qEc0Dxf/DII1bhQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-freebsd-x64": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-freebsd-x64/-/rollup-freebsd-x64-4.28.0.tgz", + "integrity": "sha512-aI2plavbUDjCQB/sRbeUZWX9qp12GfYkYSJOrdYTL/C5D53bsE2/nBPuoiJKoWp5SN78v2Vr8ZPnB+/VbQ2pFA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "freebsd" + ] + }, + "node_modules/@rollup/rollup-linux-arm-gnueabihf": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.28.0.tgz", + "integrity": "sha512-WXveUPKtfqtaNvpf0iOb0M6xC64GzUX/OowbqfiCSXTdi/jLlOmH0Ba94/OkiY2yTGTwteo4/dsHRfh5bDCZ+w==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm-musleabihf": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-musleabihf/-/rollup-linux-arm-musleabihf-4.28.0.tgz", + "integrity": "sha512-yLc3O2NtOQR67lI79zsSc7lk31xjwcaocvdD1twL64PK1yNaIqCeWI9L5B4MFPAVGEVjH5k1oWSGuYX1Wutxpg==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-gnu": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.28.0.tgz", + "integrity": "sha512-+P9G9hjEpHucHRXqesY+3X9hD2wh0iNnJXX/QhS/J5vTdG6VhNYMxJ2rJkQOxRUd17u5mbMLHM7yWGZdAASfcg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-arm64-musl": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.28.0.tgz", + "integrity": "sha512-1xsm2rCKSTpKzi5/ypT5wfc+4bOGa/9yI/eaOLW0oMs7qpC542APWhl4A37AENGZ6St6GBMWhCCMM6tXgTIplw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-powerpc64le-gnu": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-powerpc64le-gnu/-/rollup-linux-powerpc64le-gnu-4.28.0.tgz", + "integrity": "sha512-zgWxMq8neVQeXL+ouSf6S7DoNeo6EPgi1eeqHXVKQxqPy1B2NvTbaOUWPn/7CfMKL7xvhV0/+fq/Z/J69g1WAQ==", + "cpu": [ + "ppc64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-riscv64-gnu": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.28.0.tgz", + "integrity": "sha512-VEdVYacLniRxbRJLNtzwGt5vwS0ycYshofI7cWAfj7Vg5asqj+pt+Q6x4n+AONSZW/kVm+5nklde0qs2EUwU2g==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-s390x-gnu": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-s390x-gnu/-/rollup-linux-s390x-gnu-4.28.0.tgz", + "integrity": "sha512-LQlP5t2hcDJh8HV8RELD9/xlYtEzJkm/aWGsauvdO2ulfl3QYRjqrKW+mGAIWP5kdNCBheqqqYIGElSRCaXfpw==", + "cpu": [ + "s390x" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-gnu": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.28.0.tgz", + "integrity": "sha512-Nl4KIzteVEKE9BdAvYoTkW19pa7LR/RBrT6F1dJCV/3pbjwDcaOq+edkP0LXuJ9kflW/xOK414X78r+K84+msw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-linux-x64-musl": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.28.0.tgz", + "integrity": "sha512-eKpJr4vBDOi4goT75MvW+0dXcNUqisK4jvibY9vDdlgLx+yekxSm55StsHbxUsRxSTt3JEQvlr3cGDkzcSP8bw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/@rollup/rollup-win32-arm64-msvc": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.28.0.tgz", + "integrity": "sha512-Vi+WR62xWGsE/Oj+mD0FNAPY2MEox3cfyG0zLpotZdehPFXwz6lypkGs5y38Jd/NVSbOD02aVad6q6QYF7i8Bg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-ia32-msvc": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.28.0.tgz", + "integrity": "sha512-kN/Vpip8emMLn/eOza+4JwqDZBL6MPNpkdaEsgUtW1NYN3DZvZqSQrbKzJcTL6hd8YNmFTn7XGWMwccOcJBL0A==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@rollup/rollup-win32-x64-msvc": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.28.0.tgz", + "integrity": "sha512-Bvno2/aZT6usSa7lRDL2+hMjVAGjuqaymF1ApZm31JXzniR/hvr14jpU+/z4X6Gt5BPlzosscyJZGUvguXIqeQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/@sec-ant/readable-stream": { + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/@sec-ant/readable-stream/-/readable-stream-0.6.0.tgz", + "integrity": "sha512-uiBh8DrB5FN35gP6/o8JEhEQ7/ci1jUsOZO/VMUjyvTpjtV54VstOXVj1TvTj/wsT23pfX6butxxh3qufsW3+g==", + "license": "MIT" + }, + "node_modules/@vscode/markdown-it-katex": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@vscode/markdown-it-katex/-/markdown-it-katex-1.1.1.tgz", + "integrity": "sha512-3KTlbsRBPJQLE2YmLL7K6nunTlU+W9T5+FjfNdWuIUKgxSS6HWLQHaO3L4MkJi7z7MpIPpY+g4N+cWNBPE/MSA==", + "license": "MIT", + "dependencies": { + "katex": "^0.16.4" + } + }, + "node_modules/@vue/compiler-dom": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/compiler-dom/-/compiler-dom-3.5.13.tgz", + "integrity": "sha512-ZOJ46sMOKUjO3e94wPdCzQ6P1Lx/vhp2RSvfaab88Ajexs0AHeV0uasYhi99WPaogmBlRHNRuly8xV75cNTMDA==", + "license": "MIT", + "dependencies": { + "@vue/compiler-core": "3.5.13", + "@vue/shared": "3.5.13" + } + }, + "node_modules/@vue/compiler-dom/node_modules/@babel/helper-string-parser": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", + "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@vue/compiler-dom/node_modules/@babel/helper-validator-identifier": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", + "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@vue/compiler-dom/node_modules/@babel/parser": { + "version": "7.26.2", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.2.tgz", + "integrity": "sha512-DWMCZH9WA4Maitz2q21SRKHo9QXZxkDsbNZoVD62gusNtNBBqDg9i7uOhASfTfIGNzW+O+r7+jAlM8dwphcJKQ==", + "license": "MIT", + "dependencies": { + "@babel/types": "^7.26.0" + }, + "bin": { + "parser": "bin/babel-parser.js" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@vue/compiler-dom/node_modules/@babel/types": { + "version": "7.26.0", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.0.tgz", + "integrity": "sha512-Z/yiTPj+lDVnF7lWeKCIJzaIkI0vYO87dMpZ4bg4TDrFe4XXLFWL1TbXU27gBP3QccxV9mZICCrnjnYlJjXHOA==", + "license": "MIT", + "dependencies": { + "@babel/helper-string-parser": "^7.25.9", + "@babel/helper-validator-identifier": "^7.25.9" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@vue/compiler-dom/node_modules/@vue/compiler-core": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/compiler-core/-/compiler-core-3.5.13.tgz", + "integrity": "sha512-oOdAkwqUfW1WqpwSYJce06wvt6HljgY3fGeM9NcVA1HaYOij3mZG9Rkysn0OHuyUAGMbEbARIpsG+LPVlBJ5/Q==", + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.25.3", + "@vue/shared": "3.5.13", + "entities": "^4.5.0", + "estree-walker": "^2.0.2", + "source-map-js": "^1.2.0" + } + }, + "node_modules/@vue/compiler-dom/node_modules/estree-walker": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", + "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", + "license": "MIT" + }, + "node_modules/@vue/compiler-dom/node_modules/source-map-js": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/@vue/compiler-sfc": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/compiler-sfc/-/compiler-sfc-3.5.13.tgz", + "integrity": "sha512-6VdaljMpD82w6c2749Zhf5T9u5uLBWKnVue6XWxprDobftnletJ8+oel7sexFfM3qIxNmVE7LSFGTpv6obNyaQ==", + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.25.3", + "@vue/compiler-core": "3.5.13", + "@vue/compiler-dom": "3.5.13", + "@vue/compiler-ssr": "3.5.13", + "@vue/shared": "3.5.13", + "estree-walker": "^2.0.2", + "magic-string": "^0.30.11", + "postcss": "^8.4.48", + "source-map-js": "^1.2.0" + } + }, + "node_modules/@vue/compiler-sfc/node_modules/@babel/helper-string-parser": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.25.9.tgz", + "integrity": "sha512-4A/SCr/2KLd5jrtOMFzaKjVtAei3+2r/NChoBNoZ3EyP/+GlhoaEGoWOZUmFmoITP7zOJyHIMm+DYRd8o3PvHA==", + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@vue/compiler-sfc/node_modules/@babel/helper-validator-identifier": { + "version": "7.25.9", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.25.9.tgz", + "integrity": "sha512-Ed61U6XJc3CVRfkERJWDz4dJwKe7iLmmJsbOGu9wSloNSFttHV0I8g6UAgb7qnK5ly5bGLPd4oXZlxCdANBOWQ==", + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@vue/compiler-sfc/node_modules/@babel/parser": { + "version": "7.26.2", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.26.2.tgz", + "integrity": "sha512-DWMCZH9WA4Maitz2q21SRKHo9QXZxkDsbNZoVD62gusNtNBBqDg9i7uOhASfTfIGNzW+O+r7+jAlM8dwphcJKQ==", + "license": "MIT", + "dependencies": { + "@babel/types": "^7.26.0" + }, + "bin": { + "parser": "bin/babel-parser.js" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/@vue/compiler-sfc/node_modules/@babel/types": { + "version": "7.26.0", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.26.0.tgz", + "integrity": "sha512-Z/yiTPj+lDVnF7lWeKCIJzaIkI0vYO87dMpZ4bg4TDrFe4XXLFWL1TbXU27gBP3QccxV9mZICCrnjnYlJjXHOA==", + "license": "MIT", + "dependencies": { + "@babel/helper-string-parser": "^7.25.9", + "@babel/helper-validator-identifier": "^7.25.9" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@vue/compiler-sfc/node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "license": "MIT" + }, + "node_modules/@vue/compiler-sfc/node_modules/@vue/compiler-core": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/compiler-core/-/compiler-core-3.5.13.tgz", + "integrity": "sha512-oOdAkwqUfW1WqpwSYJce06wvt6HljgY3fGeM9NcVA1HaYOij3mZG9Rkysn0OHuyUAGMbEbARIpsG+LPVlBJ5/Q==", + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.25.3", + "@vue/shared": "3.5.13", + "entities": "^4.5.0", + "estree-walker": "^2.0.2", + "source-map-js": "^1.2.0" + } + }, + "node_modules/@vue/compiler-sfc/node_modules/@vue/compiler-ssr": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/compiler-ssr/-/compiler-ssr-3.5.13.tgz", + "integrity": "sha512-wMH6vrYHxQl/IybKJagqbquvxpWCuVYpoUJfCqFZwa/JY1GdATAQ+TgVtgrwwMZ0D07QhA99rs/EAAWfvG6KpA==", + "license": "MIT", + "dependencies": { + "@vue/compiler-dom": "3.5.13", + "@vue/shared": "3.5.13" + } + }, + "node_modules/@vue/compiler-sfc/node_modules/estree-walker": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-2.0.2.tgz", + "integrity": "sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==", + "license": "MIT" + }, + "node_modules/@vue/compiler-sfc/node_modules/magic-string": { + "version": "0.30.14", + "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.30.14.tgz", + "integrity": "sha512-5c99P1WKTed11ZC0HMJOj6CDIue6F8ySu+bJL+85q1zBEIY8IklrJ1eiKC2NDRh3Ct3FcvmJPyQHb9erXMTJNw==", + "license": "MIT", + "dependencies": { + "@jridgewell/sourcemap-codec": "^1.5.0" + } + }, + "node_modules/@vue/compiler-sfc/node_modules/source-map-js": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/@vue/runtime-dom": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/runtime-dom/-/runtime-dom-3.5.13.tgz", + "integrity": "sha512-dLaj94s93NYLqjLiyFzVs9X6dWhTdAlEAciC3Moq7gzAc13VJUdCnjjRurNM6uTLFATRHexHCTu/Xp3eW6yoog==", + "license": "MIT", + "dependencies": { + "@vue/reactivity": "3.5.13", + "@vue/runtime-core": "3.5.13", + "@vue/shared": "3.5.13", + "csstype": "^3.1.3" + } + }, + "node_modules/@vue/runtime-dom/node_modules/@vue/reactivity": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/reactivity/-/reactivity-3.5.13.tgz", + "integrity": "sha512-NaCwtw8o48B9I6L1zl2p41OHo/2Z4wqYGGIK1Khu5T7yxrn+ATOixn/Udn2m+6kZKB/J7cuT9DbWWhRxqixACg==", + "license": "MIT", + "dependencies": { + "@vue/shared": "3.5.13" + } + }, + "node_modules/@vue/runtime-dom/node_modules/@vue/runtime-core": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/runtime-core/-/runtime-core-3.5.13.tgz", + "integrity": "sha512-Fj4YRQ3Az0WTZw1sFe+QDb0aXCerigEpw418pw1HBUKFtnQHWzwojaukAs2X/c9DQz4MQ4bsXTGlcpGxU/RCIw==", + "license": "MIT", + "dependencies": { + "@vue/reactivity": "3.5.13", + "@vue/shared": "3.5.13" + } + }, + "node_modules/@vue/runtime-dom/node_modules/csstype": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz", + "integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==", + "license": "MIT" + }, + "node_modules/@vue/server-renderer": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/server-renderer/-/server-renderer-3.5.13.tgz", + "integrity": "sha512-wAi4IRJV/2SAW3htkTlB+dHeRmpTiVIK1OGLWV1yeStVSebSQQOwGwIq0D3ZIoBj2C2qpgz5+vX9iEBkTdk5YA==", + "license": "MIT", + "dependencies": { + "@vue/compiler-ssr": "3.5.13", + "@vue/shared": "3.5.13" + }, + "peerDependencies": { + "vue": "3.5.13" + } + }, + "node_modules/@vue/server-renderer/node_modules/@vue/compiler-ssr": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/compiler-ssr/-/compiler-ssr-3.5.13.tgz", + "integrity": "sha512-wMH6vrYHxQl/IybKJagqbquvxpWCuVYpoUJfCqFZwa/JY1GdATAQ+TgVtgrwwMZ0D07QhA99rs/EAAWfvG6KpA==", + "license": "MIT", + "dependencies": { + "@vue/compiler-dom": "3.5.13", + "@vue/shared": "3.5.13" + } + }, + "node_modules/@vue/shared": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/@vue/shared/-/shared-3.5.13.tgz", + "integrity": "sha512-/hnE/qP5ZoGpol0a5mDi45bOd7t3tjYJBjsgCsivow7D48cJeV5l05RD82lPqi7gRiphZM37rnhW1l6ZoCNNnQ==", + "license": "MIT" + }, + "node_modules/arg": { + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/arg/-/arg-5.0.2.tgz", + "integrity": "sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg==", + "license": "MIT" + }, + "node_modules/argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "license": "Python-2.0" + }, + "node_modules/autoprefixer": { + "version": "10.4.20", + "resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.20.tgz", + "integrity": "sha512-XY25y5xSv/wEoqzDyXXME4AFfkZI0P23z6Fs3YgymDnKJkCGOnkL0iTxCa85UTqaSgfcqyf3UA6+c7wUvx/16g==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/autoprefixer" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "browserslist": "^4.23.3", + "caniuse-lite": "^1.0.30001646", + "fraction.js": "^4.3.7", + "normalize-range": "^0.1.2", + "picocolors": "^1.0.1", + "postcss-value-parser": "^4.2.0" + }, + "bin": { + "autoprefixer": "bin/autoprefixer" + }, + "engines": { + "node": "^10 || ^12 || >=14" + }, + "peerDependencies": { + "postcss": "^8.1.0" + } + }, + "node_modules/browserslist": { + "version": "4.24.2", + "resolved": "https://registry.npmjs.org/browserslist/-/browserslist-4.24.2.tgz", + "integrity": "sha512-ZIc+Q62revdMcqC6aChtW4jz3My3klmCO1fEmINZY/8J3EpBg5/A/D0AKmBveUh6pgoeycoMkVMko84tuYS+Gg==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "caniuse-lite": "^1.0.30001669", + "electron-to-chromium": "^1.5.41", + "node-releases": "^2.0.18", + "update-browserslist-db": "^1.1.1" + }, + "bin": { + "browserslist": "cli.js" + }, + "engines": { + "node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7" + } + }, + "node_modules/browserslist/node_modules/electron-to-chromium": { + "version": "1.5.67", + "resolved": "https://registry.npmjs.org/electron-to-chromium/-/electron-to-chromium-1.5.67.tgz", + "integrity": "sha512-nz88NNBsD7kQSAGGJyp8hS6xSPtWwqNogA0mjtc2nUYeEf3nURK9qpV18TuBdDmEDgVWotS8Wkzf+V52dSQ/LQ==", + "license": "ISC" + }, + "node_modules/browserslist/node_modules/escalade": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/browserslist/node_modules/node-releases": { + "version": "2.0.18", + "resolved": "https://registry.npmjs.org/node-releases/-/node-releases-2.0.18.tgz", + "integrity": "sha512-d9VeXT4SJ7ZeOqGX6R5EM022wpL+eWPooLI+5UpWn2jCT1aosUQEhQP214x33Wkwx3JQMvIm+tIoVOdodFS40g==", + "license": "MIT" + }, + "node_modules/browserslist/node_modules/update-browserslist-db": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/update-browserslist-db/-/update-browserslist-db-1.1.1.tgz", + "integrity": "sha512-R8UzCaa9Az+38REPiJ1tXlImTJXlVfgHZsglwBD/k6nj76ctsH1E3q4doGrukiLQd3sGQYu56r5+lo5r94l29A==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/browserslist" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "escalade": "^3.2.0", + "picocolors": "^1.1.0" + }, + "bin": { + "update-browserslist-db": "cli.js" + }, + "peerDependencies": { + "browserslist": ">= 4.21.0" + } + }, + "node_modules/buffer-builder": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/buffer-builder/-/buffer-builder-0.2.0.tgz", + "integrity": "sha512-7VPMEPuYznPSoR21NE1zvd2Xna6c/CloiZCfcMXR1Jny6PjX0N4Nsa38zcBFo/FMK+BlA+FLKbJCQ0i2yxp+Xg==", + "devOptional": true, + "license": "MIT/X11" + }, + "node_modules/caniuse-lite": { + "version": "1.0.30001684", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001684.tgz", + "integrity": "sha512-G1LRwLIQjBQoyq0ZJGqGIJUXzJ8irpbjHLpVRXDvBEScFJ9b17sgK6vlx0GAJFE21okD7zXl08rRRUfq6HdoEQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/browserslist" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/caniuse-lite" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "CC-BY-4.0" + }, + "node_modules/chokidar": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.6.0.tgz", + "integrity": "sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==", + "license": "MIT", + "dependencies": { + "anymatch": "~3.1.2", + "braces": "~3.0.2", + "glob-parent": "~5.1.2", + "is-binary-path": "~2.1.0", + "is-glob": "~4.0.1", + "normalize-path": "~3.0.0", + "readdirp": "~3.6.0" + }, + "engines": { + "node": ">= 8.10.0" + }, + "funding": { + "url": "https://paulmillr.com/funding/" + }, + "optionalDependencies": { + "fsevents": "~2.3.2" + } + }, + "node_modules/chokidar/node_modules/anymatch": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.3.tgz", + "integrity": "sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==", + "license": "ISC", + "dependencies": { + "normalize-path": "^3.0.0", + "picomatch": "^2.0.4" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/chokidar/node_modules/binary-extensions": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.3.0.tgz", + "integrity": "sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==", + "license": "MIT", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/chokidar/node_modules/braces": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", + "license": "MIT", + "dependencies": { + "fill-range": "^7.1.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/chokidar/node_modules/fill-range": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", + "license": "MIT", + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/chokidar/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/chokidar/node_modules/is-binary-path": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", + "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", + "license": "MIT", + "dependencies": { + "binary-extensions": "^2.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/chokidar/node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "license": "MIT", + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/chokidar/node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "license": "MIT", + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/chokidar/node_modules/readdirp": { + "version": "3.6.0", + "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", + "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", + "license": "MIT", + "dependencies": { + "picomatch": "^2.2.1" + }, + "engines": { + "node": ">=8.10.0" + } + }, + "node_modules/chokidar/node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "license": "MIT", + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/colorjs.io": { + "version": "0.5.2", + "resolved": "https://registry.npmjs.org/colorjs.io/-/colorjs.io-0.5.2.tgz", + "integrity": "sha512-twmVoizEW7ylZSN32OgKdXRmo1qg+wT5/6C3xu5b9QsWzSFAhHLn2xd8ro0diCsKfCj1RdaTP/nrcW+vAoQPIw==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/commander": { + "version": "8.3.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz", + "integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==", + "license": "MIT", + "engines": { + "node": ">= 12" + } + }, + "node_modules/css-selector-tokenizer": { + "version": "0.8.0", + "resolved": "https://registry.npmjs.org/css-selector-tokenizer/-/css-selector-tokenizer-0.8.0.tgz", + "integrity": "sha512-Jd6Ig3/pe62/qe5SBPTN8h8LeUg/pT4lLgtavPf7updwwHpvFzxvOQBHYj2LZDMjUnBzgvIUSjRcf6oT5HzHFg==", + "license": "MIT", + "dependencies": { + "cssesc": "^3.0.0", + "fastparse": "^1.1.2" + } + }, + "node_modules/css-selector-tokenizer/node_modules/cssesc": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", + "integrity": "sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg==", + "license": "MIT", + "bin": { + "cssesc": "bin/cssesc" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/css-selector-tokenizer/node_modules/fastparse": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/fastparse/-/fastparse-1.1.2.tgz", + "integrity": "sha512-483XLLxTVIwWK3QTrMGRqUfUpoOs/0hbQrl2oz4J0pAcm3A3bu84wxTFqGqkJzewCLdME38xJLJAxBABfQT8sQ==", + "license": "MIT" + }, + "node_modules/culori": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/culori/-/culori-3.3.0.tgz", + "integrity": "sha512-pHJg+jbuFsCjz9iclQBqyL3B2HLCBF71BwVNujUYEvCeQMvV97R59MNK3R2+jgJ3a1fcZgI9B3vYgz8lzr/BFQ==", + "license": "MIT", + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + } + }, + "node_modules/daisyui": { + "version": "4.12.14", + "resolved": "https://registry.npmjs.org/daisyui/-/daisyui-4.12.14.tgz", + "integrity": "sha512-hA27cdBasdwd4/iEjn+aidoCrRroDuo3G5W9NDKaVCJI437Mm/3eSL/2u7MkZ0pt8a+TrYF3aT2pFVemTS3how==", + "license": "MIT", + "dependencies": { + "css-selector-tokenizer": "^0.8", + "culori": "^3", + "picocolors": "^1", + "postcss-js": "^4" + }, + "engines": { + "node": ">=16.9.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/daisyui" + } + }, + "node_modules/didyoumean": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/didyoumean/-/didyoumean-1.2.2.tgz", + "integrity": "sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==", + "license": "Apache-2.0" + }, + "node_modules/dlv": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/dlv/-/dlv-1.1.3.tgz", + "integrity": "sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA==", + "license": "MIT" + }, + "node_modules/entities": { + "version": "4.5.0", + "resolved": "https://registry.npmjs.org/entities/-/entities-4.5.0.tgz", + "integrity": "sha512-V0hjH4dGPh9Ao5p0MoRY6BVqtwCjhz6vI5LT8AJ55H+4g9/4vbHx1I54fS0XuclLhDHArPQCiMjDxjaL8fPxhw==", + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.12" + }, + "funding": { + "url": "https://github.com/fb55/entities?sponsor=1" + } + }, + "node_modules/esbuild": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.21.5.tgz", + "integrity": "sha512-mg3OPMV4hXywwpoDxu3Qda5xCKQi+vCTZq8S9J/EpkhB2HzKXq4SNFZE3+NK93JYxc8VMSep+lOUSC/RVKaBqw==", + "hasInstallScript": true, + "license": "MIT", + "bin": { + "esbuild": "bin/esbuild" + }, + "engines": { + "node": ">=12" + }, + "optionalDependencies": { + "@esbuild/aix-ppc64": "0.21.5", + "@esbuild/android-arm": "0.21.5", + "@esbuild/android-arm64": "0.21.5", + "@esbuild/android-x64": "0.21.5", + "@esbuild/darwin-arm64": "0.21.5", + "@esbuild/darwin-x64": "0.21.5", + "@esbuild/freebsd-arm64": "0.21.5", + "@esbuild/freebsd-x64": "0.21.5", + "@esbuild/linux-arm": "0.21.5", + "@esbuild/linux-arm64": "0.21.5", + "@esbuild/linux-ia32": "0.21.5", + "@esbuild/linux-loong64": "0.21.5", + "@esbuild/linux-mips64el": "0.21.5", + "@esbuild/linux-ppc64": "0.21.5", + "@esbuild/linux-riscv64": "0.21.5", + "@esbuild/linux-s390x": "0.21.5", + "@esbuild/linux-x64": "0.21.5", + "@esbuild/netbsd-x64": "0.21.5", + "@esbuild/openbsd-x64": "0.21.5", + "@esbuild/sunos-x64": "0.21.5", + "@esbuild/win32-arm64": "0.21.5", + "@esbuild/win32-ia32": "0.21.5", + "@esbuild/win32-x64": "0.21.5" + } + }, + "node_modules/esbuild/node_modules/@esbuild/darwin-arm64": { + "version": "0.21.5", + "resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.21.5.tgz", + "integrity": "sha512-DwqXqZyuk5AiWWf3UfLiRDJ5EDd49zg6O9wclZ7kUMv2WRFr4HKjXp/5t8JZ11QbQfUS6/cRCKGwYhtNAY88kQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=12" + } + }, + "node_modules/fast-glob": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.2.tgz", + "integrity": "sha512-oX2ruAFQwf/Orj8m737Y5adxDQO0LAB7/S5MnxCdTNDd4p6BsyIVsv9JQsATbTSq8KHRpLwIHbVlUNatxd+1Ow==", + "license": "MIT", + "dependencies": { + "@nodelib/fs.stat": "^2.0.2", + "@nodelib/fs.walk": "^1.2.3", + "glob-parent": "^5.1.2", + "merge2": "^1.3.0", + "micromatch": "^4.0.4" + }, + "engines": { + "node": ">=8.6.0" + } + }, + "node_modules/fast-glob/node_modules/@nodelib/fs.scandir": { + "version": "2.1.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", + "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", + "license": "MIT", + "dependencies": { + "@nodelib/fs.stat": "2.0.5", + "run-parallel": "^1.1.9" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/fast-glob/node_modules/@nodelib/fs.stat": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", + "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/fast-glob/node_modules/@nodelib/fs.walk": { + "version": "1.2.8", + "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", + "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", + "license": "MIT", + "dependencies": { + "@nodelib/fs.scandir": "2.1.5", + "fastq": "^1.6.0" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/fast-glob/node_modules/fastq": { + "version": "1.17.1", + "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.17.1.tgz", + "integrity": "sha512-sRVD3lWVIXWg6By68ZN7vho9a1pQcN/WBFaAAsDDFzlJjvoGx0P8z7V1t72grFJfJhu3YPZBuu25f7Kaw2jN1w==", + "license": "ISC", + "dependencies": { + "reusify": "^1.0.4" + } + }, + "node_modules/fast-glob/node_modules/glob-parent": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", + "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/fast-glob/node_modules/merge2": { + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", + "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/fast-glob/node_modules/queue-microtask": { + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", + "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/fast-glob/node_modules/reusify": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", + "integrity": "sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==", + "license": "MIT", + "engines": { + "iojs": ">=1.0.0", + "node": ">=0.10.0" + } + }, + "node_modules/fast-glob/node_modules/run-parallel": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", + "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT", + "dependencies": { + "queue-microtask": "^1.2.2" + } + }, + "node_modules/fraction.js": { + "version": "4.3.7", + "resolved": "https://registry.npmjs.org/fraction.js/-/fraction.js-4.3.7.tgz", + "integrity": "sha512-ZsDfxO51wGAXREY55a7la9LScWpwv9RxIrYABrlvOFBlH/ShPnrtsXeuUIfXKKOVicNxQ+o8JTbJvjS4M89yew==", + "license": "MIT", + "engines": { + "node": "*" + }, + "funding": { + "type": "patreon", + "url": "https://github.com/sponsors/rawify" + } + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "hasInstallScript": true, + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/glob-parent": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", + "integrity": "sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==", + "license": "ISC", + "dependencies": { + "is-glob": "^4.0.3" + }, + "engines": { + "node": ">=10.13.0" + } + }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/highlight.js": { + "version": "11.10.0", + "resolved": "https://registry.npmjs.org/highlight.js/-/highlight.js-11.10.0.tgz", + "integrity": "sha512-SYVnVFswQER+zu1laSya563s+F8VDGt7o35d4utbamowvUNLLMovFqwCLSocpZTz3MgaSRA1IbqRWZv97dtErQ==", + "engines": { + "node": ">=12.0.0" + } + }, + "node_modules/immutable": { + "version": "5.0.3", + "resolved": "https://registry.npmjs.org/immutable/-/immutable-5.0.3.tgz", + "integrity": "sha512-P8IdPQHq3lA1xVeBRi5VPqUm5HDgKnx0Ru51wZz5mjxHr5n3RWhjIpOFU7ybkUxfB+5IToy+OLaHYDBIWsv+uw==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/is-glob": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", + "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", + "license": "MIT", + "dependencies": { + "is-extglob": "^2.1.1" + }, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/is-glob/node_modules/is-extglob": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", + "integrity": "sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/jiti": { + "version": "1.21.6", + "resolved": "https://registry.npmjs.org/jiti/-/jiti-1.21.6.tgz", + "integrity": "sha512-2yTgeWTWzMWkHu6Jp9NKgePDaYHbntiwvYuuJLbbN9vl7DC9DvXKOB2BC3ZZ92D3cvV/aflH0osDfwpHepQ53w==", + "license": "MIT", + "bin": { + "jiti": "bin/jiti.js" + } + }, + "node_modules/katex": { + "version": "0.16.15", + "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.15.tgz", + "integrity": "sha512-yE9YJIEAk2aZ+FL/G8r+UGw0CTUzEA8ZFy6E+8tc3spHUKq3qBnzCkI1CQwGoI9atJhVyFPEypQsTY7mJ1Pi9w==", + "funding": [ + "https://opencollective.com/katex", + "https://github.com/sponsors/katex" + ], + "license": "MIT", + "dependencies": { + "commander": "^8.3.0" + }, + "bin": { + "katex": "cli.js" + } + }, + "node_modules/lilconfig": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-2.1.0.tgz", + "integrity": "sha512-utWOt/GHzuUxnLKxB6dk81RoOeoNeHgbrXiuGk4yyF5qlRz+iIVWu56E2fqGHFrXz0QNUhLB/8nKqvRH66JKGQ==", + "license": "MIT", + "engines": { + "node": ">=10" + } + }, + "node_modules/linkify-it": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/linkify-it/-/linkify-it-5.0.0.tgz", + "integrity": "sha512-5aHCbzQRADcdP+ATqnDuhhJ/MRIqDkZX5pyjFHRRysS8vZ5AbqGEoFIb6pYHPZ+L/OC2Lc+xT8uHVVR5CAK/wQ==", + "license": "MIT", + "dependencies": { + "uc.micro": "^2.0.0" + } + }, + "node_modules/markdown-it": { + "version": "14.1.0", + "resolved": "https://registry.npmjs.org/markdown-it/-/markdown-it-14.1.0.tgz", + "integrity": "sha512-a54IwgWPaeBCAAsv13YgmALOF1elABB08FxO9i+r4VFk5Vl4pKokRPeX8u5TCgSsPi6ec1otfLjdOpVcgbpshg==", + "license": "MIT", + "dependencies": { + "argparse": "^2.0.1", + "entities": "^4.4.0", + "linkify-it": "^5.0.0", + "mdurl": "^2.0.0", + "punycode.js": "^2.3.1", + "uc.micro": "^2.1.0" + }, + "bin": { + "markdown-it": "bin/markdown-it.mjs" + } + }, + "node_modules/mdurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdurl/-/mdurl-2.0.0.tgz", + "integrity": "sha512-Lf+9+2r+Tdp5wXDXC4PcIBjTDtq4UKjCPMQhKIuzpJNW0b96kVqSwW0bT7FhRSfmAiFYgP+SCRvdrDozfh0U5w==", + "license": "MIT" + }, + "node_modules/micromatch": { + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", + "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", + "license": "MIT", + "dependencies": { + "braces": "^3.0.3", + "picomatch": "^2.3.1" + }, + "engines": { + "node": ">=8.6" + } + }, + "node_modules/micromatch/node_modules/braces": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", + "license": "MIT", + "dependencies": { + "fill-range": "^7.1.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/micromatch/node_modules/fill-range": { + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", + "license": "MIT", + "dependencies": { + "to-regex-range": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/micromatch/node_modules/is-number": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", + "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", + "license": "MIT", + "engines": { + "node": ">=0.12.0" + } + }, + "node_modules/micromatch/node_modules/picomatch": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", + "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", + "license": "MIT", + "engines": { + "node": ">=8.6" + }, + "funding": { + "url": "https://github.com/sponsors/jonschlinkert" + } + }, + "node_modules/micromatch/node_modules/to-regex-range": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", + "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", + "license": "MIT", + "dependencies": { + "is-number": "^7.0.0" + }, + "engines": { + "node": ">=8.0" + } + }, + "node_modules/normalize-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", + "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/normalize-range": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/normalize-range/-/normalize-range-0.1.2.tgz", + "integrity": "sha512-bdok/XvKII3nUpklnV6P2hxtMNrCboOjAcyBuQnWEhO665FwrSNRxU+AqpsyvO6LgGYPspN+lu5CLtw4jPRKNA==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-hash": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/object-hash/-/object-hash-3.0.0.tgz", + "integrity": "sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw==", + "license": "MIT", + "engines": { + "node": ">= 6" + } + }, + "node_modules/picocolors": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", + "integrity": "sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==", + "license": "ISC" + }, + "node_modules/postcss": { + "version": "8.4.49", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.49.tgz", + "integrity": "sha512-OCVPnIObs4N29kxTjzLfUryOkvZEq+pf8jTF0lg8E7uETuWHA+v7j3c/xJmiqpX450191LlmZfUKkXxkTry7nA==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "tidelift", + "url": "https://tidelift.com/funding/github/npm/postcss" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "nanoid": "^3.3.7", + "picocolors": "^1.1.1", + "source-map-js": "^1.2.1" + }, + "engines": { + "node": "^10 || ^12 || >=14" + } + }, + "node_modules/postcss-import": { + "version": "15.1.0", + "resolved": "https://registry.npmjs.org/postcss-import/-/postcss-import-15.1.0.tgz", + "integrity": "sha512-hpr+J05B2FVYUAXHeK1YyI267J/dDDhMU6B6civm8hSY1jYJnBXxzKDKDswzJmtLHryrjhnDjqqp/49t8FALew==", + "license": "MIT", + "dependencies": { + "postcss-value-parser": "^4.0.0", + "read-cache": "^1.0.0", + "resolve": "^1.1.7" + }, + "engines": { + "node": ">=14.0.0" + }, + "peerDependencies": { + "postcss": "^8.0.0" + } + }, + "node_modules/postcss-import/node_modules/pify": { + "version": "2.3.0", + "resolved": "https://registry.npmjs.org/pify/-/pify-2.3.0.tgz", + "integrity": "sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/postcss-import/node_modules/read-cache": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/read-cache/-/read-cache-1.0.0.tgz", + "integrity": "sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA==", + "license": "MIT", + "dependencies": { + "pify": "^2.3.0" + } + }, + "node_modules/postcss-js": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/postcss-js/-/postcss-js-4.0.1.tgz", + "integrity": "sha512-dDLF8pEO191hJMtlHFPRa8xsizHaM82MLfNkUHdUtVEV3tgTp5oj+8qbEqYM57SLfc74KSbw//4SeJma2LRVIw==", + "license": "MIT", + "dependencies": { + "camelcase-css": "^2.0.1" + }, + "engines": { + "node": "^12 || ^14 || >= 16" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + "peerDependencies": { + "postcss": "^8.4.21" + } + }, + "node_modules/postcss-js/node_modules/camelcase-css": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/camelcase-css/-/camelcase-css-2.0.1.tgz", + "integrity": "sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==", + "license": "MIT", + "engines": { + "node": ">= 6" + } + }, + "node_modules/postcss-load-config": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/postcss-load-config/-/postcss-load-config-4.0.2.tgz", + "integrity": "sha512-bSVhyJGL00wMVoPUzAVAnbEoWyqRxkjv64tUl427SKnPrENtq6hJwUojroMz2VB+Q1edmi4IfrAPpami5VVgMQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "lilconfig": "^3.0.0", + "yaml": "^2.3.4" + }, + "engines": { + "node": ">= 14" + }, + "peerDependencies": { + "postcss": ">=8.0.9", + "ts-node": ">=9.0.0" + }, + "peerDependenciesMeta": { + "postcss": { + "optional": true + }, + "ts-node": { + "optional": true + } + } + }, + "node_modules/postcss-load-config/node_modules/lilconfig": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/lilconfig/-/lilconfig-3.1.2.tgz", + "integrity": "sha512-eop+wDAvpItUys0FWkHIKeC9ybYrTGbU41U5K7+bttZZeohvnY7M9dZ5kB21GNWiFT2q1OoPTvncPCgSOVO5ow==", + "license": "MIT", + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/antonk52" + } + }, + "node_modules/postcss-load-config/node_modules/yaml": { + "version": "2.6.1", + "resolved": "https://registry.npmjs.org/yaml/-/yaml-2.6.1.tgz", + "integrity": "sha512-7r0XPzioN/Q9kXBro/XPnA6kznR73DHq+GXh5ON7ZozRO6aMjbmiBuKste2wslTFkC5d1dw0GooOCepZXJ2SAg==", + "license": "ISC", + "bin": { + "yaml": "bin.mjs" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/postcss-nested": { + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/postcss-nested/-/postcss-nested-6.2.0.tgz", + "integrity": "sha512-HQbt28KulC5AJzG+cZtj9kvKB93CFCdLvog1WFLf1D+xmMvPGlBstkpTEZfK5+AN9hfJocyBFCNiqyS48bpgzQ==", + "funding": [ + { + "type": "opencollective", + "url": "https://opencollective.com/postcss/" + }, + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "dependencies": { + "postcss-selector-parser": "^6.1.1" + }, + "engines": { + "node": ">=12.0" + }, + "peerDependencies": { + "postcss": "^8.2.14" + } + }, + "node_modules/postcss-selector-parser": { + "version": "6.1.2", + "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.1.2.tgz", + "integrity": "sha512-Q8qQfPiZ+THO/3ZrOrO0cJJKfpYCagtMUkXbnEfmgUjwXg6z/WBeOyS9APBBPCTSiDV+s4SwQGu8yFsiMRIudg==", + "license": "MIT", + "dependencies": { + "cssesc": "^3.0.0", + "util-deprecate": "^1.0.2" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/postcss-selector-parser/node_modules/cssesc": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/cssesc/-/cssesc-3.0.0.tgz", + "integrity": "sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg==", + "license": "MIT", + "bin": { + "cssesc": "bin/cssesc" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/postcss-selector-parser/node_modules/util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", + "license": "MIT" + }, + "node_modules/postcss-value-parser": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz", + "integrity": "sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==", + "license": "MIT" + }, + "node_modules/postcss/node_modules/nanoid": { + "version": "3.3.8", + "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.8.tgz", + "integrity": "sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/ai" + } + ], + "license": "MIT", + "bin": { + "nanoid": "bin/nanoid.cjs" + }, + "engines": { + "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" + } + }, + "node_modules/postcss/node_modules/source-map-js": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.2.1.tgz", + "integrity": "sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==", + "license": "BSD-3-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/punycode.js": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode.js/-/punycode.js-2.3.1.tgz", + "integrity": "sha512-uxFIHU0YlHYhDQtV4R9J6a52SLx28BCjT+4ieh7IGbgwVJWO+km431c4yRlREUAsAmt/uMjQUyQHNEPf0M39CA==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/resolve": { + "version": "1.22.8", + "resolved": "https://registry.npmjs.org/resolve/-/resolve-1.22.8.tgz", + "integrity": "sha512-oKWePCxqpd6FlLvGV1VU0x7bkPmmCNolxzjMf4NczoDnQcIWrAF+cPtZn5i6n+RfD2d9i0tzpKnG6Yk168yIyw==", + "license": "MIT", + "dependencies": { + "is-core-module": "^2.13.0", + "path-parse": "^1.0.7", + "supports-preserve-symlinks-flag": "^1.0.0" + }, + "bin": { + "resolve": "bin/resolve" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve/node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve/node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/resolve/node_modules/is-core-module": { + "version": "2.15.1", + "resolved": "https://registry.npmjs.org/is-core-module/-/is-core-module-2.15.1.tgz", + "integrity": "sha512-z0vtXSwucUJtANQWldhbtbt7BnL0vxiFjIdDLAatwhDYty2bad6s+rijD6Ri4YuYJubLzIJLUidCh09e1djEVQ==", + "license": "MIT", + "dependencies": { + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/resolve/node_modules/path-parse": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/path-parse/-/path-parse-1.0.7.tgz", + "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", + "license": "MIT" + }, + "node_modules/resolve/node_modules/supports-preserve-symlinks-flag": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz", + "integrity": "sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/rollup": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.28.0.tgz", + "integrity": "sha512-G9GOrmgWHBma4YfCcX8PjH0qhXSdH8B4HDE2o4/jaxj93S4DPCIDoLcXz99eWMji4hB29UFCEd7B2gwGJDR9cQ==", + "license": "MIT", + "dependencies": { + "@types/estree": "1.0.6" + }, + "bin": { + "rollup": "dist/bin/rollup" + }, + "engines": { + "node": ">=18.0.0", + "npm": ">=8.0.0" + }, + "optionalDependencies": { + "@rollup/rollup-android-arm-eabi": "4.28.0", + "@rollup/rollup-android-arm64": "4.28.0", + "@rollup/rollup-darwin-arm64": "4.28.0", + "@rollup/rollup-darwin-x64": "4.28.0", + "@rollup/rollup-freebsd-arm64": "4.28.0", + "@rollup/rollup-freebsd-x64": "4.28.0", + "@rollup/rollup-linux-arm-gnueabihf": "4.28.0", + "@rollup/rollup-linux-arm-musleabihf": "4.28.0", + "@rollup/rollup-linux-arm64-gnu": "4.28.0", + "@rollup/rollup-linux-arm64-musl": "4.28.0", + "@rollup/rollup-linux-powerpc64le-gnu": "4.28.0", + "@rollup/rollup-linux-riscv64-gnu": "4.28.0", + "@rollup/rollup-linux-s390x-gnu": "4.28.0", + "@rollup/rollup-linux-x64-gnu": "4.28.0", + "@rollup/rollup-linux-x64-musl": "4.28.0", + "@rollup/rollup-win32-arm64-msvc": "4.28.0", + "@rollup/rollup-win32-ia32-msvc": "4.28.0", + "@rollup/rollup-win32-x64-msvc": "4.28.0", + "fsevents": "~2.3.2" + } + }, + "node_modules/rollup/node_modules/@rollup/rollup-darwin-arm64": { + "version": "4.28.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.28.0.tgz", + "integrity": "sha512-lmKx9yHsppblnLQZOGxdO66gT77bvdBtr/0P+TPOseowE7D9AJoBw8ZDULRasXRWf1Z86/gcOdpBrV6VDUY36Q==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/rollup/node_modules/@types/estree": { + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.6.tgz", + "integrity": "sha512-AYnb1nQyY49te+VRAVgmzfcgjYS91mY5P0TKUDCLEM+gNnA+3T6rWITXRLYCpahpqSQbN5cE+gHpnPyXjHWxcw==", + "license": "MIT" + }, + "node_modules/rxjs": { + "version": "7.8.1", + "resolved": "https://registry.npmjs.org/rxjs/-/rxjs-7.8.1.tgz", + "integrity": "sha512-AA3TVj+0A2iuIoQkWEK/tqFjBq2j+6PO6Y0zJcvzLAFhEFIO3HL0vls9hWLncZbAAbK0mar7oZ4V079I/qPMxg==", + "devOptional": true, + "license": "Apache-2.0", + "dependencies": { + "tslib": "^2.1.0" + } + }, + "node_modules/sass-embedded": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded/-/sass-embedded-1.83.0.tgz", + "integrity": "sha512-/8cYZeL39evUqe0o//193na51Q1VWZ61qhxioQvLJwOtWIrX+PgNhCyD8RSuTtmzc4+6+waFZf899bfp/MCUwA==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "@bufbuild/protobuf": "^2.0.0", + "buffer-builder": "^0.2.0", + "colorjs.io": "^0.5.0", + "immutable": "^5.0.2", + "rxjs": "^7.4.0", + "supports-color": "^8.1.1", + "sync-child-process": "^1.0.2", + "varint": "^6.0.0" + }, + "bin": { + "sass": "dist/bin/sass.js" + }, + "engines": { + "node": ">=16.0.0" + }, + "optionalDependencies": { + "sass-embedded-android-arm": "1.83.0", + "sass-embedded-android-arm64": "1.83.0", + "sass-embedded-android-ia32": "1.83.0", + "sass-embedded-android-riscv64": "1.83.0", + "sass-embedded-android-x64": "1.83.0", + "sass-embedded-darwin-arm64": "1.83.0", + "sass-embedded-darwin-x64": "1.83.0", + "sass-embedded-linux-arm": "1.83.0", + "sass-embedded-linux-arm64": "1.83.0", + "sass-embedded-linux-ia32": "1.83.0", + "sass-embedded-linux-musl-arm": "1.83.0", + "sass-embedded-linux-musl-arm64": "1.83.0", + "sass-embedded-linux-musl-ia32": "1.83.0", + "sass-embedded-linux-musl-riscv64": "1.83.0", + "sass-embedded-linux-musl-x64": "1.83.0", + "sass-embedded-linux-riscv64": "1.83.0", + "sass-embedded-linux-x64": "1.83.0", + "sass-embedded-win32-arm64": "1.83.0", + "sass-embedded-win32-ia32": "1.83.0", + "sass-embedded-win32-x64": "1.83.0" + } + }, + "node_modules/sass-embedded-android-arm": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-arm/-/sass-embedded-android-arm-1.83.0.tgz", + "integrity": "sha512-uwFSXzJlfbd4Px189xE5l+cxN8+TQpXdQgJec7TIrb4HEY7imabtpYufpVdqUVwT1/uiis5V4+qIEC4Vl5XObQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-arm64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-arm64/-/sass-embedded-android-arm64-1.83.0.tgz", + "integrity": "sha512-GBiCvM4a2rkWBLdYDxI6XYnprfk5U5c81g69RC2X6kqPuzxzx8qTArQ9M6keFK4+iDQ5N9QTwFCr0KbZTn+ZNQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-ia32": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-ia32/-/sass-embedded-android-ia32-1.83.0.tgz", + "integrity": "sha512-5ATPdGo2SICqAhiJl/Z8KQ23zH4sGgobGgux0TnrNtt83uHZ+r+To/ubVJ7xTkZxed+KJZnIpolGD8dQyQqoTg==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-riscv64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-riscv64/-/sass-embedded-android-riscv64-1.83.0.tgz", + "integrity": "sha512-aveknUOB8GZewOzVn2Uwk+DKcncTR50Q6vtzslNMGbYnxtgQNHzy8A1qVEviNUruex+pHofppeMK4iMPFAbiEQ==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-android-x64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-android-x64/-/sass-embedded-android-x64-1.83.0.tgz", + "integrity": "sha512-WqIay/72ncyf9Ph4vS742J3a73wZihWmzFUwpn1OD6lme1Aj4eWzWIve5IVnlTEJgcZcDHu6ECID9IZgehJKoA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-darwin-arm64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-darwin-arm64/-/sass-embedded-darwin-arm64-1.83.0.tgz", + "integrity": "sha512-XQl9QqgxFFIPm/CzHhmppse5o9ocxrbaAdC2/DAnlAqvYWBBtgFqPjGoYlej13h9SzfvNoogx+y9r+Ap+e+hYg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-darwin-x64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-darwin-x64/-/sass-embedded-darwin-x64-1.83.0.tgz", + "integrity": "sha512-ERQ7Tvp1kFOW3ux4VDFIxb7tkYXHYc+zJpcrbs0hzcIO5ilIRU2tIOK1OrNwrFO6Qxyf7AUuBwYKLAtIU/Nz7g==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-arm": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-arm/-/sass-embedded-linux-arm-1.83.0.tgz", + "integrity": "sha512-baG9RYBJxUFmqwDNC9h9ZFElgJoyO3jgHGjzEZ1wHhIS9anpG+zZQvO8bHx3dBpKEImX+DBeLX+CxsFR9n81gQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-arm64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-arm64/-/sass-embedded-linux-arm64-1.83.0.tgz", + "integrity": "sha512-syEAVTJt4qhaMLxrSwOWa46zdqHJdnqJkLUK+t9aCr8xqBZLPxSUeIGji76uOehQZ1C+KGFj6n9xstHN6wzOJw==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-ia32": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-ia32/-/sass-embedded-linux-ia32-1.83.0.tgz", + "integrity": "sha512-RRBxQxMpoxu5+XcSSc6QR/o9asEwUzR8AbCS83RaXcdTIHTa/CccQsiAoDDoPlRsMTLqnzs0LKL4CfOsf7zBbA==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-arm": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-arm/-/sass-embedded-linux-musl-arm-1.83.0.tgz", + "integrity": "sha512-Yc7u2TelCfBab+PRob9/MNJFh3EooMiz4urvhejXkihTiKSHGCv5YqDdtWzvyb9tY2Jb7YtYREVuHwfdVn3dTQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-arm64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-arm64/-/sass-embedded-linux-musl-arm64-1.83.0.tgz", + "integrity": "sha512-Y7juhPHClUO2H5O+u+StRy6SEAcwZ+hTEk5WJdEmo1Bb1gDtfHvJaWB/iFZJ2tW0W1e865AZeUrC4OcOFjyAQA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-ia32": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-ia32/-/sass-embedded-linux-musl-ia32-1.83.0.tgz", + "integrity": "sha512-arQeYwGmwXV8byx5G1PtSzZWW1jbkfR5qrIHMEbTFSAvAxpqjgSvCvrHMOFd73FcMxVaYh4BX9LQNbKinkbEdg==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-riscv64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-riscv64/-/sass-embedded-linux-musl-riscv64-1.83.0.tgz", + "integrity": "sha512-E6uzlIWz59rut+Z3XR6mLG915zNzv07ISvj3GUNZENdHM7dF8GQ//ANoIpl5PljMQKp89GnYdvo6kj2gnaBf/g==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-musl-x64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-musl-x64/-/sass-embedded-linux-musl-x64-1.83.0.tgz", + "integrity": "sha512-eAMK6tyGqvqr21r9g8BnR3fQc1rYFj85RGduSQ3xkITZ6jOAnOhuU94N5fwRS852Hpws0lXhET+7JHXgg3U18w==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-riscv64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-riscv64/-/sass-embedded-linux-riscv64-1.83.0.tgz", + "integrity": "sha512-Ojpi78pTv02sy2fUYirRGXHLY3fPnV/bvwuC2i5LwPQw2LpCcFyFTtN0c5h4LJDk9P6wr+/ZB/JXU8tHIOlK+Q==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-linux-x64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-linux-x64/-/sass-embedded-linux-x64-1.83.0.tgz", + "integrity": "sha512-3iLjlXdoPfgZRtX4odhRvka1BQs5mAXqfCtDIQBgh/o0JnGPzJIWWl9bYLpHxK8qb+uyVBxXYgXpI0sCzArBOw==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-win32-arm64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-win32-arm64/-/sass-embedded-win32-arm64-1.83.0.tgz", + "integrity": "sha512-iOHw/8/t2dlTW3lOFwG5eUbiwhEyGWawivlKWJ8lkXH7fjMpVx2VO9zCFAm8RvY9xOHJ9sf1L7g5bx3EnNP9BQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-win32-ia32": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-win32-ia32/-/sass-embedded-win32-ia32-1.83.0.tgz", + "integrity": "sha512-2PxNXJ8Pad4geVcTXY4rkyTr5AwbF8nfrCTDv0ulbTvPhzX2mMKEGcBZUXWn5BeHZTBc6whNMfS7d5fQXR9dDQ==", + "cpu": [ + "ia32" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sass-embedded-win32-x64": { + "version": "1.83.0", + "resolved": "https://registry.npmjs.org/sass-embedded-win32-x64/-/sass-embedded-win32-x64-1.83.0.tgz", + "integrity": "sha512-muBXkFngM6eLTNqOV0FQi7Dv9s+YRQ42Yem26mosdan/GmJQc81deto6uDTgrYn+bzFNmiXcOdfm+0MkTWK3OQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/sucrase": { + "version": "3.35.0", + "resolved": "https://registry.npmjs.org/sucrase/-/sucrase-3.35.0.tgz", + "integrity": "sha512-8EbVDiu9iN/nESwxeSxDKe0dunta1GOlHufmSSXxMD2z2/tMZpDMpvXQGsc+ajGo8y2uYUmixaSRUc/QPoQ0GA==", + "license": "MIT", + "dependencies": { + "@jridgewell/gen-mapping": "^0.3.2", + "commander": "^4.0.0", + "glob": "^10.3.10", + "lines-and-columns": "^1.1.6", + "mz": "^2.7.0", + "pirates": "^4.0.1", + "ts-interface-checker": "^0.1.9" + }, + "bin": { + "sucrase": "bin/sucrase", + "sucrase-node": "bin/sucrase-node" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + } + }, + "node_modules/sucrase/node_modules/@isaacs/cliui": { + "version": "8.0.2", + "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", + "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==", + "license": "ISC", + "dependencies": { + "string-width": "^5.1.2", + "string-width-cjs": "npm:string-width@^4.2.0", + "strip-ansi": "^7.0.1", + "strip-ansi-cjs": "npm:strip-ansi@^6.0.1", + "wrap-ansi": "^8.1.0", + "wrap-ansi-cjs": "npm:wrap-ansi@^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/sucrase/node_modules/@jridgewell/gen-mapping": { + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz", + "integrity": "sha512-IzL8ZoEDIBRWEzlCcRhOaCupYyN5gdIK+Q6fbFdPDg6HqX6jpkItn7DFIpW9LQzXG6Df9sA7+OKnq0qlz/GaQg==", + "license": "MIT", + "dependencies": { + "@jridgewell/set-array": "^1.2.1", + "@jridgewell/sourcemap-codec": "^1.4.10", + "@jridgewell/trace-mapping": "^0.3.24" + }, + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/sucrase/node_modules/@jridgewell/resolve-uri": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/sucrase/node_modules/@jridgewell/set-array": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", + "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", + "license": "MIT", + "engines": { + "node": ">=6.0.0" + } + }, + "node_modules/sucrase/node_modules/@jridgewell/sourcemap-codec": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", + "license": "MIT" + }, + "node_modules/sucrase/node_modules/@jridgewell/trace-mapping": { + "version": "0.3.25", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", + "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", + "license": "MIT", + "dependencies": { + "@jridgewell/resolve-uri": "^3.1.0", + "@jridgewell/sourcemap-codec": "^1.4.14" + } + }, + "node_modules/sucrase/node_modules/@pkgjs/parseargs": { + "version": "0.11.0", + "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz", + "integrity": "sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=14" + } + }, + "node_modules/sucrase/node_modules/ansi-regex": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.1.0.tgz", + "integrity": "sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-regex?sponsor=1" + } + }, + "node_modules/sucrase/node_modules/ansi-styles": { + "version": "6.2.1", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.1.tgz", + "integrity": "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/sucrase/node_modules/any-promise": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/any-promise/-/any-promise-1.3.0.tgz", + "integrity": "sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==", + "license": "MIT" + }, + "node_modules/sucrase/node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "license": "MIT" + }, + "node_modules/sucrase/node_modules/brace-expansion": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", + "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/sucrase/node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/sucrase/node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "license": "MIT" + }, + "node_modules/sucrase/node_modules/commander": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/commander/-/commander-4.1.1.tgz", + "integrity": "sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==", + "license": "MIT", + "engines": { + "node": ">= 6" + } + }, + "node_modules/sucrase/node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/sucrase/node_modules/eastasianwidth": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz", + "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==", + "license": "MIT" + }, + "node_modules/sucrase/node_modules/emoji-regex": { + "version": "9.2.2", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", + "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==", + "license": "MIT" + }, + "node_modules/sucrase/node_modules/foreground-child": { + "version": "3.3.0", + "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.3.0.tgz", + "integrity": "sha512-Ld2g8rrAyMYFXBhEqMz8ZAHBi4J4uS1i/CxGMDnjyFWddMXLVcDp051DZfu+t7+ab7Wv6SMqpWmyFIj5UbfFvg==", + "license": "ISC", + "dependencies": { + "cross-spawn": "^7.0.0", + "signal-exit": "^4.0.1" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/sucrase/node_modules/glob": { + "version": "10.4.5", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", + "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/sucrase/node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "license": "ISC" + }, + "node_modules/sucrase/node_modules/jackspeak": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", + "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", + "license": "BlueOak-1.0.0", + "dependencies": { + "@isaacs/cliui": "^8.0.2" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + }, + "optionalDependencies": { + "@pkgjs/parseargs": "^0.11.0" + } + }, + "node_modules/sucrase/node_modules/lines-and-columns": { + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz", + "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==", + "license": "MIT" + }, + "node_modules/sucrase/node_modules/lru-cache": { + "version": "10.4.3", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", + "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==", + "license": "ISC" + }, + "node_modules/sucrase/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/sucrase/node_modules/minipass": { + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz", + "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==", + "license": "ISC", + "engines": { + "node": ">=16 || 14 >=14.17" + } + }, + "node_modules/sucrase/node_modules/mz": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/mz/-/mz-2.7.0.tgz", + "integrity": "sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==", + "license": "MIT", + "dependencies": { + "any-promise": "^1.0.0", + "object-assign": "^4.0.1", + "thenify-all": "^1.0.0" + } + }, + "node_modules/sucrase/node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/sucrase/node_modules/package-json-from-dist": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", + "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==", + "license": "BlueOak-1.0.0" + }, + "node_modules/sucrase/node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/path-scurry": { + "version": "1.11.1", + "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", + "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", + "license": "BlueOak-1.0.0", + "dependencies": { + "lru-cache": "^10.2.0", + "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" + }, + "engines": { + "node": ">=16 || 14 >=14.18" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/sucrase/node_modules/pirates": { + "version": "4.0.6", + "resolved": "https://registry.npmjs.org/pirates/-/pirates-4.0.6.tgz", + "integrity": "sha512-saLsH7WeYYPiD25LDuLRRY/i+6HaPYr6G1OUlN39otzkSTxKnubR9RTxS3/Kk50s1g2JTgFwWQDQyplC5/SHZg==", + "license": "MIT", + "engines": { + "node": ">= 6" + } + }, + "node_modules/sucrase/node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/signal-exit": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", + "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", + "license": "ISC", + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/sucrase/node_modules/string-width": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz", + "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==", + "license": "MIT", + "dependencies": { + "eastasianwidth": "^0.2.0", + "emoji-regex": "^9.2.2", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/sucrase/node_modules/string-width-cjs": { + "name": "string-width", + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/string-width-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/string-width-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT" + }, + "node_modules/sucrase/node_modules/string-width-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/strip-ansi": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz", + "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^6.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/strip-ansi?sponsor=1" + } + }, + "node_modules/sucrase/node_modules/strip-ansi-cjs": { + "name": "strip-ansi", + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/strip-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/thenify": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/thenify/-/thenify-3.3.1.tgz", + "integrity": "sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==", + "license": "MIT", + "dependencies": { + "any-promise": "^1.0.0" + } + }, + "node_modules/sucrase/node_modules/thenify-all": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/thenify-all/-/thenify-all-1.6.0.tgz", + "integrity": "sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==", + "license": "MIT", + "dependencies": { + "thenify": ">= 3.1.0 < 4" + }, + "engines": { + "node": ">=0.8" + } + }, + "node_modules/sucrase/node_modules/ts-interface-checker": { + "version": "0.1.13", + "resolved": "https://registry.npmjs.org/ts-interface-checker/-/ts-interface-checker-0.1.13.tgz", + "integrity": "sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==", + "license": "Apache-2.0" + }, + "node_modules/sucrase/node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/sucrase/node_modules/wrap-ansi": { + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz", + "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^6.1.0", + "string-width": "^5.0.1", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/sucrase/node_modules/wrap-ansi-cjs": { + "name": "wrap-ansi", + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/sucrase/node_modules/wrap-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/wrap-ansi-cjs/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/sucrase/node_modules/wrap-ansi-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT" + }, + "node_modules/sucrase/node_modules/wrap-ansi-cjs/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/sucrase/node_modules/wrap-ansi-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/supports-color": { + "version": "8.1.1", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-8.1.1.tgz", + "integrity": "sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/supports-color?sponsor=1" + } + }, + "node_modules/sync-child-process": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/sync-child-process/-/sync-child-process-1.0.2.tgz", + "integrity": "sha512-8lD+t2KrrScJ/7KXCSyfhT3/hRq78rC0wBFqNJXv3mZyn6hW2ypM05JmlSvtqRbeq6jqA94oHbxAr2vYsJ8vDA==", + "devOptional": true, + "license": "MIT", + "dependencies": { + "sync-message-port": "^1.0.0" + }, + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/sync-message-port": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/sync-message-port/-/sync-message-port-1.1.3.tgz", + "integrity": "sha512-GTt8rSKje5FilG+wEdfCkOcLL7LWqpMlr2c3LRuKt/YXxcJ52aGSbGBAdI4L3aaqfrBt6y711El53ItyH1NWzg==", + "devOptional": true, + "license": "MIT", + "engines": { + "node": ">=16.0.0" + } + }, + "node_modules/tailwindcss": { + "version": "3.4.15", + "resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.15.tgz", + "integrity": "sha512-r4MeXnfBmSOuKUWmXe6h2CcyfzJCEk4F0pptO5jlnYSIViUkVmsawj80N5h2lO3gwcmSb4n3PuN+e+GC1Guylw==", + "license": "MIT", + "dependencies": { + "@alloc/quick-lru": "^5.2.0", + "arg": "^5.0.2", + "chokidar": "^3.6.0", + "didyoumean": "^1.2.2", + "dlv": "^1.1.3", + "fast-glob": "^3.3.2", + "glob-parent": "^6.0.2", + "is-glob": "^4.0.3", + "jiti": "^1.21.6", + "lilconfig": "^2.1.0", + "micromatch": "^4.0.8", + "normalize-path": "^3.0.0", + "object-hash": "^3.0.0", + "picocolors": "^1.1.1", + "postcss": "^8.4.47", + "postcss-import": "^15.1.0", + "postcss-js": "^4.0.1", + "postcss-load-config": "^4.0.2", + "postcss-nested": "^6.2.0", + "postcss-selector-parser": "^6.1.2", + "resolve": "^1.22.8", + "sucrase": "^3.35.0" + }, + "bin": { + "tailwind": "lib/cli.js", + "tailwindcss": "lib/cli.js" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/textlinestream": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/textlinestream/-/textlinestream-1.1.1.tgz", + "integrity": "sha512-iBHbi7BQxrFmwZUQJsT0SjNzlLLsXhvW/kg7EyOMVMBIrlnj/qYofwo1LVLZi+3GbUEo96Iu2eqToI2+lZoAEQ==", + "license": "MIT" + }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "devOptional": true, + "license": "0BSD" + }, + "node_modules/uc.micro": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/uc.micro/-/uc.micro-2.1.0.tgz", + "integrity": "sha512-ARDJmphmdvUk6Glw7y9DQ2bFkKBHwQHLi2lsaH6PPmz/Ka9sFOBsBluozhDltWmnv9u/cF6Rt87znRTPV+yp/A==", + "license": "MIT" + }, + "node_modules/varint": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/varint/-/varint-6.0.0.tgz", + "integrity": "sha512-cXEIW6cfr15lFv563k4GuVuW/fiwjknytD37jIOLSdSWuOI6WnO/oKwmP2FQTU2l01LP8/M5TSAJpzUaGe3uWg==", + "devOptional": true, + "license": "MIT" + }, + "node_modules/vite": { + "version": "5.4.11", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.11.tgz", + "integrity": "sha512-c7jFQRklXua0mTzneGW9QVyxFjUgwcihC4bXEtujIo2ouWCe1Ajt/amn2PCxYnhYfd5k09JX3SB7OYWFKYqj8Q==", + "license": "MIT", + "dependencies": { + "esbuild": "^0.21.3", + "postcss": "^8.4.43", + "rollup": "^4.20.0" + }, + "bin": { + "vite": "bin/vite.js" + }, + "engines": { + "node": "^18.0.0 || >=20.0.0" + }, + "funding": { + "url": "https://github.com/vitejs/vite?sponsor=1" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + }, + "peerDependencies": { + "@types/node": "^18.0.0 || >=20.0.0", + "less": "*", + "lightningcss": "^1.21.0", + "sass": "*", + "sass-embedded": "*", + "stylus": "*", + "sugarss": "*", + "terser": "^5.4.0" + }, + "peerDependenciesMeta": { + "@types/node": { + "optional": true + }, + "less": { + "optional": true + }, + "lightningcss": { + "optional": true + }, + "sass": { + "optional": true + }, + "sass-embedded": { + "optional": true + }, + "stylus": { + "optional": true + }, + "sugarss": { + "optional": true + }, + "terser": { + "optional": true + } + } + }, + "node_modules/vite-plugin-singlefile": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/vite-plugin-singlefile/-/vite-plugin-singlefile-2.0.3.tgz", + "integrity": "sha512-OEBEwdX8nCGPSdtaB1D7rryYnT+YfPTS8ojL1TDyeUF+bWDCTfRriQqw6T0vl9EbKI/KMg7szN3awst6cLrKkA==", + "license": "MIT", + "dependencies": { + "micromatch": "^4.0.8" + }, + "engines": { + "node": ">18.0.0" + }, + "peerDependencies": { + "rollup": "^4.24.3", + "vite": "^5.4.10" + } + }, + "node_modules/vue": { + "version": "3.5.13", + "resolved": "https://registry.npmjs.org/vue/-/vue-3.5.13.tgz", + "integrity": "sha512-wmeiSMxkZCSc+PM2w2VRsOYAZC8GdipNFRTsLSfodVqI9mbejKeXEGr8SckuLnrQPGe3oJN5c3K0vpoU9q/wCQ==", + "license": "MIT", + "dependencies": { + "@vue/compiler-dom": "3.5.13", + "@vue/compiler-sfc": "3.5.13", + "@vue/runtime-dom": "3.5.13", + "@vue/server-renderer": "3.5.13", + "@vue/shared": "3.5.13" + }, + "peerDependencies": { + "typescript": "*" + }, + "peerDependenciesMeta": { + "typescript": { + "optional": true + } + } + } + } +} diff --git a/examples/server/webui/package.json b/examples/server/webui/package.json new file mode 100644 index 000000000..2836cce00 --- /dev/null +++ b/examples/server/webui/package.json @@ -0,0 +1,30 @@ +{ + "name": "webui", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview", + "analyze": "ANALYZE=1 npx vite-bundle-visualizer" + }, + "devDependencies": { + "sass-embedded": "^1.83.0", + "vite": "^5.4.10" + }, + "dependencies": { + "@sec-ant/readable-stream": "^0.6.0", + "@vscode/markdown-it-katex": "^1.1.1", + "autoprefixer": "^10.4.20", + "daisyui": "^4.12.14", + "highlight.js": "^11.10.0", + "katex": "^0.16.15", + "markdown-it": "^14.1.0", + "postcss": "^8.4.49", + "tailwindcss": "^3.4.15", + "textlinestream": "^1.1.1", + "vite-plugin-singlefile": "^2.0.3", + "vue": "^3.5.13" + } +} diff --git a/examples/server/webui/postcss.config.js b/examples/server/webui/postcss.config.js new file mode 100644 index 000000000..2e7af2b7f --- /dev/null +++ b/examples/server/webui/postcss.config.js @@ -0,0 +1,6 @@ +export default { + plugins: { + tailwindcss: {}, + autoprefixer: {}, + }, +} diff --git a/examples/server/webui/public/demo-conversation.json b/examples/server/webui/public/demo-conversation.json new file mode 100644 index 000000000..75ab599dd --- /dev/null +++ b/examples/server/webui/public/demo-conversation.json @@ -0,0 +1,33 @@ +{ + "demo": true, + "id": "conv-1734086746930", + "lastModified": 1734087548943, + "messages": [ + { + "id": 1734086764521, + "role": "user", + "content": "this is a demo conversation, used in dev mode" + }, + { + "id": 1734087548327, + "role": "assistant", + "content": "This is the formula:\n\n$\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}$\n\nGiven an input vector \\(\\mathbf{x} = [x_1, x_2, \\ldots, x_n]\\)\n\n\\[\ny_i = \\frac{e^{x_i}}{\\sum_{j=1}^n e^{x_j}}\n\\]\n\nCode block latex:\n```latex\n\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}\n```\n\nTest dollar sign: $1234 $4567\n\nInvalid latex syntax: $E = mc^$ and $$E = mc^$$", + "timings": { + "prompt_n": 1, + "prompt_ms": 28.923, + "predicted_n": 25, + "predicted_ms": 573.016 + } + }, + { + "id": 1734087548328, + "role": "user", + "content": "this is a demo conversation, used in dev mode" + }, + { + "id": 1734087548329, + "role": "assistant", + "content": "Code block:\n```js\nconsole.log('hello world')\n```\n```sh\nls -la /dev\n```" + } + ] +} diff --git a/examples/server/webui/src/highlight-config.js b/examples/server/webui/src/highlight-config.js new file mode 100644 index 000000000..96c7028f9 --- /dev/null +++ b/examples/server/webui/src/highlight-config.js @@ -0,0 +1,60 @@ +import hljs from 'highlight.js/lib/core'; + +// only import commonly used languages to reduce bundle size + +import python from 'highlight.js/lib/languages/python'; +import javascript from 'highlight.js/lib/languages/javascript'; +import json from 'highlight.js/lib/languages/json'; +import bash from 'highlight.js/lib/languages/bash'; +import yaml from 'highlight.js/lib/languages/yaml'; +import markdown from 'highlight.js/lib/languages/markdown'; +import scss from 'highlight.js/lib/languages/scss'; +import xml from 'highlight.js/lib/languages/xml'; +import ruby from 'highlight.js/lib/languages/ruby'; +import go from 'highlight.js/lib/languages/go'; +import java from 'highlight.js/lib/languages/java'; +import rust from 'highlight.js/lib/languages/rust'; +import scala from 'highlight.js/lib/languages/scala'; +import cpp from 'highlight.js/lib/languages/cpp'; +import csharp from 'highlight.js/lib/languages/csharp'; +import swift from 'highlight.js/lib/languages/swift'; +import dart from 'highlight.js/lib/languages/dart'; +import elixir from 'highlight.js/lib/languages/elixir'; +import kotlin from 'highlight.js/lib/languages/kotlin'; +import lua from 'highlight.js/lib/languages/lua'; +import php from 'highlight.js/lib/languages/php'; +import latex from 'highlight.js/lib/languages/latex'; + +hljs.registerLanguage('python', python); +hljs.registerLanguage('javascript', javascript); +hljs.registerLanguage('json', json); +hljs.registerLanguage('yaml', yaml); +hljs.registerLanguage('markdown', markdown); +hljs.registerLanguage('xml', xml); +hljs.registerLanguage('ruby', ruby); +hljs.registerLanguage('go', go); +hljs.registerLanguage('java', java); +hljs.registerLanguage('rust', rust); +hljs.registerLanguage('scala', scala); +hljs.registerLanguage('csharp', csharp); +hljs.registerLanguage('swift', swift); +hljs.registerLanguage('dart', dart); +hljs.registerLanguage('elixir', elixir); +hljs.registerLanguage('kotlin', kotlin); +hljs.registerLanguage('lua', lua); +hljs.registerLanguage('php', php); +hljs.registerLanguage('latex', latex); + +// reuse some languages to further reduce bundle size + +hljs.registerLanguage('shell', bash); +hljs.registerLanguage('bash', bash); +hljs.registerLanguage('sh', bash); + +hljs.registerLanguage('css', scss); +hljs.registerLanguage('scss', scss); + +hljs.registerLanguage('c', cpp); +hljs.registerLanguage('cpp', cpp); + +export default hljs; diff --git a/examples/server/webui/src/katex-gpt.js b/examples/server/webui/src/katex-gpt.js new file mode 100644 index 000000000..7c7c5e22c --- /dev/null +++ b/examples/server/webui/src/katex-gpt.js @@ -0,0 +1,66 @@ +import katex from 'katex'; + +// Adapted from https://github.com/SchneeHertz/markdown-it-katex-gpt +// MIT license + +const defaultOptions = { + delimiters: [ + { left: '\\[', right: '\\]', display: true }, + { left: '\\(', right: '\\)', display: false }, + ], +}; + +export function renderLatexHTML(content, display = false) { + return katex.renderToString(content, { + throwOnError: false, + output: 'mathml', + displayMode: display, + }); +} + +function escapedBracketRule(options) { + return (state, silent) => { + const max = state.posMax; + const start = state.pos; + + for (const { left, right, display } of options.delimiters) { + + // Check if it starts with the left delimiter + if (!state.src.slice(start).startsWith(left)) continue; + + // Skip the length of the left delimiter + let pos = start + left.length; + + // Find the matching right delimiter + while (pos < max) { + if (state.src.slice(pos).startsWith(right)) { + break; + } + pos++; + } + + // No matching right delimiter found, skip to the next match + if (pos >= max) continue; + + // If not in silent mode, convert LaTeX formula to MathML + if (!silent) { + const content = state.src.slice(start + left.length, pos); + try { + const renderedContent = renderLatexHTML(content, display); + const token = state.push('html_inline', '', 0); + token.content = renderedContent; + } catch (e) { + console.error(e); + } + } + + // Update position, skip the length of the right delimiter + state.pos = pos + right.length; + return true; + } + } +} + +export default function (md, options = defaultOptions) { + md.inline.ruler.after('text', 'escaped_bracket', escapedBracketRule(options)); +} diff --git a/examples/server/webui/src/main.js b/examples/server/webui/src/main.js new file mode 100644 index 000000000..90f4ca368 --- /dev/null +++ b/examples/server/webui/src/main.js @@ -0,0 +1,701 @@ +import './styles.scss'; +import { createApp, defineComponent, shallowRef, computed, h } from 'vue/dist/vue.esm-bundler.js'; +import MarkdownIt from 'markdown-it'; +import TextLineStream from 'textlinestream'; + +// math formula rendering +import 'katex/dist/katex.min.css'; +import markdownItKatexGpt from './katex-gpt'; +import markdownItKatexNormal from '@vscode/markdown-it-katex'; + +// code highlighting +import hljs from './highlight-config'; +import daisyuiThemes from 'daisyui/src/theming/themes'; + +// ponyfill for missing ReadableStream asyncIterator on Safari +import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator'; + +const isDev = import.meta.env.MODE === 'development'; + +// types +/** @typedef {{ id: number, role: 'user' | 'assistant', content: string, timings: any }} Message */ +/** @typedef {{ role: 'user' | 'assistant', content: string }} APIMessage */ +/** @typedef {{ id: string, lastModified: number, messages: Array }} Conversation */ + +// utility functions +const isString = (x) => !!x.toLowerCase; +const isBoolean = (x) => x === true || x === false; +const isNumeric = (n) => !isString(n) && !isNaN(n) && !isBoolean(n); +const escapeAttr = (str) => str.replace(/>/g, '>').replace(/"/g, '"'); +const copyStr = (textToCopy) => { + // Navigator clipboard api needs a secure context (https) + if (navigator.clipboard && window.isSecureContext) { + navigator.clipboard.writeText(textToCopy); + } else { + // Use the 'out of viewport hidden text area' trick + const textArea = document.createElement('textarea'); + textArea.value = textToCopy; + // Move textarea out of the viewport so it's not visible + textArea.style.position = 'absolute'; + textArea.style.left = '-999999px'; + document.body.prepend(textArea); + textArea.select(); + document.execCommand('copy'); + } +}; + +// constants +const BASE_URL = isDev + ? (localStorage.getItem('base') || 'https://localhost:8080') // for debugging + : (new URL('.', document.baseURI).href).toString().replace(/\/$/, ''); // for production +console.log({ BASE_URL }); + +const CONFIG_DEFAULT = { + // Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value. + apiKey: '', + systemMessage: 'You are a helpful assistant.', + showTokensPerSecond: false, + showThoughtInProgress: false, + excludeThoughtOnReq: true, + // make sure these default values are in sync with `common.h` + samplers: 'edkypmxt', + temperature: 0.8, + dynatemp_range: 0.0, + dynatemp_exponent: 1.0, + top_k: 40, + top_p: 0.95, + min_p: 0.05, + xtc_probability: 0.0, + xtc_threshold: 0.1, + typical_p: 1.0, + repeat_last_n: 64, + repeat_penalty: 1.0, + presence_penalty: 0.0, + frequency_penalty: 0.0, + dry_multiplier: 0.0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + max_tokens: -1, + custom: '', // custom json-stringified object +}; +const CONFIG_INFO = { + apiKey: 'Set the API Key if you are using --api-key option for the server.', + systemMessage: 'The starting message that defines how model should behave.', + samplers: 'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature', + temperature: 'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.', + dynatemp_range: 'Addon for the temperature sampler. The added value to the range of dynamic temperature, which adjusts probabilities by entropy of tokens.', + dynatemp_exponent: 'Addon for the temperature sampler. Smoothes out the probability redistribution based on the most probable token.', + top_k: 'Keeps only k top tokens.', + top_p: 'Limits tokens to those that together have a cumulative probability of at least p', + min_p: 'Limits tokens based on the minimum probability for a token to be considered, relative to the probability of the most likely token.', + xtc_probability: 'XTC sampler cuts out top tokens; this parameter controls the chance of cutting tokens at all. 0 disables XTC.', + xtc_threshold: 'XTC sampler cuts out top tokens; this parameter controls the token probability that is required to cut that token.', + typical_p: 'Sorts and limits tokens based on the difference between log-probability and entropy.', + repeat_last_n: 'Last n tokens to consider for penalizing repetition', + repeat_penalty: 'Controls the repetition of token sequences in the generated text', + presence_penalty: 'Limits tokens based on whether they appear in the output or not.', + frequency_penalty: 'Limits tokens based on how often they appear in the output.', + dry_multiplier: 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling multiplier.', + dry_base: 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling base value.', + dry_allowed_length: 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the allowed length for DRY sampling.', + dry_penalty_last_n: 'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets DRY penalty for the last n tokens.', + max_tokens: 'The maximum number of token per output.', + custom: '', // custom json-stringified object +}; +// config keys having numeric value (i.e. temperature, top_k, top_p, etc) +const CONFIG_NUMERIC_KEYS = Object.entries(CONFIG_DEFAULT).filter(e => isNumeric(e[1])).map(e => e[0]); +// list of themes supported by daisyui +const THEMES = ['light', 'dark'] + // make sure light & dark are always at the beginning + .concat(Object.keys(daisyuiThemes).filter(t => t !== 'light' && t !== 'dark')); + +// markdown support +const VueMarkdown = defineComponent( + (props) => { + const md = shallowRef(new MarkdownIt({ + breaks: true, + highlight: function (str, lang) { // Add highlight.js + if (lang && hljs.getLanguage(lang)) { + try { + return '
' +
+                   hljs.highlight(str, { language: lang, ignoreIllegals: true }).value +
+                   '
'; + } catch (__) {} + } + return '
' + md.value.utils.escapeHtml(str) + '
'; + } + })); + // support latex with double dollar sign and square brackets + md.value.use(markdownItKatexGpt, { + delimiters: [ + { left: '\\[', right: '\\]', display: true }, + { left: '\\(', right: '\\)', display: false }, + { left: '$$', right: '$$', display: false }, + // do not add single dollar sign here, other wise it will confused with dollar used for money symbol + ], + throwOnError: false, + }); + // support latex with single dollar sign + md.value.use(markdownItKatexNormal, { throwOnError: false }); + // add copy button to code blocks + const origFenchRenderer = md.value.renderer.rules.fence; + md.value.renderer.rules.fence = (tokens, idx, ...args) => { + const content = tokens[idx].content; + const origRendered = origFenchRenderer(tokens, idx, ...args); + return `
+ + ${origRendered} +
`; + }; + window.copyStr = copyStr; + const content = computed(() => md.value.render(props.source)); + return () => h('div', { innerHTML: content.value }); + }, + { props: ['source'] } +); + +// input field to be used by settings modal +const SettingsModalShortInput = defineComponent({ + template: document.getElementById('settings-modal-short-input').innerHTML, + props: { + label: { type: String, required: false }, + configKey: String, + configDefault: Object, + configInfo: Object, + modelValue: [Object, String, Number], + }, +}); + +// message bubble component +const MessageBubble = defineComponent({ + components: { + VueMarkdown + }, + template: document.getElementById('message-bubble').innerHTML, + props: { + config: Object, + msg: Object, + isGenerating: Boolean, + showThoughtInProgress: Boolean, + editUserMsgAndRegenerate: Function, + regenerateMsg: Function, + }, + data() { + return { + editingContent: null, + }; + }, + computed: { + timings() { + if (!this.msg.timings) return null; + return { + ...this.msg.timings, + prompt_per_second: this.msg.timings.prompt_n / (this.msg.timings.prompt_ms / 1000), + predicted_per_second: this.msg.timings.predicted_n / (this.msg.timings.predicted_ms / 1000), + }; + }, + splitMsgContent() { + const content = this.msg.content; + if (this.msg.role !== 'assistant') { + return { content }; + } + let actualContent = ''; + let cot = ''; + let isThinking = false; + let thinkSplit = content.split('', 2); + actualContent += thinkSplit[0]; + while (thinkSplit[1] !== undefined) { + // tag found + thinkSplit = thinkSplit[1].split('', 2); + cot += thinkSplit[0]; + isThinking = true; + if (thinkSplit[1] !== undefined) { + // closing tag found + isThinking = false; + thinkSplit = thinkSplit[1].split('', 2); + actualContent += thinkSplit[0]; + } + } + return { content: actualContent, cot, isThinking }; + }, + }, + methods: { + copyMsg() { + copyStr(this.msg.content); + }, + editMsg() { + this.editUserMsgAndRegenerate({ + ...this.msg, + content: this.editingContent, + }); + this.editingContent = null; + }, + }, +}); + +// coversations is stored in localStorage +// format: { [convId]: { id: string, lastModified: number, messages: [...] } } +// convId is a string prefixed with 'conv-' +const StorageUtils = { + /** + * manage conversations + * @returns {Array} + */ + getAllConversations() { + const res = []; + for (const key in localStorage) { + if (key.startsWith('conv-')) { + res.push(JSON.parse(localStorage.getItem(key))); + } + } + res.sort((a, b) => b.lastModified - a.lastModified); + return res; + }, + /** + * can return null if convId does not exist + * @param {string} convId + * @returns {Conversation | null} + */ + getOneConversation(convId) { + return JSON.parse(localStorage.getItem(convId) || 'null'); + }, + /** + * if convId does not exist, create one + * @param {string} convId + * @param {Message} msg + */ + appendMsg(convId, msg) { + if (msg.content === null) return; + const conv = StorageUtils.getOneConversation(convId) || { + id: convId, + lastModified: Date.now(), + messages: [], + }; + conv.messages.push(msg); + conv.lastModified = Date.now(); + localStorage.setItem(convId, JSON.stringify(conv)); + }, + /** + * Get new conversation id + * @returns {string} + */ + getNewConvId() { + return `conv-${Date.now()}`; + }, + /** + * remove conversation by id + * @param {string} convId + */ + remove(convId) { + localStorage.removeItem(convId); + }, + /** + * remove all conversations + * @param {string} convId + */ + filterAndKeepMsgs(convId, predicate) { + const conv = StorageUtils.getOneConversation(convId); + if (!conv) return; + conv.messages = conv.messages.filter(predicate); + conv.lastModified = Date.now(); + localStorage.setItem(convId, JSON.stringify(conv)); + }, + /** + * remove last message from conversation + * @param {string} convId + * @returns {Message | undefined} + */ + popMsg(convId) { + const conv = StorageUtils.getOneConversation(convId); + if (!conv) return; + const msg = conv.messages.pop(); + conv.lastModified = Date.now(); + if (conv.messages.length === 0) { + StorageUtils.remove(convId); + } else { + localStorage.setItem(convId, JSON.stringify(conv)); + } + return msg; + }, + + // manage config + getConfig() { + const savedVal = JSON.parse(localStorage.getItem('config') || '{}'); + // to prevent breaking changes in the future, we always provide default value for missing keys + return { + ...CONFIG_DEFAULT, + ...savedVal, + }; + }, + setConfig(config) { + localStorage.setItem('config', JSON.stringify(config)); + }, + getTheme() { + return localStorage.getItem('theme') || 'auto'; + }, + setTheme(theme) { + if (theme === 'auto') { + localStorage.removeItem('theme'); + } else { + localStorage.setItem('theme', theme); + } + }, +}; + +// scroll to bottom of chat messages +// if requiresNearBottom is true, only auto-scroll if user is near bottom +const chatScrollToBottom = (requiresNearBottom) => { + const msgListElem = document.getElementById('messages-list'); + const spaceToBottom = msgListElem.scrollHeight - msgListElem.scrollTop - msgListElem.clientHeight; + if (!requiresNearBottom || (spaceToBottom < 100)) { + setTimeout(() => msgListElem.scrollTo({ top: msgListElem.scrollHeight }), 1); + } +}; + +// wrapper for SSE +async function* sendSSEPostRequest(url, fetchOptions) { + const res = await fetch(url, fetchOptions); + const lines = res.body + .pipeThrough(new TextDecoderStream()) + .pipeThrough(new TextLineStream()); + for await (const line of asyncIterator(lines)) { + if (isDev) console.log({line}); + if (line.startsWith('data:') && !line.endsWith('[DONE]')) { + const data = JSON.parse(line.slice(5)); + yield data; + } else if (line.startsWith('error:')) { + const data = JSON.parse(line.slice(6)); + throw new Error(data.message || 'Unknown error'); + } + } +}; + +const mainApp = createApp({ + components: { + VueMarkdown, + SettingsModalShortInput, + MessageBubble, + }, + data() { + return { + conversations: StorageUtils.getAllConversations(), + /** @type {Array} */ + messages: [], + viewingConvId: StorageUtils.getNewConvId(), + inputMsg: '', + isGenerating: false, + /** @type {Array | null} */ + pendingMsg: null, // the on-going message from assistant + stopGeneration: () => {}, + selectedTheme: StorageUtils.getTheme(), + config: StorageUtils.getConfig(), + showConfigDialog: false, + // const + themes: THEMES, + /** @type {CONFIG_DEFAULT} */ + configDefault: {...CONFIG_DEFAULT}, + configInfo: {...CONFIG_INFO}, + isDev, + } + }, + computed: {}, + mounted() { + document.getElementById('app').classList.remove('opacity-0'); // show app + // scroll to the bottom when the pending message height is updated + const pendingMsgElem = document.getElementById('pending-msg'); + const resizeObserver = new ResizeObserver(() => { + if (this.isGenerating) chatScrollToBottom(true); + }); + resizeObserver.observe(pendingMsgElem); + this.setSelectedTheme(this.selectedTheme); + }, + watch: { + viewingConvId: function(val, oldVal) { + if (val != oldVal) { + this.fetchMessages(); + chatScrollToBottom(); + this.hideSidebar(); + } + } + }, + methods: { + hideSidebar() { + document.getElementById('toggle-drawer').checked = false; + }, + setSelectedTheme(theme) { + this.selectedTheme = theme; + document.body.setAttribute('data-theme', theme); + document.body.setAttribute('data-color-scheme', daisyuiThemes[theme]?.['color-scheme'] ?? 'auto'); + StorageUtils.setTheme(theme); + }, + newConversation() { + if (this.isGenerating) return; + this.viewingConvId = StorageUtils.getNewConvId(); + }, + setViewingConv(convId) { + if (this.isGenerating) return; + this.viewingConvId = convId; + }, + deleteConv(convId) { + if (this.isGenerating) return; + if (window.confirm('Are you sure to delete this conversation?')) { + StorageUtils.remove(convId); + if (this.viewingConvId === convId) { + this.viewingConvId = StorageUtils.getNewConvId(); + } + this.fetchConversation(); + this.fetchMessages(); + } + }, + downloadConv(convId) { + const conversation = StorageUtils.getOneConversation(convId); + if (!conversation) { + alert('Conversation not found.'); + return; + } + const conversationJson = JSON.stringify(conversation, null, 2); + const blob = new Blob([conversationJson], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `conversation_${convId}.json`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + }, + async sendMessage() { + if (!this.inputMsg) return; + const currConvId = this.viewingConvId; + + StorageUtils.appendMsg(currConvId, { + id: Date.now(), + role: 'user', + content: this.inputMsg, + }); + this.fetchConversation(); + this.fetchMessages(); + this.inputMsg = ''; + this.generateMessage(currConvId); + chatScrollToBottom(); + }, + async generateMessage(currConvId) { + if (this.isGenerating) return; + this.pendingMsg = { id: Date.now()+1, role: 'assistant', content: null }; + this.isGenerating = true; + + try { + /** @type {CONFIG_DEFAULT} */ + const config = this.config; + const abortController = new AbortController(); + this.stopGeneration = () => abortController.abort(); + /** @type {Array} */ + let messages = [ + { role: 'system', content: config.systemMessage }, + ...normalizeMsgsForAPI(this.messages), + ]; + if (config.excludeThoughtOnReq) { + messages = filterThoughtFromMsgs(messages); + } + if (isDev) console.log({messages}); + const params = { + messages, + stream: true, + cache_prompt: true, + samplers: config.samplers, + temperature: config.temperature, + dynatemp_range: config.dynatemp_range, + dynatemp_exponent: config.dynatemp_exponent, + top_k: config.top_k, + top_p: config.top_p, + min_p: config.min_p, + typical_p: config.typical_p, + xtc_probability: config.xtc_probability, + xtc_threshold: config.xtc_threshold, + repeat_last_n: config.repeat_last_n, + repeat_penalty: config.repeat_penalty, + presence_penalty: config.presence_penalty, + frequency_penalty: config.frequency_penalty, + dry_multiplier: config.dry_multiplier, + dry_base: config.dry_base, + dry_allowed_length: config.dry_allowed_length, + dry_penalty_last_n: config.dry_penalty_last_n, + max_tokens: config.max_tokens, + timings_per_token: !!config.showTokensPerSecond, + ...(config.custom.length ? JSON.parse(config.custom) : {}), + }; + const chunks = sendSSEPostRequest(`${BASE_URL}/v1/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + ...(config.apiKey ? {'Authorization': `Bearer ${config.apiKey}`} : {}) + }, + body: JSON.stringify(params), + signal: abortController.signal, + }); + for await (const chunk of chunks) { + const stop = chunk.stop; + const addedContent = chunk.choices[0].delta.content; + const lastContent = this.pendingMsg.content || ''; + if (addedContent) { + this.pendingMsg = { + id: this.pendingMsg.id, + role: 'assistant', + content: lastContent + addedContent, + }; + } + const timings = chunk.timings; + if (timings && config.showTokensPerSecond) { + // only extract what's really needed, to save some space + this.pendingMsg.timings = { + prompt_n: timings.prompt_n, + prompt_ms: timings.prompt_ms, + predicted_n: timings.predicted_n, + predicted_ms: timings.predicted_ms, + }; + } + } + + StorageUtils.appendMsg(currConvId, this.pendingMsg); + this.fetchConversation(); + this.fetchMessages(); + setTimeout(() => document.getElementById('msg-input').focus(), 1); + } catch (error) { + if (error.name === 'AbortError') { + // user stopped the generation via stopGeneration() function + StorageUtils.appendMsg(currConvId, this.pendingMsg); + this.fetchConversation(); + this.fetchMessages(); + } else { + console.error(error); + alert(error); + // pop last user message + const lastUserMsg = StorageUtils.popMsg(currConvId); + this.inputMsg = lastUserMsg ? lastUserMsg.content : ''; + } + } + + this.pendingMsg = null; + this.isGenerating = false; + this.stopGeneration = () => {}; + this.fetchMessages(); + chatScrollToBottom(); + }, + + // message actions + regenerateMsg(msg) { + if (this.isGenerating) return; + // TODO: somehow keep old history (like how ChatGPT has different "tree"). This can be done by adding "sub-conversations" with "subconv-" prefix, and new message will have a list of subconvIds + const currConvId = this.viewingConvId; + StorageUtils.filterAndKeepMsgs(currConvId, (m) => m.id < msg.id); + this.fetchConversation(); + this.fetchMessages(); + this.generateMessage(currConvId); + }, + editUserMsgAndRegenerate(msg) { + if (this.isGenerating) return; + const currConvId = this.viewingConvId; + const newContent = msg.content; + StorageUtils.filterAndKeepMsgs(currConvId, (m) => m.id < msg.id); + StorageUtils.appendMsg(currConvId, { + id: Date.now(), + role: 'user', + content: newContent, + }); + this.fetchConversation(); + this.fetchMessages(); + this.generateMessage(currConvId); + }, + + // settings dialog methods + closeAndSaveConfigDialog() { + try { + if (this.config.custom.length) JSON.parse(this.config.custom); + } catch (error) { + alert('Invalid JSON for custom config. Please either fix it or leave it empty.'); + return; + } + for (const key of CONFIG_NUMERIC_KEYS) { + if (isNaN(this.config[key]) || this.config[key].toString().trim().length === 0) { + alert(`Invalid number for ${key} (expected an integer or a float)`); + return; + } + this.config[key] = parseFloat(this.config[key]); + } + this.showConfigDialog = false; + StorageUtils.setConfig(this.config); + }, + closeAndDiscardConfigDialog() { + this.showConfigDialog = false; + this.config = StorageUtils.getConfig(); + }, + resetConfigDialog() { + if (window.confirm('Are you sure to reset all settings?')) { + this.config = {...CONFIG_DEFAULT}; + } + }, + + // sync state functions + fetchConversation() { + this.conversations = StorageUtils.getAllConversations(); + }, + fetchMessages() { + this.messages = StorageUtils.getOneConversation(this.viewingConvId)?.messages ?? []; + }, + + // debug functions + async debugImportDemoConv() { + const res = await fetch('/demo-conversation.json'); + const demoConv = await res.json(); + StorageUtils.remove(demoConv.id); + for (const msg of demoConv.messages) { + StorageUtils.appendMsg(demoConv.id, msg); + } + this.fetchConversation(); + } + }, +}); +mainApp.config.errorHandler = alert; +try { + mainApp.mount('#app'); +} catch (err) { + console.error(err); + document.getElementById('app').innerHTML = `
+ Failed to start app. Please try clearing localStorage and try again.
+
+ +
`; +} + +/** + * filter out redundant fields upon sending to API + * @param {Array} messages + * @returns {Array} + */ +function normalizeMsgsForAPI(messages) { + return messages.map((msg) => { + return { + role: msg.role, + content: msg.content, + }; + }); +} + +/** + * recommended for DeepsSeek-R1, filter out content between and tags + * @param {Array} messages + * @returns {Array} + */ +function filterThoughtFromMsgs(messages) { + return messages.map((msg) => { + return { + role: msg.role, + content: msg.role === 'assistant' + ? msg.content.split('
').at(-1).trim() + : msg.content, + }; + }); +} diff --git a/examples/server/webui/src/styles.scss b/examples/server/webui/src/styles.scss new file mode 100644 index 000000000..34fe2aaf0 --- /dev/null +++ b/examples/server/webui/src/styles.scss @@ -0,0 +1,48 @@ +@use "sass:meta"; + +@tailwind base; +@tailwind components; +@tailwind utilities; + +.markdown { + h1, h2, h3, h4, h5, h6, ul, ol, li { all: revert; } + pre { + @apply whitespace-pre-wrap rounded-lg p-2; + border: 1px solid currentColor; + } + /* TODO: fix markdown table */ +} + +.show-on-hover { + @apply md:opacity-0 md:group-hover:opacity-100; +} +.btn-mini { + @apply cursor-pointer hover:shadow-md; +} +.chat-screen { max-width: 900px; } + +.chat-bubble-base-300 { + --tw-bg-opacity: 1; + --tw-text-opacity: 1; + @apply bg-base-300 text-base-content; +} + +/* Highlight.js */ +[data-color-scheme='light'] { + @include meta.load-css('highlight.js/styles/stackoverflow-light'); +} +[data-color-scheme='dark'] { + @include meta.load-css('highlight.js/styles/stackoverflow-dark'); +} +[data-color-scheme='auto'] { + @media (prefers-color-scheme: light) { + @include meta.load-css('highlight.js/styles/stackoverflow-light'); + } + @media (prefers-color-scheme: dark) { + @include meta.load-css('highlight.js/styles/stackoverflow-dark'); + } +} +.hljs { + background: transparent !important; + padding: 0.5em !important; +} diff --git a/examples/server/webui/tailwind.config.js b/examples/server/webui/tailwind.config.js new file mode 100644 index 000000000..c43066a19 --- /dev/null +++ b/examples/server/webui/tailwind.config.js @@ -0,0 +1,16 @@ +/** @type {import('tailwindcss').Config} */ +export default { + content: [ + "./index.html", + "./src/**/*.{js,ts,jsx,tsx}", + ], + theme: { + extend: {}, + }, + plugins: [ + require('daisyui'), + ], + daisyui: { + themes: ['light', 'dark', 'cupcake', 'bumblebee', 'emerald', 'corporate', 'synthwave', 'retro', 'cyberpunk', 'valentine', 'halloween', 'garden', 'forest', 'aqua', 'lofi', 'pastel', 'fantasy', 'wireframe', 'black', 'luxury', 'dracula', 'cmyk', 'autumn', 'business', 'acid', 'lemonade', 'night', 'coffee', 'winter', 'dim', 'nord', 'sunset'], + } +} diff --git a/examples/server/webui/vite.config.js b/examples/server/webui/vite.config.js new file mode 100644 index 000000000..6619a630d --- /dev/null +++ b/examples/server/webui/vite.config.js @@ -0,0 +1,59 @@ + +import { viteSingleFile } from 'vite-plugin-singlefile'; +import path from 'path'; +import fs from 'fs'; +import zlib from 'zlib'; + +const MAX_BUNDLE_SIZE = 1.5 * 1024 * 1024; // only increase when absolutely necessary + +const GUIDE_FOR_FRONTEND = ` + +`.trim(); + +const BUILD_PLUGINS = [ + viteSingleFile(), + (function llamaCppPlugin() { + let config; + return { + name: 'llamacpp:build', + apply: 'build', + async configResolved(_config) { + config = _config; + }, + writeBundle() { + const outputIndexHtml = path.join(config.build.outDir, 'index.html'); + const content = GUIDE_FOR_FRONTEND + '\n' + fs.readFileSync(outputIndexHtml, 'utf-8'); + const compressed = zlib.gzipSync(Buffer.from(content, 'utf-8'), { level: 9 }); + + // because gzip header contains machine-specific info, we must remove these data from the header + // timestamp + compressed[0x4] = 0; + compressed[0x5] = 0; + compressed[0x6] = 0; + compressed[0x7] = 0; + // OS + compressed[0x9] = 0; + + if (compressed.byteLength > MAX_BUNDLE_SIZE) { + throw new Error( + `Bundle size is too large (${Math.ceil(compressed.byteLength / 1024)} KB).\n` + + `Please reduce the size of the frontend or increase MAX_BUNDLE_SIZE in vite.config.js.\n`, + ); + } + + const targetOutputFile = path.join(config.build.outDir, '../../public/index.html.gz'); + fs.writeFileSync(targetOutputFile, compressed); + } + } + })(), +]; + +/** @type {import('vite').UserConfig} */ +export default { + plugins: process.env.ANALYZE ? [] : BUILD_PLUGINS, +}; diff --git a/examples/simple-chat/CMakeLists.txt b/examples/simple-chat/CMakeLists.txt new file mode 100644 index 000000000..567f7fbbb --- /dev/null +++ b/examples/simple-chat/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-simple-chat) +add_executable(${TARGET} simple-chat.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/simple-chat/README.md b/examples/simple-chat/README.md new file mode 100644 index 000000000..f0099ce3d --- /dev/null +++ b/examples/simple-chat/README.md @@ -0,0 +1,7 @@ +# llama.cpp/example/simple-chat + +The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file. + +```bash +./llama-simple-chat -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048 +... diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp new file mode 100644 index 000000000..c5534cc13 --- /dev/null +++ b/examples/simple-chat/simple-chat.cpp @@ -0,0 +1,206 @@ +#include "llama.h" +#include +#include +#include +#include +#include + +static void print_usage(int, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); + printf("\n"); +} + +int main(int argc, char ** argv) { + std::string model_path; + int ngl = 99; + int n_ctx = 2048; + + // parse command line arguments + for (int i = 1; i < argc; i++) { + try { + if (strcmp(argv[i], "-m") == 0) { + if (i + 1 < argc) { + model_path = argv[++i]; + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-c") == 0) { + if (i + 1 < argc) { + n_ctx = std::stoi(argv[++i]); + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (i + 1 < argc) { + ngl = std::stoi(argv[++i]); + } else { + print_usage(argc, argv); + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } catch (std::exception & e) { + fprintf(stderr, "error: %s\n", e.what()); + print_usage(argc, argv); + return 1; + } + } + if (model_path.empty()) { + print_usage(argc, argv); + return 1; + } + + // only print errors + llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) { + if (level >= GGML_LOG_LEVEL_ERROR) { + fprintf(stderr, "%s", text); + } + }, nullptr); + + // load dynamic backends + ggml_backend_load_all(); + + // initialize the model + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); + if (!model) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // initialize the context + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ctx; + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (!ctx) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // initialize the sampler + llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params()); + llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + + // helper function to evaluate a prompt and generate a response + auto generate = [&](const std::string & prompt) { + std::string response; + + const bool is_first = llama_get_kv_cache_used_cells(ctx) == 0; + + // tokenize the prompt + const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); + std::vector prompt_tokens(n_prompt_tokens); + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) { + GGML_ABORT("failed to tokenize the prompt\n"); + } + + // prepare a batch for the prompt + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_token new_token_id; + while (true) { + // check if we have enough space in the context to evaluate this batch + int n_ctx = llama_n_ctx(ctx); + int n_ctx_used = llama_get_kv_cache_used_cells(ctx); + if (n_ctx_used + batch.n_tokens > n_ctx) { + printf("\033[0m\n"); + fprintf(stderr, "context size exceeded\n"); + exit(0); + } + + if (llama_decode(ctx, batch)) { + GGML_ABORT("failed to decode\n"); + } + + // sample the next token + new_token_id = llama_sampler_sample(smpl, ctx, -1); + + // is it an end of generation? + if (llama_vocab_is_eog(vocab, new_token_id)) { + break; + } + + // convert the token to a string, print it and add it to the response + char buf[256]; + int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true); + if (n < 0) { + GGML_ABORT("failed to convert token to piece\n"); + } + std::string piece(buf, n); + printf("%s", piece.c_str()); + fflush(stdout); + response += piece; + + // prepare the next batch with the sampled token + batch = llama_batch_get_one(&new_token_id, 1); + } + + return response; + }; + + std::vector messages; + std::vector formatted(llama_n_ctx(ctx)); + int prev_len = 0; + while (true) { + // get user input + printf("\033[32m> \033[0m"); + std::string user; + std::getline(std::cin, user); + + if (user.empty()) { + break; + } + + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); + + // add the user input to the message list and format it + messages.push_back({"user", strdup(user.c_str())}); + int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + if (new_len > (int)formatted.size()) { + formatted.resize(new_len); + new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + } + if (new_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } + + // remove previous messages to obtain the prompt to generate the response + std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); + + // generate a response + printf("\033[33m"); + std::string response = generate(prompt); + printf("\n\033[0m"); + + // add the response to the messages + messages.push_back({"assistant", strdup(response.c_str())}); + prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0); + if (prev_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } + } + + // free resources + for (auto & msg : messages) { + free(const_cast(msg.content)); + } + llama_sampler_free(smpl); + llama_free(ctx); + llama_model_free(model); + + return 0; +} diff --git a/examples/main-cmake-pkg/.gitignore b/examples/simple-cmake-pkg/.gitignore similarity index 100% rename from examples/main-cmake-pkg/.gitignore rename to examples/simple-cmake-pkg/.gitignore diff --git a/examples/simple-cmake-pkg/CMakeLists.txt b/examples/simple-cmake-pkg/CMakeLists.txt new file mode 100644 index 000000000..128e38c8f --- /dev/null +++ b/examples/simple-cmake-pkg/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.12) +project(llama-simple-cmake-pkg) + +set(TARGET llama-simple-cmake-pkg) + +find_package(Llama REQUIRED) + +add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../simple/simple.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ggml::all ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/simple-cmake-pkg/README.md b/examples/simple-cmake-pkg/README.md new file mode 100644 index 000000000..8b30049e2 --- /dev/null +++ b/examples/simple-cmake-pkg/README.md @@ -0,0 +1,34 @@ +# llama.cpp/example/simple-cmake-pkg + +This program builds [simple](../simple) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggerganov/llama.cpp) in projects which live outside of the source tree. + +## Building + +Because this example is "outside of the source tree", it is important to first build/install llama.cpp using CMake. An example is provided here, but please see the [llama.cpp build instructions](../..) for more detailed build instructions. + +### Considerations + +When hardware acceleration libraries are used (e.g. CUDA, Metal, Vulkan, etc.), the appropriate dependencies will be searched for automatically. So, for example, when finding a package + +### Build llama.cpp and install to llama.cpp/inst + +```sh +git clone https://github.com/ggerganov/llama.cpp +cd llama.cpp +cmake -S . -B build +cmake --build build +cmake --install build --prefix inst + +### Build simple-cmake-pkg + +```sh +cd examples/simple-cmake-pkg +cmake -S . -B build -DCMAKE_PREFIX_PATH=../../inst/lib/cmake +cmake --build build +``` + +### Run simple-cmake-pkg + +```sh +./build/llama-simple-cmake-pkg -m ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" +``` diff --git a/examples/simple/CMakeLists.txt b/examples/simple/CMakeLists.txt index 070cfbe7a..104ecabfd 100644 --- a/examples/simple/CMakeLists.txt +++ b/examples/simple/CMakeLists.txt @@ -1,5 +1,5 @@ set(TARGET llama-simple) add_executable(${TARGET} simple.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/simple/README.md b/examples/simple/README.md index 0ff342535..937008b24 100644 --- a/examples/simple/README.md +++ b/examples/simple/README.md @@ -3,7 +3,7 @@ The purpose of this example is to demonstrate a minimal usage of llama.cpp for generating text with a given prompt. ```bash -./llama-simple -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" +./llama-simple -m ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" ... diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index a53cef547..10e79a0a6 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -1,172 +1,206 @@ -#include "common.h" #include "llama.h" - -#include #include +#include #include #include static void print_usage(int, char ** argv) { - LOG_TEE("\nexample usage:\n"); - LOG_TEE("\n %s -m model.gguf -p \"Hello my name is\" -n 32\n", argv[0]); - LOG_TEE("\n"); + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [prompt]\n", argv[0]); + printf("\n"); } int main(int argc, char ** argv) { - gpt_params params; + // path to the model gguf file + std::string model_path; + // prompt to generate text from + std::string prompt = "Hello my name is"; + // number of layers to offload to the GPU + int ngl = 99; + // number of tokens to predict + int n_predict = 32; - params.prompt = "Hello my name is"; - params.n_predict = 32; + // parse command line arguments - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON, print_usage); - if (!gpt_params_parse(argc, argv, params, options)) { - return 1; + { + int i = 1; + for (; i < argc; i++) { + if (strcmp(argv[i], "-m") == 0) { + if (i + 1 < argc) { + model_path = argv[++i]; + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-n") == 0) { + if (i + 1 < argc) { + try { + n_predict = std::stoi(argv[++i]); + } catch (...) { + print_usage(argc, argv); + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (i + 1 < argc) { + try { + ngl = std::stoi(argv[++i]); + } catch (...) { + print_usage(argc, argv); + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } else { + // prompt starts here + break; + } + } + if (model_path.empty()) { + print_usage(argc, argv); + return 1; + } + if (i < argc) { + prompt = argv[i++]; + for (; i < argc; i++) { + prompt += " "; + prompt += argv[i]; + } + } } - // total length of the sequence including the prompt - const int n_predict = params.n_predict; + // load dynamic backends - // init LLM - - llama_backend_init(); - llama_numa_init(params.numa); + ggml_backend_load_all(); // initialize the model - llama_model_params model_params = llama_model_params_from_gpt_params(params); + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); + const llama_vocab * vocab = llama_model_get_vocab(model); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; } + // tokenize the prompt + + // find the number of tokens in the prompt + const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); + + // allocate space for the tokens and tokenize the prompt + std::vector prompt_tokens(n_prompt); + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { + fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__); + return 1; + } + // initialize the context - llama_context_params ctx_params = llama_context_params_from_gpt_params(params); + llama_context_params ctx_params = llama_context_default_params(); + // n_ctx is the context size + ctx_params.n_ctx = n_prompt + n_predict - 1; + // n_batch is the maximum number of tokens that can be processed in a single call to llama_decode + ctx_params.n_batch = n_prompt; + // enable performance counters + ctx_params.no_perf = false; - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; } + // initialize the sampler + auto sparams = llama_sampler_chain_default_params(); - sparams.no_perf = false; - llama_sampler * smpl = llama_sampler_chain_init(sparams); llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); - // tokenize the prompt - - std::vector tokens_list; - tokens_list = ::llama_tokenize(ctx, params.prompt, true); - - const int n_ctx = llama_n_ctx(ctx); - const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size()); - - LOG_TEE("\n%s: n_predict = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, n_kv_req); - - // make sure the KV cache is big enough to hold all the prompt and generated tokens - if (n_kv_req > n_ctx) { - LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); - LOG_TEE("%s: either reduce n_predict or increase n_ctx\n", __func__); - return 1; - } - // print the prompt token-by-token - fprintf(stderr, "\n"); - - for (auto id : tokens_list) { - fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); + for (auto id : prompt_tokens) { + char buf[128]; + int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true); + if (n < 0) { + fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); + return 1; + } + std::string s(buf, n); + printf("%s", s.c_str()); } - fflush(stderr); + // prepare a batch for the prompt - // create a llama_batch with size 512 - // we use this object to submit token data for decoding - - llama_batch batch = llama_batch_init(512, 0, 1); - - // evaluate the initial prompt - for (size_t i = 0; i < tokens_list.size(); i++) { - llama_batch_add(batch, tokens_list[i], i, { 0 }, false); - } - - // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; - - if (llama_decode(ctx, batch) != 0) { - LOG_TEE("%s: llama_decode() failed\n", __func__); - return 1; - } + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); // main loop - int n_cur = batch.n_tokens; - int n_decode = 0; - const auto t_main_start = ggml_time_us(); + int n_decode = 0; + llama_token new_token_id; - while (n_cur <= n_predict) { - // sample the next token - { - const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); - - llama_sampler_accept(smpl, new_token_id); - - // is it an end of generation? - if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { - LOG_TEE("\n"); - - break; - } - - LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); - fflush(stdout); - - // prepare the next batch - llama_batch_clear(batch); - - // push this new token for next evaluation - llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); - - n_decode += 1; - } - - n_cur += 1; - + for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) { // evaluate the current batch with the transformer model if (llama_decode(ctx, batch)) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } + + n_pos += batch.n_tokens; + + // sample the next token + { + new_token_id = llama_sampler_sample(smpl, ctx, -1); + + // is it an end of generation? + if (llama_vocab_is_eog(vocab, new_token_id)) { + break; + } + + char buf[128]; + int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true); + if (n < 0) { + fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); + return 1; + } + std::string s(buf, n); + printf("%s", s.c_str()); + fflush(stdout); + + // prepare the next batch with the sampled token + batch = llama_batch_get_one(&new_token_id, 1); + + n_decode += 1; + } } - LOG_TEE("\n"); + printf("\n"); const auto t_main_end = ggml_time_us(); - LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - LOG_TEE("\n"); - llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN); - llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); - + fprintf(stderr, "\n"); + llama_perf_sampler_print(smpl); + llama_perf_context_print(ctx); fprintf(stderr, "\n"); - llama_batch_free(batch); llama_sampler_free(smpl); llama_free(ctx); - llama_free_model(model); - - llama_backend_free(); + llama_model_free(model); return 0; } diff --git a/examples/speculative-simple/CMakeLists.txt b/examples/speculative-simple/CMakeLists.txt new file mode 100644 index 000000000..aeaea74fc --- /dev/null +++ b/examples/speculative-simple/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-speculative-simple) +add_executable(${TARGET} speculative-simple.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/speculative-simple/README.md b/examples/speculative-simple/README.md new file mode 100644 index 000000000..e3a6c6b4a --- /dev/null +++ b/examples/speculative-simple/README.md @@ -0,0 +1,12 @@ +# llama.cpp/examples/speculative-simple + +Demonstration of basic greedy speculative decoding + +```bash +./bin/llama-speculative-simple \ + -m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \ + -md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \ + -f test.txt -c 0 -ngl 99 --color \ + --sampling-seq k --top-k 1 -fa --temp 0.0 \ + -ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9 +``` diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp new file mode 100644 index 000000000..403ba2dd2 --- /dev/null +++ b/examples/speculative-simple/speculative-simple.cpp @@ -0,0 +1,261 @@ +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "speculative.h" +#include "log.h" +#include "llama.h" + +#include +#include +#include +#include + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { + return 1; + } + + if (params.n_predict < -1) { + LOG_ERR("%s: --n-predict must be >= -1\n", __func__); + return 1; + } + + common_init(); + + if (params.speculative.model.empty()) { + LOG_ERR("%s: --model-draft is required\n", __func__); + return 1; + } + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + llama_model * model_tgt = NULL; + //llama_model * model_dft = NULL; + + llama_context * ctx_tgt = NULL; + llama_context * ctx_dft = NULL; + + // load the target model + common_init_result llama_init_tgt = common_init_from_params(params); + + model_tgt = llama_init_tgt.model.get(); + ctx_tgt = llama_init_tgt.context.get(); + + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); + + // load the draft model + params.devices = params.speculative.devices; + params.model = params.speculative.model; + params.n_ctx = params.speculative.n_ctx; + params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch; + params.n_gpu_layers = params.speculative.n_gpu_layers; + + if (params.speculative.cpuparams.n_threads > 0) { + params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; + } + + params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + common_init_result llama_init_dft = common_init_from_params(params); + + //model_dft = llama_init_dft.model.get(); + ctx_dft = llama_init_dft.context.get(); + + if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { + return 1; + } + + // Tokenize the prompt + std::vector inp; + inp = common_tokenize(ctx_tgt, params.prompt, true, true); + + if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) { + LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); + + return 1; + } + + if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) { + LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt)); + + return 1; + } + + LOG("\n\n"); + + for (auto id : inp) { + LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); + } + + // how many tokens to draft each time + int n_draft = params.speculative.n_max; + int n_draft_min = params.speculative.n_min; + + float p_min = params.speculative.p_min; + + int n_predict = 0; + int n_drafted = 0; + int n_accept = 0; + + // used to determine end of generation + bool has_eos = false; + + // ================================================ + // everything until here is standard initialization + // the relevant stuff for speculative decoding starts here + + const auto t_enc_start = ggml_time_us(); + + // target model sampling context + struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); + + // eval the prompt + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + + // note: keep the last token separate! + llama_token id_last = inp.back(); + + // all tokens currently in the target context + llama_tokens prompt_tgt(inp.begin(), inp.end() - 1); + prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); + + int n_past = inp.size() - 1; + + // init the speculator + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft; + params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft; + params_spec.p_min = p_min; + + struct common_speculative * spec = common_speculative_init(ctx_dft); + + llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + + const auto t_enc_end = ggml_time_us(); + + const auto t_dec_start = ggml_time_us(); + + while (true) { + // optionally, generate draft tokens that can be appended to the target batch + // + // this is the most important part of the speculation. the more probable tokens that are provided here + // the better the performance will be. in theory, this computation can be performed asynchronously and even + // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens + // from a cache or lookup tables. + // + llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last); + + //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); + + // always have a token to evaluate from before - id_last + common_batch_clear(batch_tgt); + common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); + + // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] + { + // do not waste time on small drafts + if (draft.size() < (size_t) n_draft_min) { + draft.clear(); + } + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + } + + //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); + + llama_decode(ctx_tgt, batch_tgt); + } + + // sample from the full target batch and return the accepted tokens based on the target sampler + // + // for each token to be accepted, the sampler would have to sample that same token + // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the + // available logits from the batch and sample the next token until we run out of logits or the sampler + // disagrees with the draft + // + const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft); + + //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); + + GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token + + n_past += ids.size() - 1; + n_drafted += draft.size(); // note: we ignore the discarded small drafts + n_accept += ids.size() - 1; + n_predict += ids.size(); + + // process the accepted tokens and update contexts + // + // this is the standard token post-processing that we normally do + // in this case, we do it for a group of accepted tokens at once + // + for (size_t i = 0; i < ids.size(); ++i) { + prompt_tgt.push_back(id_last); + + id_last = ids[i]; + + if (llama_vocab_is_eog(vocab, id_last)) { + has_eos = true; + break; + } + + const std::string token_str = common_token_to_piece(ctx_tgt, id_last); + + if (params.use_color && i + 1 < ids.size()) { + LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str()); + } else { + LOG("%s", token_str.c_str()); + } + } + + LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last); + + { + LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); + + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); + } + + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + break; + } + } + + auto t_dec_end = ggml_time_us(); + + const int n_input = inp.size(); + + LOG("\n\n"); + + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + LOG_INF("\n"); + LOG_INF("draft:\n\n"); + + llama_perf_context_print(ctx_dft); + + LOG_INF("\n"); + LOG_INF("target:\n\n"); + common_perf_print(ctx_tgt, smpl); + + common_sampler_free(smpl); + common_speculative_free(spec); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/examples/speculative/CMakeLists.txt b/examples/speculative/CMakeLists.txt index aa208e7aa..c84196bd9 100644 --- a/examples/speculative/CMakeLists.txt +++ b/examples/speculative/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-speculative) add_executable(${TARGET} speculative.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 8f29b5a2c..c7ccea50d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -1,13 +1,18 @@ +#include "arg.h" #include "common.h" +#include "sampling.h" +#include "log.h" #include "llama.h" -#include +#include #include +#include +#include +#include #include #include -#include -#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 struct seq_draft { @@ -21,19 +26,28 @@ struct seq_draft { std::vector tokens; std::vector> dists; - struct gpt_sampler * smpl = nullptr; + struct common_sampler * smpl = nullptr; }; int main(int argc, char ** argv) { - gpt_params params; + common_params params; - auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_SPECULATIVE); - if (!gpt_params_parse(argc, argv, params, options)) { + // needed to get candidate probs even for temp <= 0.0 + params.sampling.n_probs = 128; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { return 1; } - if (params.model_draft.empty()) { - fprintf(stderr, "%s: error: --model-draft is required\n", __func__); + if (params.n_predict < -1) { + LOG_ERR("%s: --n-predict must be >= -1\n", __func__); + return 1; + } + + common_init(); + + if (params.speculative.model.empty()) { + LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } @@ -41,17 +55,11 @@ int main(int argc, char ** argv) { const int n_seq_dft = params.n_parallel; // probability threshold for splitting a draft branch (only for n_seq_dft > 1) - const float p_split = params.p_split; + const float p_draft_split = params.speculative.p_split; - std::default_random_engine rng(params.sparams.seed); + std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed); std::uniform_real_distribution<> u_dist; -#ifndef LOG_DISABLE_LOGS - log_set_target(log_filename_generator("speculative", "log")); - LOG_TEE("Log start\n"); - log_dump_cmdline(argc, argv); -#endif // LOG_DISABLE_LOGS - // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); @@ -63,66 +71,72 @@ int main(int argc, char ** argv) { llama_context * ctx_dft = NULL; // load the target model - llama_init_result llama_init_tgt = llama_init_from_gpt_params(params); - model_tgt = llama_init_tgt.model; - ctx_tgt = llama_init_tgt.context; + common_init_result llama_init_tgt = common_init_from_params(params); + + model_tgt = llama_init_tgt.model.get(); + ctx_tgt = llama_init_tgt.context.get(); // load the draft model - params.model = params.model_draft; - params.n_gpu_layers = params.n_gpu_layers_draft; - if (params.draft_cpuparams.n_threads > 0) { - params.cpuparams.n_threads = params.draft_cpuparams.n_threads; + params.devices = params.speculative.devices; + params.model = params.speculative.model; + params.n_gpu_layers = params.speculative.n_gpu_layers; + if (params.speculative.cpuparams.n_threads > 0) { + params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; } - params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads; - llama_init_result llama_init_dft = llama_init_from_gpt_params(params); - model_dft = llama_init_dft.model; - ctx_dft = llama_init_dft.context; + params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + common_init_result llama_init_dft = common_init_from_params(params); - const bool vocab_type_tgt = llama_vocab_type(model_tgt); - LOG("vocab_type tgt: %d\n", vocab_type_tgt); + model_dft = llama_init_dft.model.get(); + ctx_dft = llama_init_dft.context.get(); - const bool vocab_type_dft = llama_vocab_type(model_dft); - LOG("vocab_type dft: %d\n", vocab_type_dft); + const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + + const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); + LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(vocab_dft); + LOG_DBG("vocab_type dft: %d\n", vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { - fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__); - fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); + LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__); + LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); return 1; } if ( - llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || - llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || - llama_token_bos(model_tgt) != llama_token_bos(model_dft) || - llama_token_eos(model_tgt) != llama_token_eos(model_dft) + llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || + llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || + llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || + llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft) ) { - fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__); + LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); return 1; } { - const int n_vocab_tgt = llama_n_vocab(model_tgt); - const int n_vocab_dft = llama_n_vocab(model_dft); + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); const int vocab_diff = n_vocab_tgt > n_vocab_dft ? n_vocab_tgt - n_vocab_dft : n_vocab_dft - n_vocab_tgt; if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { - fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__); - fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", - n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); + LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return 1; } for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_token_get_text(model_tgt, i); - const char * token_text_dft = llama_token_get_text(model_dft, i); + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__); - fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i, - llama_token_to_piece(ctx_tgt, i).c_str(), - llama_token_to_piece(ctx_dft, i).c_str()); + LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); + LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, + common_token_to_piece(ctx_tgt, i).c_str(), + common_token_to_piece(ctx_dft, i).c_str()); return 1; } } @@ -131,40 +145,38 @@ int main(int argc, char ** argv) { // Tokenize the prompt std::vector inp; - inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true); + inp = common_tokenize(ctx_tgt, params.prompt, true, true); const int max_context_size = llama_n_ctx(ctx_tgt); const int max_tokens_list_size = max_context_size - 4; if ((int) inp.size() > max_tokens_list_size) { - fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); return 1; } - fprintf(stderr, "\n\n"); + LOG("\n\n"); for (auto id : inp) { - fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str()); + LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); } - fflush(stderr); - const int n_input = inp.size(); const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); - llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0)); + llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1)); + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input)); const auto t_enc_end = ggml_time_us(); // the 2 models should have the same vocab - //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); + //GGML_ASSERT(n_vocab == llama_vocab_n_tokens(model_dft)); // how many tokens to draft each time - int n_draft = params.n_draft; + int n_draft = params.speculative.n_max; int n_predict = 0; int n_drafted = 0; @@ -177,20 +189,18 @@ int main(int argc, char ** argv) { bool has_eos = false; // target model sampling context (reuse the llama_context's sampling instance) - struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams); - - struct llama_sampler * softmax = llama_sampler_init_softmax(); + struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // draft sequence data std::vector drafts(n_seq_dft); for (int s = 0; s < n_seq_dft; ++s) { - // allocate gpt_sampler for each draft sequence - drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams); + // allocate llama_sampler for each draft sequence + drafts[s].smpl = common_sampler_init(model_dft, params.sampling); } - llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); - llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft); + llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft); const auto t_dec_start = ggml_time_us(); @@ -210,7 +220,7 @@ int main(int argc, char ** argv) { active_seqs.insert(s); const auto & tokens = drafts[s].tokens; - LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); + LOG_DBG("draft %d: %s\n", s, string_from(ctx_dft, tokens).c_str()); } int i_dft = 0; @@ -226,11 +236,11 @@ int main(int argc, char ** argv) { // for stochastic sampling, attempt to match the token with the drafted tokens { bool accept = false; - if (params.sparams.temp > 0) { + if (params.sampling.temp > 0) { // stochastic verification - gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); + common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); - auto & dist_tgt = *gpt_sampler_get_candidates(smpl); + auto & dist_tgt = *common_sampler_get_candidates(smpl); float p_tgt = 0.0f; float p_dft = 0.0f; @@ -253,7 +263,7 @@ int main(int argc, char ** argv) { continue; } - LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); + LOG_DBG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); float r = u_dist(rng); llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true }; @@ -263,26 +273,27 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < dist_tgt.size; i++) { if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { p_tgt = dist_tgt.data[i].p; - } - if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) { - p_dft = dist_dft.data[i].p; - } - if (p_tgt && p_dft) { break; } } - LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt); + for (size_t i = 0; i < dist_dft.size; i++) { + if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) { + p_dft = dist_dft.data[i].p; + break; + } + } + LOG_DBG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt); if (r <= p_tgt / p_dft) { s_keep = s; accept = true; token_id = drafts[s].tokens[i_dft]; - token_str = llama_token_to_piece(ctx_tgt, token_id); - gpt_sampler_accept(smpl, token_id, true); + token_str = common_token_to_piece(ctx_tgt, token_id); + common_sampler_accept(smpl, token_id, true); - LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); + LOG_DBG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); break; } else { - LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str()); + LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], common_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str()); drafts[s].active = false; // calculate residual probability @@ -337,7 +348,7 @@ int main(int argc, char ** argv) { if (!accept) { // all drafted tokens were rejected // sample from the target model - LOG("all drafted tokens were rejected, sampling from residual distribution\n"); + LOG_DBG("all drafted tokens were rejected, sampling from residual distribution\n"); std::vector probs(dist_tgt.size); for (size_t i = 0; i < dist_tgt.size; ++i) { probs[i] = dist_tgt.data[i].p; @@ -348,21 +359,19 @@ int main(int argc, char ** argv) { const int idx = dist(rng); token_id = dist_tgt.data[idx].id; - gpt_sampler_accept(smpl, token_id, true); - token_str = llama_token_to_piece(ctx_tgt, token_id); + common_sampler_accept(smpl, token_id, true); + token_str = common_token_to_piece(ctx_tgt, token_id); } } else { // greedy verification // sample from the target model - LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); + LOG_DBG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - gpt_sampler_accept(smpl, token_id, true); + common_sampler_accept(smpl, token_id, true); - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str()); - - token_str = llama_token_to_piece(ctx_tgt, token_id); + token_str = common_token_to_piece(ctx_tgt, token_id); for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { @@ -370,7 +379,7 @@ int main(int argc, char ** argv) { } if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) { - LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str()); + LOG_DBG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str()); s_keep = s; accept = true; @@ -380,7 +389,7 @@ int main(int argc, char ** argv) { } } - if (llama_token_is_eog(model_tgt, token_id)) { + if (llama_vocab_is_eog(vocab_tgt, token_id)) { has_eos = true; } ++n_predict; @@ -392,26 +401,24 @@ int main(int argc, char ** argv) { ++i_dft; if (params.use_color) { // Color token according to its origin sequence - printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str()); + LOG("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str()); } else { - printf("%s", token_str.c_str()); + LOG("%s", token_str.c_str()); } - fflush(stdout); continue; } else { - printf("%s", token_str.c_str()); - fflush(stdout); + LOG("%s", token_str.c_str()); break; } } } { - LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str()); + LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str()); // TODO: simplify { - LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); + LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); llama_kv_cache_seq_keep(ctx_dft, s_keep); llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); @@ -434,24 +441,24 @@ int main(int argc, char ** argv) { drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); - llama_batch_clear(batch_dft); - llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); + common_batch_clear(batch_dft); + common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); - // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); + // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); llama_decode(ctx_dft, batch_dft); ++n_past_dft; } - if (n_predict > params.n_predict || has_eos) { + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { break; } if (drafts[0].smpl) { - gpt_sampler_free(drafts[0].smpl); + common_sampler_free(drafts[0].smpl); } - drafts[0].smpl = gpt_sampler_clone(smpl); + drafts[0].smpl = common_sampler_clone(smpl); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -464,8 +471,8 @@ int main(int argc, char ** argv) { drafts[0].drafting = true; drafts[0].i_batch_dft = 0; - llama_batch_clear(batch_tgt); - llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); + common_batch_clear(batch_tgt); + common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { @@ -480,21 +487,21 @@ int main(int argc, char ** argv) { continue; } - gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true); + common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true); - const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl); + const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl); for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) { - LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, s, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, s, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } std::vector sa(1, s); // attempt to split the branch if the probability is high enough for (int f = 1; f < 8; ++f) { - if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) { - LOG("splitting seq %3d into %3d\n", s, n_seq_cur); + if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) { + LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); @@ -521,9 +528,9 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; if (drafts[n_seq_cur].smpl) { - gpt_sampler_free(drafts[n_seq_cur].smpl); + common_sampler_free(drafts[n_seq_cur].smpl); } - drafts[n_seq_cur].smpl = gpt_sampler_clone(drafts[s].smpl); + drafts[n_seq_cur].smpl = common_sampler_clone(drafts[s].smpl); sa.push_back(n_seq_cur); @@ -539,7 +546,7 @@ int main(int argc, char ** argv) { const int s = sa[is]; - gpt_sampler_accept(drafts[s].smpl, id, true); + common_sampler_accept(drafts[s].smpl, id, true); drafts[s].tokens.push_back(id); // save cur_p.data into drafts[s].dists @@ -548,12 +555,12 @@ int main(int argc, char ** argv) { // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); - llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); + common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); // add the token to the batch for batched decoding with the draft model drafts[s].i_batch_dft = batch_dft.n_tokens; - llama_batch_add(batch_dft, id, n_past_cur, { s }, true); + common_batch_add(batch_dft, id, n_past_cur, { s }, true); if (batch_tgt.n_tokens > n_draft) { drafts[s].drafting = false; @@ -583,7 +590,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); } - // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); + // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); llama_decode(ctx_tgt, batch_tgt); ++n_past_tgt; } @@ -601,42 +608,37 @@ int main(int argc, char ** argv) { auto t_dec_end = ggml_time_us(); - LOG_TEE("\n\n"); + LOG("\n\n"); - LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); - LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); - LOG_TEE("\n"); - LOG_TEE("n_draft = %d\n", n_draft); - LOG_TEE("n_predict = %d\n", n_predict); - LOG_TEE("n_drafted = %d\n", n_drafted); - LOG_TEE("n_accept = %d\n", n_accept); - LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); - LOG_TEE("\ndraft:\n\n"); + LOG_INF("\n"); + LOG_INF("draft:\n\n"); // TODO: print sampling/grammar timings for all drafts - llama_perf_print(ctx_dft, LLAMA_PERF_TYPE_CONTEXT); + llama_perf_context_print(ctx_dft); - LOG_TEE("\ntarget:\n\n"); - gpt_perf_print(ctx_tgt, smpl); + LOG_INF("\n"); + LOG_INF("target:\n\n"); + common_perf_print(ctx_tgt, smpl); - gpt_sampler_free(smpl); + common_sampler_free(smpl); for (int s = 0; s < n_seq_dft; ++s) { - gpt_sampler_free(drafts[s].smpl); + common_sampler_free(drafts[s].smpl); } - llama_sampler_free(softmax); llama_batch_free(batch_dft); - llama_free(ctx_tgt); - llama_free_model(model_tgt); - - llama_free(ctx_dft); - llama_free_model(model_dft); - llama_backend_free(); - fprintf(stderr, "\n\n"); + LOG("\n\n"); return 0; } diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index 111366fb0..3b9ba3b2d 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -4,33 +4,24 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: MIT -INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" source /opt/intel/oneapi/setvars.sh -if [ $# -gt 0 ]; then - GGML_SYCL_DEVICE=$1 - GGML_SYCL_SINGLE_GPU=1 -else - GGML_SYCL_DEVICE=0 - GGML_SYCL_SINGLE_GPU=0 -fi - #export GGML_SYCL_DEBUG=1 - #ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer. -if [ $GGML_SYCL_SINGLE_GPU -eq 1 ]; then +INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:" +MODEL_FILE=models/llama-2-7b.Q4_0.gguf +NGL=33 +CONEXT=8192 + +if [ $# -gt 0 ]; then + GGML_SYCL_DEVICE=$1 echo "use $GGML_SYCL_DEVICE as main GPU" #use signle GPU only - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 -mg $GGML_SYCL_DEVICE -sm none + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT} -mg $GGML_SYCL_DEVICE -sm none + else #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT} fi - -#use main GPU only -#ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 -mg $GGML_SYCL_DEVICE -sm none - -#use multiple GPUs with same max compute units -#ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 diff --git a/examples/tokenize/CMakeLists.txt b/examples/tokenize/CMakeLists.txt index b704dcae1..1690b53e5 100644 --- a/examples/tokenize/CMakeLists.txt +++ b/examples/tokenize/CMakeLists.txt @@ -2,4 +2,4 @@ set(TARGET llama-tokenize) add_executable(${TARGET} tokenize.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index c817be566..7375759eb 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -1,11 +1,13 @@ #include "common.h" +//#include "log.h" // TODO: start using log.h #include "llama.h" -#include #include +#include #include #include #include +#include // TODO: remove me #if defined(_WIN32) #define WIN32_LEAN_AND_MEAN @@ -13,25 +15,26 @@ #include // For CommandLineToArgvW #endif -static void print_usage_information(const char * argv0, FILE * stream) { - fprintf(stream, "usage: %s [options]\n\n", argv0); - fprintf(stream, "The tokenize program tokenizes a prompt using a given model,\n"); - fprintf(stream, "and prints the resulting tokens to standard output.\n\n"); - fprintf(stream, "It needs a model file, a prompt, and optionally other flags\n"); - fprintf(stream, "to control the behavior of the tokenizer.\n\n"); - fprintf(stream, " The possible options are:\n"); - fprintf(stream, "\n"); - fprintf(stream, " -h, --help print this help and exit\n"); - fprintf(stream, " -m MODEL_PATH, --model MODEL_PATH path to model.\n"); - fprintf(stream, " --ids if given, only print numerical token IDs, and not token strings.\n"); - fprintf(stream, " The output format looks like [1, 2, 3], i.e. parseable by Python.\n"); - fprintf(stream, " -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n"); - fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); - fprintf(stream, " --stdin read prompt from standard input.\n"); - fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); - fprintf(stream, " --no-parse-special do not parse control tokens.\n"); - fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n"); - fprintf(stream, " --show-count print the total number of tokens.\n"); +static void print_usage_information(const char * argv0) { + printf("usage: %s [options]\n\n", argv0); + printf("The tokenize program tokenizes a prompt using a given model,\n"); + printf("and prints the resulting tokens to standard output.\n\n"); + printf("It needs a model file, a prompt, and optionally other flags\n"); + printf("to control the behavior of the tokenizer.\n\n"); + printf(" The possible options are:\n"); + printf("\n"); + printf(" -h, --help print this help and exit\n"); + printf(" -m MODEL_PATH, --model MODEL_PATH path to model.\n"); + printf(" --ids if given, only print numerical token IDs, and not token strings.\n"); + printf(" The output format looks like [1, 2, 3], i.e. parseable by Python.\n"); + printf(" -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n"); + printf(" -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); + printf(" --stdin read prompt from standard input.\n"); + printf(" --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); + printf(" --no-escape do not escape input (such as \\n, \\t, etc.).\n"); + printf(" --no-parse-special do not parse control tokens.\n"); + printf(" --log-disable disable logs. Makes stderr quiet when loading the model.\n"); + printf(" --show-count print the total number of tokens.\n"); } static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) { @@ -185,7 +188,7 @@ int main(int raw_argc, char ** raw_argv) { const int argc = argv.size(); if (argc <= 1) { - print_usage_information(argv[0].c_str(), stderr); + print_usage_information(argv[0].c_str()); return 1; } @@ -196,6 +199,7 @@ int main(int raw_argc, char ** raw_argv) { // variables where to put any arguments we see. bool printing_ids = false; bool no_bos = false; + bool no_escape = false; bool no_parse_special = false; bool disable_logging = false; bool show_token_count = false; @@ -214,7 +218,7 @@ int main(int raw_argc, char ** raw_argv) { for (; iarg < argc; ++iarg) { std::string arg{argv[iarg]}; if (arg == "-h" || arg == "--help") { - print_usage_information(argv[0].c_str(), stdout); + print_usage_information(argv[0].c_str()); return 0; } else if (arg == "--ids") { @@ -231,6 +235,9 @@ int main(int raw_argc, char ** raw_argv) { else if (arg == "--no-bos") { no_bos = true; } + else if (arg == "--no-escape") { + no_escape = true; + } else if (arg == "--no-parse-special") { no_parse_special = true; } @@ -323,10 +330,6 @@ int main(int raw_argc, char ** raw_argv) { // Start actually doing the tokenizing stuff. ////// -#ifdef LOG_DISABLE_LOGS - disable_logging = true; -#endif - if (disable_logging) { llama_log_set(llama_log_callback_null, NULL); } @@ -335,14 +338,16 @@ int main(int raw_argc, char ** raw_argv) { llama_model_params model_params = llama_model_default_params(); model_params.vocab_only = true; - llama_model * model = llama_load_model_from_file(model_path, model_params); + llama_model * model = llama_model_load_from_file(model_path, model_params); if (!model) { fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path); return 1; } + const llama_vocab * vocab = llama_model_get_vocab(model); + llama_context_params ctx_params = llama_context_default_params(); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (!ctx) { fprintf(stderr, "Error: could not create context.\n"); return 1; @@ -362,12 +367,17 @@ int main(int raw_argc, char ** raw_argv) { prompt = stdin_buffer.str(); } - const bool model_wants_add_bos = llama_add_bos_token(model); + const bool model_wants_add_bos = llama_vocab_get_add_bos(vocab); const bool add_bos = model_wants_add_bos && !no_bos; const bool parse_special = !no_parse_special; + const bool escape = !no_escape; + + if (escape) { + string_process_escapes(prompt); + } std::vector tokens; - tokens = ::llama_tokenize(model, prompt, add_bos, parse_special); + tokens = common_tokenize(vocab, prompt, add_bos, parse_special); if (printing_ids) { printf("["); @@ -382,7 +392,7 @@ int main(int raw_argc, char ** raw_argv) { } else { bool invalid_utf8 = false; printf("%6d -> '", tokens[i]); - write_utf8_cstr_to_stdout(llama_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8); + write_utf8_cstr_to_stdout(common_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8); if (invalid_utf8) { printf("' (utf-8 decode failure)\n"); } else { @@ -396,11 +406,11 @@ int main(int raw_argc, char ** raw_argv) { } if (show_token_count) { - printf("Total number of tokens: %ld\n", tokens.size()); + printf("Total number of tokens: %zu\n", tokens.size()); } // silence valgrind llama_free(ctx); - llama_free_model(model); + llama_model_free(model); return 0; } diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt new file mode 100644 index 000000000..c72bd814c --- /dev/null +++ b/examples/tts/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-tts) +add_executable(${TARGET} tts.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tts/README.md b/examples/tts/README.md new file mode 100644 index 000000000..4509763c6 --- /dev/null +++ b/examples/tts/README.md @@ -0,0 +1,117 @@ +# llama.cpp/example/tts +This example demonstrates the Text To Speech feature. It uses a +[model](https://www.outeai.com/blog/outetts-0.2-500m) from +[outeai](https://www.outeai.com/). + +## Quickstart +If you have built llama.cpp with `-DLLAMA_CURL=ON` you can simply run the +following command and the required models will be downloaded automatically: +```console +$ build/bin/llama-tts --tts-oute-default -p "Hello world" && aplay output.wav +``` +For details about the models and how to convert them to the required format +see the following sections. + +### Model conversion +Checkout or download the model that contains the LLM model: +```console +$ pushd models +$ git clone --branch main --single-branch --depth 1 https://huggingface.co/OuteAI/OuteTTS-0.2-500M +$ cd OuteTTS-0.2-500M && git lfs install && git lfs pull +$ popd +``` +Convert the model to .gguf format: +```console +(venv) python convert_hf_to_gguf.py models/OuteTTS-0.2-500M \ + --outfile models/outetts-0.2-0.5B-f16.gguf --outtype f16 +``` +The generated model will be `models/outetts-0.2-0.5B-f16.gguf`. + +We can optionally quantize this to Q8_0 using the following command: +```console +$ build/bin/llama-quantize models/outetts-0.2-0.5B-f16.gguf \ + models/outetts-0.2-0.5B-q8_0.gguf q8_0 +``` +The quantized model will be `models/outetts-0.2-0.5B-q8_0.gguf`. + +Next we do something simlar for the audio decoder. First download or checkout +the model for the voice decoder: +```console +$ pushd models +$ git clone --branch main --single-branch --depth 1 https://huggingface.co/novateur/WavTokenizer-large-speech-75token +$ cd WavTokenizer-large-speech-75token && git lfs install && git lfs pull +$ popd +``` +This model file is PyTorch checkpoint (.ckpt) and we first need to convert it to +huggingface format: +```console +(venv) python examples/tts/convert_pt_to_hf.py \ + models/WavTokenizer-large-speech-75token/wavtokenizer_large_speech_320_24k.ckpt +... +Model has been successfully converted and saved to models/WavTokenizer-large-speech-75token/model.safetensors +Metadata has been saved to models/WavTokenizer-large-speech-75token/index.json +Config has been saved to models/WavTokenizer-large-speech-75tokenconfig.json +``` +Then we can convert the huggingface format to gguf: +```console +(venv) python convert_hf_to_gguf.py models/WavTokenizer-large-speech-75token \ + --outfile models/wavtokenizer-large-75-f16.gguf --outtype f16 +... +INFO:hf-to-gguf:Model successfully exported to models/wavtokenizer-large-75-f16.gguf +``` + +### Running the example + +With both of the models generated, the LLM model and the voice decoder model, +we can run the example: +```console +$ build/bin/llama-tts -m ./models/outetts-0.2-0.5B-q8_0.gguf \ + -mv ./models/wavtokenizer-large-75-f16.gguf \ + -p "Hello world" +... +main: audio written to file 'output.wav' +``` +The output.wav file will contain the audio of the prompt. This can be heard +by playing the file with a media player. On Linux the following command will +play the audio: +```console +$ aplay output.wav +``` + +### Running the example with llama-server +Running this example with `llama-server` is also possible and requires two +server instances to be started. One will serve the LLM model and the other +will serve the voice decoder model. + +The LLM model server can be started with the following command: +```console +$ ./build/bin/llama-server -m ./models/outetts-0.2-0.5B-q8_0.gguf --port 8020 +``` + +And the voice decoder model server can be started using: +```console +./build/bin/llama-server -m ./models/wavtokenizer-large-75-f16.gguf --port 8021 --embeddings --pooling none +``` + +Then we can run [tts-outetts.py](tts-outetts.py) to generate the audio. + +First create a virtual environment for python and install the required +dependencies (this in only required to be done once): +```console +$ python3 -m venv venv +$ source venv/bin/activate +(venv) pip install requests numpy +``` + +And then run the python script using: +```conole +(venv) python ./examples/tts/tts-outetts.py http://localhost:8020 http://localhost:8021 "Hello world" +spectrogram generated: n_codes: 90, n_embd: 1282 +converting to audio ... +audio generated: 28800 samples +audio written to file "output.wav" +``` +And to play the audio we can again use aplay or any other media player: +```console +$ aplay output.wav +``` diff --git a/examples/tts/convert_pt_to_hf.py b/examples/tts/convert_pt_to_hf.py new file mode 100644 index 000000000..8909a65fd --- /dev/null +++ b/examples/tts/convert_pt_to_hf.py @@ -0,0 +1,180 @@ +# convert the https://huggingface.co/novateur/WavTokenizer-large-speech-75token to HF format +# the goal is to be able to reuse the convert_hf_to_gguf.py after that to create a GGUF file with the WavTokenizer decoder +# +# TODO: this script is LLM-generated and probably very inefficient and should be rewritten + +import torch +import json +import os +import sys +import re + +from safetensors.torch import save_file + +# default +model_path = './model.pt'; + +# read from CLI +if len(sys.argv) > 1: + model_path = sys.argv[1] + +# get the directory of the input model +path_dst = os.path.dirname(model_path) + +print(f"Loading model from {model_path}") + +model = torch.load(model_path, map_location='cpu') + +#print(model) + +# print all keys +for key in model.keys(): + print(key) + if key == 'hyper_parameters': + #print(model[key]) + # dump as json pretty + print(json.dumps(model[key], indent=4)) + #if key != 'state_dict' and key != 'optimizer_states': + # print(model[key]) + +# Check if the loaded model is a state_dict or a model instance +if isinstance(model, torch.nn.Module): + state_dict = model.state_dict() +else: + state_dict = model + +# Print the structure of the state_dict to understand its format +print("State dictionary keys:") +for key in state_dict.keys(): + print(key) + +# Ensure the state_dict is flat and contains only torch.Tensor objects +def flatten_state_dict(state_dict, parent_key='', sep='.'): + items = [] + items_new = [] + + for k, v in state_dict.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, torch.Tensor): + items.append((new_key, v)) + elif isinstance(v, dict): + items.extend(flatten_state_dict(v, new_key, sep=sep).items()) + return dict(items) + + size_total_mb = 0 + + for key, value in list(items): + # keep only what we need for inference + if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \ + not key.startswith('state_dict.backbone.') and \ + not key.startswith('state_dict.head.out'): + print('Skipping key: ', key) + continue + + new_key = key + + new_key = new_key.replace('state_dict.', '') + new_key = new_key.replace('pos_net', 'posnet') + + # check if matches "backbone.posnet.%d.bias" or "backbone.posnet.%d.weight" + if new_key.startswith("backbone.posnet."): + match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key) + if match: + new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}" + + # "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight" + if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed": + new_key = "backbone.embedding.weight" + + # these are the only rows used + # ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100 + if new_key.endswith("norm.scale.weight"): + new_key = new_key.replace("norm.scale.weight", "norm.weight") + value = value[0] + + if new_key.endswith("norm.shift.weight"): + new_key = new_key.replace("norm.shift.weight", "norm.bias") + value = value[0] + + if new_key.endswith("gamma"): + new_key = new_key.replace("gamma", "gamma.weight") + + # convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias + if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")): + value = value.unsqueeze(1) + + if new_key.endswith("dwconv.bias"): + value = value.unsqueeze(1) + + size_mb = value.element_size() * value.nelement() / (1024 * 1024) + print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}") + + size_total_mb += size_mb + + #print(key, '->', new_key, ': ', value) + #print(key, '->', new_key) + + items_new.append((new_key, value)) + + print(f"Total size: {size_total_mb:8.2f} MB") + + return dict(items_new) + +flattened_state_dict = flatten_state_dict(state_dict) + + +# Convert the model to the safetensors format +output_path = path_dst + '/model.safetensors' +save_file(flattened_state_dict, output_path) + +print(f"Model has been successfully converted and saved to {output_path}") + +# Calculate the total size of the .safetensors file +total_size = os.path.getsize(output_path) + +# Create the weight map +weight_map = { + "model.safetensors": ["*"] # Assuming all weights are in one file +} + +# Create metadata for the index.json file +metadata = { + "total_size": total_size, + "weight_map": weight_map +} + +# Save the metadata to index.json +index_path = path_dst + '/index.json' +with open(index_path, 'w') as f: + json.dump(metadata, f, indent=4) + +print(f"Metadata has been saved to {index_path}") + +config = { + "architectures": [ + "WavTokenizerDec" + ], + "hidden_size": 1282, + "n_embd_features": 512, + "n_ff": 2304, + "vocab_size": 4096, + "n_head": 1, + "layer_norm_epsilon": 1e-6, + "group_norm_epsilon": 1e-6, + "group_norm_groups": 32, + "max_position_embeddings": 8192, # ? + "n_layer": 12, + "posnet": { + "n_embd": 768, + "n_layer": 6 + }, + "convnext": { + "n_embd": 768, + "n_layer": 12 + }, +} + +with open(path_dst + '/config.json', 'w') as f: + json.dump(config, f, indent=4) + +print(f"Config has been saved to {path_dst + 'config.json'}") diff --git a/examples/tts/tts-outetts.py b/examples/tts/tts-outetts.py new file mode 100644 index 000000000..3791f9fc3 --- /dev/null +++ b/examples/tts/tts-outetts.py @@ -0,0 +1,299 @@ +import sys +#import json +#import struct +import requests +import re +import struct +import numpy as np +from concurrent.futures import ThreadPoolExecutor + + +def fill_hann_window(size, periodic=True): + if periodic: + return np.hanning(size + 1)[:-1] + return np.hanning(size) + + +def irfft(n_fft, complex_input): + return np.fft.irfft(complex_input, n=n_fft) + + +def fold(buffer, n_out, n_win, n_hop, n_pad): + result = np.zeros(n_out) + n_frames = len(buffer) // n_win + + for i in range(n_frames): + start = i * n_hop + end = start + n_win + result[start:end] += buffer[i * n_win:(i + 1) * n_win] + + return result[n_pad:-n_pad] if n_pad > 0 else result + + +def process_frame(args): + l, n_fft, ST, hann = args + frame = irfft(n_fft, ST[l]) + frame = frame * hann + hann2 = hann * hann + return frame, hann2 + + +def embd_to_audio(embd, n_codes, n_embd, n_thread=4): + embd = np.asarray(embd, dtype=np.float32).reshape(n_codes, n_embd) + + n_fft = 1280 + n_hop = 320 + n_win = 1280 + n_pad = (n_win - n_hop) // 2 + n_out = (n_codes - 1) * n_hop + n_win + + hann = fill_hann_window(n_fft, True) + + E = np.zeros((n_embd, n_codes), dtype=np.float32) + for l in range(n_codes): + for k in range(n_embd): + E[k, l] = embd[l, k] + + half_embd = n_embd // 2 + S = np.zeros((n_codes, half_embd + 1), dtype=np.complex64) + + for k in range(half_embd): + for l in range(n_codes): + mag = E[k, l] + phi = E[k + half_embd, l] + + mag = np.clip(np.exp(mag), 0, 1e2) + S[l, k] = mag * np.exp(1j * phi) + + res = np.zeros(n_codes * n_fft) + hann2_buffer = np.zeros(n_codes * n_fft) + + with ThreadPoolExecutor(max_workers=n_thread) as executor: + args = [(l, n_fft, S, hann) for l in range(n_codes)] + results = list(executor.map(process_frame, args)) + + for l, (frame, hann2) in enumerate(results): + res[l*n_fft:(l+1)*n_fft] = frame + hann2_buffer[l*n_fft:(l+1)*n_fft] = hann2 + + audio = fold(res, n_out, n_win, n_hop, n_pad) + env = fold(hann2_buffer, n_out, n_win, n_hop, n_pad) + + mask = env > 1e-10 + audio[mask] /= env[mask] + + return audio + + +def save_wav(filename, audio_data, sample_rate): + num_channels = 1 + bits_per_sample = 16 + bytes_per_sample = bits_per_sample // 8 + data_size = len(audio_data) * bytes_per_sample + byte_rate = sample_rate * num_channels * bytes_per_sample + block_align = num_channels * bytes_per_sample + chunk_size = 36 + data_size # 36 = size of header minus first 8 bytes + + header = struct.pack( + '<4sI4s4sIHHIIHH4sI', + b'RIFF', + chunk_size, + b'WAVE', + b'fmt ', + 16, # fmt chunk size + 1, # audio format (PCM) + num_channels, + sample_rate, + byte_rate, + block_align, + bits_per_sample, + b'data', + data_size + ) + + audio_data = np.clip(audio_data * 32767, -32768, 32767) + pcm_data = audio_data.astype(np.int16) + + with open(filename, 'wb') as f: + f.write(header) + f.write(pcm_data.tobytes()) + + +def process_text(text: str): + text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed + text = re.sub(r'[-_/,\.\\]', ' ', text) + text = re.sub(r'[^a-z\s]', '', text) + text = re.sub(r'\s+', ' ', text).strip() + return text.split() + +# usage: +# python tts-outetts.py http://server-llm:port http://server-dec:port "text" + +if len(sys.argv) <= 3: + print("usage: python tts-outetts.py http://server-llm:port http://server-dec:port \"text\"") + exit(1) + +host_llm = sys.argv[1] +host_dec = sys.argv[2] +text = sys.argv[3] + +prefix = """<|im_start|> +<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>""" + +words = process_text(text) +words = "<|text_sep|>".join([i.strip() for i in words]) +words += "<|text_end|>\n" + +# voice data +# TODO: load from json +#suffix = """<|audio_start|> +#the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|> +#overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|> +#package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|> +#from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|> +#just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|> +#two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|> +#people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|> +#is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|> +#pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|> +#remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|> +#sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|> +#i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|> +#have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|> +#some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|> +#critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|> +#about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|> +#some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|> +#of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|> +#the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|> +#gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|> +#aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|> +#but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|> +#its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|> +#still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|> +#really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|> +#enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|> +#and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|> +#it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|> +#looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> +#lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>""" + +# TODO: tokenization is slow for some reason - here is pre-tokenized input +suffix = [ 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, 152460, 153375, 151670, 198, 74455, + 155808, 151669, 151799, 151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470, 151970, 153413, + 152419, 153334, 153289, 153374, 153199, 152040, 153260, 152721, 152680, 153297, 152419, 153248, 152400, + 152691, 153368, 153437, 151670, 198, 1722, 155828, 151669, 152607, 152256, 152991, 152299, 152688, 153163, + 153016, 152789, 153198, 152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207, 152461, 153321, + 153309, 151750, 152137, 153340, 152573, 152267, 153347, 151789, 152681, 153339, 151992, 152512, 151751, + 152179, 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, 152311, 151670, 198, 1499, 155791, + 151669, 152276, 152454, 153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226, 153043, 152325, + 153267, 152622, 151670, 198, 4250, 155797, 151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271, + 152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213, 152112, 153204, 151722, 152542, 151670, 198, + 19789, 155796, 151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002, 152191, 151734, 152312, 152810, + 152237, 153224, 153169, 153224, 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, 152265, 151946, + 151808, 152412, 152363, 152305, 153156, 152733, 152810, 153157, 152016, 152100, 152069, 153234, 152317, + 152589, 152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504, 153376, 152272, 152433, 152325, + 151941, 151670, 198, 285, 155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381, 152474, 152680, + 152157, 153255, 152324, 151682, 151670, 198, 32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682, + 152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, 153070, 151883, 152890, 152489, 153144, + 153375, 152358, 151685, 152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, 151902, 152720, + 153377, 152027, 152378, 152821, 153207, 153459, 153028, 153068, 152507, 153255, 152158, 152921, 151958, + 152609, 152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470, 152606, 152162, 152186, 153071, + 152244, 153118, 153375, 153018, 152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736, 153380, + 153502, 152702, 152115, 153181, 152735, 153277, 153457, 152393, 153112, 152595, 151670, 198, 19098, 155808, + 151669, 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, 153163, 152922, 153402, 152034, + 152591, 153438, 152215, 151673, 152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482, 152718, + 152862, 153347, 151670, 198, 72, 155780, 151669, 151795, 152111, 152746, 152377, 153471, 152309, 151670, 198, + 19016, 155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701, 152939, 152536, 152091, 151815, 152733, + 151672, 151670, 198, 14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042, 153504, 152589, 153333, + 151839, 151941, 153038, 153180, 151670, 198, 36996, 8303, 155832, 151669, 152231, 152256, 152835, 152801, + 152985, 153400, 152393, 152818, 152765, 152249, 152600, 151699, 152302, 152752, 153018, 153009, 151992, + 153054, 152847, 153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458, 152048, 152757, 152428, + 153195, 151906, 153006, 153178, 153250, 152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418, + 152228, 152733, 151670, 198, 9096, 155801, 151669, 151698, 153321, 152217, 153039, 152935, 153400, 152122, + 152531, 153106, 152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851, 152901, 152885, 152594, + 153446, 153080, 151670, 198, 14689, 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, 151673, + 151690, 151698, 152714, 152846, 152981, 153171, 153384, 153364, 153188, 153246, 151670, 198, 1055, 155779, + 151669, 151869, 152388, 152711, 153334, 151736, 151670, 198, 1782, 155780, 151669, 153483, 153240, 152241, + 152558, 152697, 153046, 151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605, 153034, 153434, + 153372, 153347, 151887, 152453, 152758, 152133, 152510, 152694, 152431, 152321, 153088, 152676, 152223, + 152581, 152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032, 152903, 152859, 152989, 151748, + 152669, 152661, 152650, 152409, 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469, 152988, + 152894, 151819, 152391, 153019, 152058, 153062, 153230, 151826, 152112, 152306, 152264, 152769, 153390, + 152384, 152435, 152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540, 151919, 151893, 152558, + 152817, 152946, 152956, 152129, 152715, 153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450, + 151670, 198, 8088, 155792, 151669, 152452, 153497, 153353, 152679, 152533, 152382, 152374, 152611, 153341, + 153163, 152285, 153411, 152495, 153141, 152320, 151670, 198, 1199, 155781, 151669, 151764, 152360, 153295, + 152634, 153342, 152199, 152271, 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, 152016, 152385, + 152629, 152495, 151826, 153321, 152958, 152180, 151886, 153432, 152922, 152128, 153024, 153040, 152593, + 152287, 151677, 151670, 198, 53660, 155808, 151669, 151727, 152092, 152680, 153331, 151699, 152316, 152938, + 152289, 152433, 153384, 151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691, 152489, 151941, + 152049, 152034, 153053, 152179, 153160, 151676, 153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350, + 152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234, 153135, 152291, 153235, 152143, 152583, + 152402, 153483, 152678, 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, 152548, 153442, + 152109, 152659, 153325, 152781, 152570, 152957, 151752, 152265, 153381, 152515, 151670, 198, 437, 155787, + 151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174, 151792, 153409, 153327, 152990, 151670, 198, + 275, 155781, 151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974, 151670, 198, 94273, 155799, + 151669, 152953, 152938, 153427, 152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331, 152257, + 152987, 152777, 153448, 152408, 151696, 152408, 152326, 152699, 151670, 198, 385, 16239, 155828, 151669, + 152306, 152268, 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, 152918, 152923, 152467, + 152331, 153053, 153330, 151889, 153444, 152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751, + 152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499, 152109, 152255, 151739, 152267, 152759, + 153318, 153165, 153349, 151670, ] + +response = requests.post( + host_llm + "/completion", + json={ + "prompt": [prefix + words, *suffix], + "n_predict": 1024, + "cache_prompt": True, + "return_tokens": True, + "samplers": ["top_k"], + "top_k": 16, + "seed": 1003, + } +) + +response_json = response.json() + +#print(json.dumps(response_json, indent=4)) +#print(json.dumps(response_json["prompt"], indent=4).replace("\\n", "\n")) +#print(json.dumps(response_json["timings"], indent=4)) +#print(json.dumps(response_json["tokens"], indent=4)) + +codes = response_json["tokens"] + +codes = [t - 151672 for t in codes if t >= 151672 and t <= 155772] + +response = requests.post( + host_dec + "/embeddings", + json={ + "input": [*codes], + } +) + +response_json = response.json() + +#print(json.dumps(response_json, indent=4)) + +# spectrogram +embd = response_json[0]["embedding"] + +n_codes = len(embd) +n_embd = len(embd[0]) + +print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd)) + +# post-process the spectrogram to convert to audio +print('converting to audio ...') +audio = embd_to_audio(embd, n_codes, n_embd) +print('audio generated: %d samples' % len(audio)) + +filename = "output.wav" +sample_rate = 24000 # sampling rate + +# zero out first 0.25 seconds +audio[:24000 // 4] = 0.0 + +save_wav(filename, audio, sample_rate) +print('audio written to file "%s"' % filename) diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp new file mode 100644 index 000000000..f78f76303 --- /dev/null +++ b/examples/tts/tts.cpp @@ -0,0 +1,973 @@ +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "log.h" +#include "llama.h" + +#define _USE_MATH_DEFINES // For M_PI on MSVC + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// +// Terminal utils +// + +#define SQR(X) ((X) * (X)) +#define UNCUBE(x) x < 48 ? 0 : x < 115 ? 1 : (x - 35) / 40 + +/** + * Quantizes 24-bit RGB to xterm256 code range [16,256). + */ +static int rgb2xterm256(int r, int g, int b) { + unsigned char cube[] = {0, 0137, 0207, 0257, 0327, 0377}; + int av, ir, ig, ib, il, qr, qg, qb, ql; + av = r * .299 + g * .587 + b * .114 + .5; + ql = (il = av > 238 ? 23 : (av - 3) / 10) * 10 + 8; + qr = cube[(ir = UNCUBE(r))]; + qg = cube[(ig = UNCUBE(g))]; + qb = cube[(ib = UNCUBE(b))]; + if (SQR(qr - r) + SQR(qg - g) + SQR(qb - b) <= + SQR(ql - r) + SQR(ql - g) + SQR(ql - b)) + return ir * 36 + ig * 6 + ib + 020; + return il + 0350; +} + +static std::string set_xterm256_foreground(int r, int g, int b) { + int x = rgb2xterm256(r, g, b); + std::ostringstream oss; + oss << "\033[38;5;" << x << "m"; + return oss.str(); +} + +const std::vector k_colors = { + set_xterm256_foreground(220, 5, 12), + set_xterm256_foreground(232, 96, 28), + set_xterm256_foreground(241, 147, 45), + set_xterm256_foreground(246, 193, 65), + set_xterm256_foreground(247, 240, 86), + set_xterm256_foreground(144, 201, 135), + set_xterm256_foreground( 78, 178, 101), +}; + +static void print_usage(int, char ** argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -p \"Hello!\"\n", argv[0]); + LOG("\n"); +} + +struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; +}; + +static void save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { + std::ofstream file(fname, std::ios::binary); + if (!file) { + LOG_ERR("%s: Failed to open file '%s' for writing", __func__, fname.c_str()); + return; + } + + wav_header header; + header.sample_rate = sample_rate; + header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); + header.block_align = header.num_channels * (header.bits_per_sample / 8); + header.data_size = data.size() * (header.bits_per_sample / 8); + header.chunk_size = 36 + header.data_size; + + file.write(reinterpret_cast(&header), sizeof(header)); + + for (const auto & sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); + file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); + } + + file.close(); +} + +static void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } +} + +// very poor-man fft +static void twiddle(float * real, float * imag, int k, int N) { + float angle = 2 * M_PI * k / N; + *real = cos(angle); + *imag = sin(angle); +} + +static void irfft(int n, const float * inp_cplx, float * out_real) { + int N = n / 2 + 1; + + std::vector real_input(N); + std::vector imag_input(N); + for (int i = 0; i < N; ++i) { + real_input[i] = inp_cplx[2 * i]; + imag_input[i] = inp_cplx[2 * i + 1]; + } + + std::vector real_output(n); + std::vector imag_output(n); + + for (int k = 0; k < n; ++k) { + real_output[k] = 0.0f; + imag_output[k] = 0.0f; + for (int m = 0; m < N; ++m) { + float twiddle_real; + float twiddle_imag; + + twiddle(&twiddle_real, &twiddle_imag, k * m, n); + + real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag; + imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real; + } + } + + for (int i = 0; i < n; ++i) { + out_real[i] = real_output[i] / N; + } +} + +// +// y = torch.nn.functional.fold( +// data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), +// )[:, 0, 0, pad:-pad] +// +// data.shape = torch.Size([1, 1280, 261]) +// output_size = 84480 +// win_length = 1280 +// hop_length = 320 +// pad = 480 +// +static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output) { + int64_t output_height = n_out; + int64_t kernel_w = n_win; + int64_t stride_w = n_hop; + int64_t width = n_out; + + output.resize(width, 0.0f); + + int64_t col_idx = 0; + for (int64_t w_col = 0; w_col < width; ++w_col) { + int64_t start = w_col * stride_w - n_pad; + int64_t end = start + kernel_w; + + for (int64_t w_im = start; w_im < end; ++w_im) { + if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) { + output[w_im] += data[col_idx]; + } + col_idx++; + } + } + + output.resize(n_out - 2 * n_pad); +} + +// TODO: not optimized at all +static std::vector embd_to_audio( + const float * embd, + const int n_codes, + const int n_embd, + const int n_thread) { + const int n_fft = 1280; + const int n_hop = 320; + const int n_win = 1280; + const int n_pad = (n_win - n_hop)/2; + const int n_out = (n_codes - 1)*n_hop + n_win; + + std::vector hann(n_fft); + + fill_hann_window(hann.size(), true, hann.data()); + + int n_spec = n_embd*n_codes; + + std::vector E (n_spec); + std::vector S (n_spec); + std::vector ST(n_spec); + + for (int l = 0; l < n_codes; ++l) { + for (int k = 0; k < n_embd; ++k) { + E[k*n_codes + l] = embd[l*n_embd + k]; + } + } + + for (int k = 0; k < n_embd/2; ++k) { + for (int l = 0; l < n_codes; ++l) { + float mag = E[(k )*n_codes + l]; + float phi = E[(k + n_embd/2)*n_codes + l]; + + mag = exp(mag); + + if (mag > 1e2) { + mag = 1e2; + } + S[2*(k*n_codes + l) + 0] = mag*cosf(phi); + S[2*(k*n_codes + l) + 1] = mag*sinf(phi); + } + } + + for (int l = 0; l < n_codes; ++l) { + for (int k = 0; k < n_embd/2; ++k) { + ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0]; + ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1]; + } + } + + std::vector res (n_codes*n_fft); + std::vector hann2(n_codes*n_fft); + + std::vector workers(n_thread); + for (int i = 0; i < n_thread; ++i) { + workers[i] = std::thread([&, i]() { + for (int l = i; l < n_codes; l += n_thread) { + irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); + for (int j = 0; j < n_fft; ++j) { + res [l*n_fft + j] *= hann[j]; + hann2[l*n_fft + j] = hann[j] * hann[j]; + } + } + }); + } + for (int i = 0; i < n_thread; ++i) { + workers[i].join(); + } + + std::vector audio; + std::vector env; + + fold(res, n_out, n_win, n_hop, n_pad, audio); + fold(hann2, n_out, n_win, n_hop, n_pad, env); // TODO: can be done once + + for (size_t i = 0; i < audio.size(); ++i) { + audio[i] /= env[i]; + } + + return audio; +} + +static const std::map ones = { + {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}, + {5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"}, + {10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"}, + {15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"} +}; + +static const std::map tens = { + {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"}, + {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"} +}; + +// Convert a number less than 1000 to words +static std::string convert_less_than_thousand(int num) { + std::string result; + + if (num >= 100) { + result += ones.at(num / 100) + " hundred "; + num %= 100; + } + + if (num >= 20) { + result += tens.at(num / 10); + if (num % 10 > 0) { + result += "-" + ones.at(num % 10); + } + } else if (num > 0) { + result += ones.at(num); + } + + return result; +} + +static std::string number_to_words(const std::string & number_str) { + try { + size_t decimal_pos = number_str.find('.'); + std::string integer_part = number_str.substr(0, decimal_pos); + + int int_number = std::stoi(integer_part); + std::string result; + + if (int_number == 0) { + result = "zero"; + } else { + if (int_number >= 1000000000) { + int billions = int_number / 1000000000; + result += convert_less_than_thousand(billions) + " billion "; + int_number %= 1000000000; + } + + if (int_number >= 1000000) { + int millions = int_number / 1000000; + result += convert_less_than_thousand(millions) + " million "; + int_number %= 1000000; + } + + if (int_number >= 1000) { + int thousands = int_number / 1000; + result += convert_less_than_thousand(thousands) + " thousand "; + int_number %= 1000; + } + + if (int_number > 0) { + result += convert_less_than_thousand(int_number); + } + } + + // Handle decimal part + if (decimal_pos != std::string::npos) { + result += " point"; + std::string decimal_part = number_str.substr(decimal_pos + 1); + for (char digit : decimal_part) { + result += " " + ones.at(digit - '0'); + } + } + + return result; + } catch (const std::exception& e) { + // Skip if fails + return " "; + } +} + +static std::string replace_numbers_with_words(const std::string & input_text) { + std::regex number_pattern(R"(\d+(\.\d+)?)"); + std::string result; + auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern); + auto end = std::sregex_iterator(); + + size_t last_pos = 0; + for (std::sregex_iterator i = it; i != end; ++i) { + const std::smatch& match = *i; + result.append(input_text, last_pos, match.position() - last_pos); + result.append(number_to_words(match.str())); + last_pos = match.position() + match.length(); + } + result.append(input_text, last_pos); + + return result; +} + +// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 +static std::string process_text(const std::string & text) { + + // For now I skipped text romanization as I am unsure how to handle + // uroman and MeCab implementations in C++ + // maybe something like https://github.com/anyascii/anyascii/ could work. + // currently only English would be supported in this function + + std::string processed_text = replace_numbers_with_words(text); + + std::transform(processed_text.begin(), processed_text.end(), + processed_text.begin(), ::tolower); + + std::regex special_chars(R"([-_/,\.\\])"); + processed_text = std::regex_replace(processed_text, special_chars, " "); + + std::regex non_alpha(R"([^a-z\s])"); + processed_text = std::regex_replace(processed_text, non_alpha, ""); + + std::regex multiple_spaces(R"(\s+)"); + processed_text = std::regex_replace(processed_text, multiple_spaces, " "); + + processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), ""); + + /* + Replace spaces with the separator token same as in line 365 + + for (auto & c : prompt_user) { + if (c == ' ') { + prompt_clean += "<|text_sep|>"; + */ + processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>"); + + return processed_text; +} + +static void prompt_add(llama_tokens & prompt, llama_token token) { + prompt.push_back(token); +} + +static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) { + prompt.insert(prompt.end(), tokens.begin(), tokens.end()); +} + +static void prompt_add(llama_tokens & prompt, const llama_vocab * vocab, const std::string & txt, bool add_special, bool parse_special) { + auto tmp = common_tokenize(vocab, txt, add_special, parse_special); + prompt_add(prompt, tmp); +} + +static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) { + prompt.clear(); + + prompt_add(prompt, vocab, "<|im_start|>\n", true, true); +} + +static std::vector prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) { + const std::string& delimiter = "<|text_sep|>"; + + std::vector result; + size_t start = 0; + size_t end = str.find(delimiter); + + //first token is always a newline, as it was not previously added + result.push_back(common_tokenize(vocab, "\n", false, true)[0]); + + while (end != std::string::npos) { + std::string current_word = str.substr(start, end - start); + auto tmp = common_tokenize(vocab, current_word, false, true); + result.push_back(tmp[0]); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + // Add the last part + std::string current_word = str.substr(start); + auto tmp = common_tokenize(vocab, current_word, false, true); + if (tmp.size() > 0) { + result.push_back(tmp[0]); + } + return result; +} + +int main(int argc, char ** argv) { + common_params params; + + params.prompt = ""; + + params.n_predict = 4096; + params.n_batch = 8192; + params.n_ctx = 8192; + + params.sampling.top_k = 4; + params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, }; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) { + return 1; + } + + const int n_parallel = params.n_parallel; + const int n_predict = params.n_predict; + + common_init(); + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + llama_model * model_ttc = NULL; // text-to-codes + llama_model * model_cts = NULL; // codes-to-speech + + llama_context * ctx_ttc = NULL; + llama_context * ctx_cts = NULL; + + common_init_result llama_init_ttc = common_init_from_params(params); + + model_ttc = llama_init_ttc.model.get(); + ctx_ttc = llama_init_ttc.context.get(); + + const llama_vocab * vocab = llama_model_get_vocab(model_ttc); + + // TODO: refactor in a common struct + params.model = params.vocoder.model; + params.model_url = params.vocoder.model_url; + params.hf_repo = params.vocoder.hf_repo; + params.hf_file = params.vocoder.hf_file; + + params.embedding = true; + + common_init_result llama_init_cts = common_init_from_params(params); + + model_cts = llama_init_cts.model.get(); + ctx_cts = llama_init_cts.context.get(); + + std::vector smpl(n_parallel); + for (int i = 0; i < n_parallel; ++i) { + params.sampling.no_perf = (i != 0); + params.sampling.seed = params.sampling.seed + 1; + + smpl[i] = common_sampler_init(model_ttc, params.sampling); + } + + LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl[0])); + LOG_INF("sampler params: \n%s\n", params.sampling.print().c_str()); + LOG_INF("sampler chain: %s\n", common_sampler_print(smpl[0]).c_str()); + + LOG_INF("%s: loading done\n", __func__); + + const auto t_main_start = ggml_time_us(); + + std::vector codes; + std::vector guide_tokens; + + // process prompt and generate voice codes + { + LOG_INF("%s: constructing prompt ..\n", __func__); + + std::vector prompt_inp; + + prompt_init(prompt_inp, vocab); + + prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true); + + // convert the input text into the necessary format expected by OuteTTS + { + std::string prompt_clean = process_text(params.prompt); + if (params.vocoder.use_guide_tokens) { + guide_tokens = prepare_guide_tokens(vocab, prompt_clean); + } + + LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); + + prompt_add(prompt_inp, vocab, prompt_clean, false, true); + } + + prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true); + + // disabled to save time on tokenizing each time + // TODO: load voices from the json files +#if 0 + const std::string voice_data = R"(<|audio_start|> +the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|> +overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|> +package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|> +from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|> +just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|> +two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|> +people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|> +is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|> +pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|> +remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|> +sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|> +i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|> +have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|> +some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|> +critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|> +about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|> +some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|> +of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|> +the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|> +gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|> +aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|> +but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|> +its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|> +still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|> +really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|> +enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|> +and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|> +it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|> +looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> +lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)"; + + auto tmp = common_tokenize(vocab, voice_data, false, true); + printf("\n\n"); + for (int i = 0; i < tmp.size(); ++i) { + printf("%d, ", tmp[i]); + } + printf("\n\n"); +#else + prompt_add(prompt_inp, llama_tokens { + 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, + 152460, 153375, 151670, 198, 74455, 155808, 151669, 151799, + 151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470, + 151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040, + 153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691, + 153368, 153437, 151670, 198, 1722, 155828, 151669, 152607, + 152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198, + 152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207, + 152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267, + 153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179, + 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, + 152311, 151670, 198, 1499, 155791, 151669, 152276, 152454, + 153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226, + 153043, 152325, 153267, 152622, 151670, 198, 4250, 155797, + 151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271, + 152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213, + 152112, 153204, 151722, 152542, 151670, 198, 19789, 155796, + 151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002, + 152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224, + 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, + 152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733, + 152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589, + 152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504, + 153376, 152272, 152433, 152325, 151941, 151670, 198, 285, + 155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381, + 152474, 152680, 152157, 153255, 152324, 151682, 151670, 198, + 32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682, + 152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, + 153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685, + 152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, + 151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459, + 153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609, + 152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470, + 152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018, + 152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736, + 153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457, + 152393, 153112, 152595, 151670, 198, 19098, 155808, 151669, + 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, + 153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673, + 152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482, + 152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795, + 152111, 152746, 152377, 153471, 152309, 151670, 198, 19016, + 155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701, + 152939, 152536, 152091, 151815, 152733, 151672, 151670, 198, + 14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042, + 153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670, + 198, 36996, 8303, 155832, 151669, 152231, 152256, 152835, + 152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600, + 151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847, + 153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458, + 152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250, + 152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418, + 152228, 152733, 151670, 198, 9096, 155801, 151669, 151698, + 153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106, + 152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851, + 152901, 152885, 152594, 153446, 153080, 151670, 198, 14689, + 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, + 151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384, + 153364, 153188, 153246, 151670, 198, 1055, 155779, 151669, + 151869, 152388, 152711, 153334, 151736, 151670, 198, 1782, + 155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046, + 151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605, + 153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133, + 152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581, + 152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032, + 152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409, + 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469, + 152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230, + 151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435, + 152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540, + 151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715, + 153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450, + 151670, 198, 8088, 155792, 151669, 152452, 153497, 153353, + 152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285, + 153411, 152495, 153141, 152320, 151670, 198, 1199, 155781, + 151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271, + 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, + 152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180, + 151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287, + 151677, 151670, 198, 53660, 155808, 151669, 151727, 152092, + 152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384, + 151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691, + 152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676, + 153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350, + 152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234, + 153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678, + 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, + 152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957, + 151752, 152265, 153381, 152515, 151670, 198, 437, 155787, + 151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174, + 151792, 153409, 153327, 152990, 151670, 198, 275, 155781, + 151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974, + 151670, 198, 94273, 155799, 151669, 152953, 152938, 153427, + 152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331, + 152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326, + 152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268, + 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, + 152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444, + 152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751, + 152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499, + 152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349, + 151670,}); +#endif + + // print the prompt token-by-token + + LOG("\n"); + + for (auto id : prompt_inp) { + LOG("%s", common_token_to_piece(ctx_ttc, id).c_str()); + } + + LOG_INF("%s: prompt size: %d\n", __func__, (int) prompt_inp.size()); + + LOG("\n"); + + // create a llama_batch + // we use this object to submit token data for decoding + llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel); + + std::vector seq_ids(n_parallel, 0); + for (int32_t i = 0; i < n_parallel; ++i) { + seq_ids[i] = i; + } + + // evaluate the initial prompt + for (size_t i = 0; i < prompt_inp.size(); ++i) { + common_batch_add(batch, prompt_inp[i], i, seq_ids, false); + } + GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx_ttc, batch) != 0) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + + if (n_parallel > 1) { + LOG_INF("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); + } + + llama_synchronize(ctx_ttc); + + LOG_INF("%s: time for prompt: %.3f ms\n\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f); + + const auto t_dec_start = ggml_time_us(); + + // main loop + + // remember the batch index of the last token for each parallel sequence + // we need this to determine which logits to sample from + std::vector i_batch(n_parallel, batch.n_tokens - 1); + + int n_past = batch.n_tokens; + int n_decode = 0; + + bool next_token_uses_guide_token = true; + + while (n_decode <= n_predict) { + // prepare the next batch + common_batch_clear(batch); + + // sample the next token for each parallel sequence / stream + for (int32_t i = 0; i < n_parallel; ++i) { + if (i_batch[i] < 0) { + // the stream has already finished + continue; + } + + llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]); + + //guide tokens help prevent hallucinations by forcing the TTS to use the correct word + if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) { + llama_token guide_token = guide_tokens[0]; + guide_tokens.erase(guide_tokens.begin()); + new_token_id = guide_token; //ensure correct word fragment is used + } + + //this is the token id that always precedes a new word + next_token_uses_guide_token = (new_token_id == 198); + + common_sampler_accept(smpl[i], new_token_id, true); + + codes.push_back(new_token_id); + + const auto * cands = common_sampler_get_candidates(smpl[i]); + + // is it an end of generation? -> mark the stream as finished + if (llama_vocab_is_eog(vocab, new_token_id) || n_decode == n_predict) { + std::string reason; + if (llama_vocab_is_eog(vocab, new_token_id)) { + reason = "eos"; + } else { + reason = "n_predict"; + } + + i_batch[i] = -1; + + LOG("\n"); + if (n_parallel > 1) { + LOG_CNT("\n"); + LOG_INF("%s: stream %d finished at n_past = %d, reason = '%s'\n", __func__, i, n_past, reason.c_str()); + } + + continue; + } + + { + const float p = cands->data[cands->selected].p; + + const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) ((3*p)*float(k_colors.size())))); + + LOG_CNT("%s%d%s", k_colors[col].c_str(), i, "\033[0m"); + //LOG_CNT("%d", i); + } + + i_batch[i] = batch.n_tokens; + + // push this new token for next evaluation + common_batch_add(batch, new_token_id, n_past, { i }, true); + } + + // all streams are finished + if (batch.n_tokens == 0) { + break; + } + + n_decode += 1; + n_past += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx_ttc, batch)) { + LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + llama_batch_free(batch); + + LOG("\n"); + LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); + } + + common_perf_print(ctx_ttc, smpl[0]); + + //std::vector codes = {198, 88225, 155856, 151669, 152205, + // 153064, 152537, 153421, 153209, 152524, 151689, 152993, 152438, 152695, + // 153091, 152945, 152829, 152534, 152934, 153020, 151997, 152263, 153010, + // 153146, 152399, 153208, 152496, 151793, 152848, 152263, 152571, 153286, + // 152227, 153300, 152934, 152263, 153208, 152263, 152965, 152430, 152296, + // 153146, 152920, 152376, 152556, 153363, 151775, 152044, 152972, 152690, + // 153379, 152368, 152233, 153422, 152490, 151996, 152022, 151694, 152061, + // 153238, 152539, 153356, 152640, 153021, 153123, 151962, 153094, 151670, + // 198, 20339, 13189, 155824, 151669, 152070, 152007, 152910, 151683, + // 152000, 152373, 152760, 152046, 151735, 152334, 152394, 153073, 152908, + // 151856, 151953, 153247, 153293, 151903, 153480, 153168, 152478, 153359, + // 153429, 151905, 151678, 152567, 152411, 152165, 152556, 153075, 153424, + // 151993, 152999, 153078, 152151, 152088, 153389, 152484, 151874, 151670, + // 198, 285, 155784, 151669, 152226, 152126, 152638, 153215, 151729, + // 152959, 153479, 153059, 151838, 151670, 198, 1782, 155783, 151669, + // 153288, 153055, 153314, 152497, 152962, 152741, 152076, 153253, 151670, + // 198, 471, 16488, 155825, 151669, 152060, 152916, 151893, 153469, 152501, + // 152080, 152743, 151932, 153161, 152096, 152761, 152698, 153401, 153242, + // 153336, 152441, 152838, 153467, 152706, 153496, 153310, 152422, 153360, + // 153115, 152763, 151998, 152373, 153450, 152554, 151968, 153323, 152055, + // 152468, 153111, 153358, 152813, 152010, 151770, 152823, 152960, 151670, + // 198, 22627, 155823, 151669, 152814, 152366, 153484, 152931, 153441, + // 152164, 152877, 152915, 153463, 151692, 152911, 152747, 152776, 151831, + // 153449, 151882, 152975, 152031, 152513, 153150, 152448, 152667, 153133, + // 153189, 152619, 153466, 152054, 152106, 153119, 152277, 152439, 153109, + // 152997, 152141, 153154, 153256, 153311, 151922, 151670, 198, 1055, + // 155781, 151669, 152633, 151850, 153060, 153270, 152560, 153348, 152729, + // 151670, 198, 25312, 155803, 151669, 152521, 153403, 152561, 153337, + // 153383, 152199, 153493, 153326, 151830, 152254, 152248, 152349, 152153, + // 153007, 151823, 153037, 152575, 152457, 152406, 152592, 153116, 153365, + // 153456, 151670, 198, 88225, 155817, 151669, 153271, 151925, 152218, + // 152418, 152253, 153140, 151903, 153151, 152626, 152338, 152647, 153464, + // 152785, 152768, 151711, 152037, 152033, 151804, 152216, 151701, 151855, + // 152348, 152995, 152955, 152905, 152342, 152340, 153391, 153453, 152418, + // 153415, 151990, 153083, 152884, 151670, 198, 151668, 198, 151645}; + + { + const std::string inp_txt = common_detokenize(ctx_ttc, codes, true); + + LOG("\n"); + LOG_INF("codes: '%s'\n", inp_txt.c_str()); + LOG_INF("%s: codes size: %d\n", __func__, (int) codes.size()); + } + + // remove all non-audio tokens (i.e. < 151672 || > 155772) + codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end()); + + { + const std::string inp_txt = common_detokenize(ctx_ttc, codes, true); + LOG_INF("codes audio: '%s'\n", inp_txt.c_str()); + LOG_INF("%s: codes audio size: %d\n", __func__, (int) codes.size()); + } + + for (auto & token : codes) { + token -= 151672; + } + + const auto t_voc_start = ggml_time_us(); + + const int n_codes = codes.size(); + + llama_batch batch = llama_batch_init(n_codes, 0, 1); + + for (size_t i = 0; i < codes.size(); ++i) { + common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits? + } + GGML_ASSERT(batch.n_tokens == n_codes); + + if (llama_decode(ctx_cts, batch) != 0) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + + llama_synchronize(ctx_cts); + + LOG_INF("%s: time for vocoder: %.3f ms\n", __func__, (ggml_time_us() - t_voc_start) / 1000.0f); + + const auto t_spec_start = ggml_time_us(); + +#if 1 + // spectral operations + const int n_embd = llama_model_n_embd(model_cts); + const float * embd = llama_get_embeddings(ctx_cts); + + auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads); + +#else + // read the spectrogram from a file for debugging purposes + std::vector audio; + { + std::ifstream fin("out.bin", std::ios::binary); + if (!fin) { + LOG_ERR("%s: failed to open file '%s'\n", __func__, "out.bin"); + return 1; + } + + std::vector embd; + + int n_codes; + int n_embd; + + fin.read(reinterpret_cast(&n_codes), sizeof(int)); + fin.read(reinterpret_cast(&n_embd), sizeof(int)); + + embd.resize(n_codes * n_embd); + fin.read(reinterpret_cast(embd.data()), n_codes * n_embd * sizeof(float)); + fin.close(); + + LOG_INF("%s: n_codes: %d, n_embd: %d\n", __func__, n_codes, n_embd); + + audio = embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads); + } +#endif + + const std::string fname = "output.wav"; + + const int n_sr = 24000; // sampling rate + + // zero out first 0.25 seconds + for (int i = 0; i < 24000/4; ++i) { + audio[i] = 0.0f; + } + + LOG_INF("%s: time for spectral ops: %.3f ms\n", __func__, (ggml_time_us() - t_spec_start) / 1000.0f); + LOG_INF("%s: total time: %.3f ms\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f); + + save_wav16(fname, audio, n_sr); + + LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str()); + + llama_backend_free(); + + return 0; +} diff --git a/flake.lock b/flake.lock index 10e1f8a29..d114f4422 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "nixpkgs-lib": "nixpkgs-lib" }, "locked": { - "lastModified": 1725024810, - "narHash": "sha256-ODYRm8zHfLTH3soTFWE452ydPYz2iTvr9T8ftDMUQ3E=", + "lastModified": 1730504689, + "narHash": "sha256-hgmguH29K2fvs9szpq2r3pz2/8cJd2LPS+b4tfNFCwE=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "af510d4a62d071ea13925ce41c95e3dec816c01d", + "rev": "506278e768c2a08bec68eb62932193e341f55c90", "type": "github" }, "original": { @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1724819573, - "narHash": "sha256-GnR7/ibgIH1vhoy8cYdmXE6iyZqKqFxQSVkFgosBh6w=", + "lastModified": 1732014248, + "narHash": "sha256-y/MEyuJ5oBWrWAic/14LaIr/u5E0wRVzyYsouYY3W6w=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "71e91c409d1e654808b2621f28a327acfdad8dc2", + "rev": "23e89b7da85c3640bbc2173fe04f4bd114342367", "type": "github" }, "original": { @@ -36,14 +36,14 @@ }, "nixpkgs-lib": { "locked": { - "lastModified": 1722555339, - "narHash": "sha256-uFf2QeW7eAHlYXuDktm9c25OxOyCoUOQmh5SZ9amE5Q=", + "lastModified": 1730504152, + "narHash": "sha256-lXvH/vOfb4aGYyvFmZK/HlsNsr/0CVWlwYvo2rxJk3s=", "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/a5d394176e64ab29c852d03346c1fc9b0b7d33eb.tar.gz" + "url": "https://github.com/NixOS/nixpkgs/archive/cc2f28000298e1269cea6612cd06ec9979dd5d7f.tar.gz" }, "original": { "type": "tarball", - "url": "https://github.com/NixOS/nixpkgs/archive/a5d394176e64ab29c852d03346c1fc9b0b7d33eb.tar.gz" + "url": "https://github.com/NixOS/nixpkgs/archive/cc2f28000298e1269cea6612cd06ec9979dd5d7f.tar.gz" } }, "root": { diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 532534bcb..7c069e420 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -32,7 +32,15 @@ else() endif() endif() +# remove the lib prefix on win32 mingw +if (WIN32) + set(CMAKE_STATIC_LIBRARY_PREFIX "") + set(CMAKE_SHARED_LIBRARY_PREFIX "") + set(CMAKE_SHARED_MODULE_PREFIX "") +endif() + option(BUILD_SHARED_LIBS "ggml: build shared libraries" ${BUILD_SHARED_LIBS_DEFAULT}) +option(GGML_BACKEND_DL "ggml: build backends as dynamic libraries (requires BUILD_SHARED_LIBS)" OFF) # # option list @@ -50,17 +58,27 @@ else() set(GGML_BLAS_VENDOR_DEFAULT "Generic") endif() -if (CMAKE_CROSSCOMPILING) +if (CMAKE_CROSSCOMPILING OR DEFINED ENV{SOURCE_DATE_EPOCH}) + message(STATUS "Setting GGML_NATIVE_DEFAULT to OFF") set(GGML_NATIVE_DEFAULT OFF) else() set(GGML_NATIVE_DEFAULT ON) endif() +# defaults +if (NOT GGML_LLAMAFILE_DEFAULT) + set(GGML_LLAMAFILE_DEFAULT OFF) +endif() + +if (NOT GGML_CUDA_GRAPHS_DEFAULT) + set(GGML_CUDA_GRAPHS_DEFAULT OFF) +endif() + # general -option(GGML_STATIC "ggml: static link libraries" OFF) -option(GGML_NATIVE "ggml: enable -march=native flag" ${GGML_NATIVE_DEFAULT}) -option(GGML_LTO "ggml: enable link time optimization" OFF) -option(GGML_CCACHE "ggml: use ccache if available" ON) +option(GGML_STATIC "ggml: static link libraries" OFF) +option(GGML_NATIVE "ggml: optimize the build for the current system" ${GGML_NATIVE_DEFAULT}) +option(GGML_LTO "ggml: enable link time optimization" OFF) +option(GGML_CCACHE "ggml: use ccache if available" ON) # debug option(GGML_ALL_WARNINGS "ggml: enable all compiler warnings" ON) @@ -82,54 +100,62 @@ else() set(INS_ENB ON) endif() -option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF) - -option(GGML_AVX "ggml: enable AVX" ${INS_ENB}) -option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB}) -option(GGML_AVX512 "ggml: enable AVX512" OFF) -option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF) -option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF) -option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF) -option(GGML_FMA "ggml: enable FMA" ${INS_ENB}) +option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF) +option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON) +option(GGML_AVX "ggml: enable AVX" ${INS_ENB}) +option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF) +option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB}) +option(GGML_AVX512 "ggml: enable AVX512F" OFF) +option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF) +option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF) +option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF) if (NOT MSVC) - option(GGML_F16C "ggml: enable F16C" ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512 + # in MSVC F16C and FMA is implied with AVX2/AVX512 + option(GGML_FMA "ggml: enable FMA" ${INS_ENB}) + option(GGML_F16C "ggml: enable F16C" ${INS_ENB}) + # MSVC does not seem to support AMX + option(GGML_AMX_TILE "ggml: enable AMX-TILE" OFF) + option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF) + option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF) endif() -option(GGML_LASX "ggml: enable lasx" ON) -option(GGML_LSX "ggml: enable lsx" ON) -option(GGML_SVE "ggml: enable SVE" OFF) +option(GGML_LASX "ggml: enable lasx" ON) +option(GGML_LSX "ggml: enable lsx" ON) +option(GGML_RVV "ggml: enable rvv" ON) + +option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) +set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") + if (WIN32) - set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows Version") + set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows version") endif() # ggml core set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism") +option(GGML_CPU "ggml: enable CPU backend" ON) # 3rd party libs / backends option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON) option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT}) set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING "ggml: BLAS library vendor") -option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF) +option(GGML_LLAMAFILE "ggml: use LLAMAFILE" ${GGML_LLAMAFILE_DEFAULT}) option(GGML_CUDA "ggml: use CUDA" OFF) option(GGML_MUSA "ggml: use MUSA" OFF) -option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF) option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) -set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") -set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) -set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING - "ggml: iters./thread per block for Q2_K/Q6_K") set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) -option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF) +option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) -option(GGML_HIPBLAS "ggml: use hipBLAS" OFF) +option(GGML_HIP "ggml: use HIP" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) @@ -141,6 +167,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_KOMPUTE "ggml: use Kompute" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) +option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF) option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL}) @@ -153,6 +180,16 @@ option(GGML_SYCL "ggml: use SYCL" option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") +set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING + "ggml: sycl device architecture") + +option(GGML_OPENCL "ggml: use OpenCL" OFF) +option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF) +option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) +option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) + +# toolchain for vulkan-shaders-gen +set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") # extra artifacts option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE}) @@ -165,11 +202,7 @@ option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE}) set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD_REQUIRED true) -if (GGML_SYCL) - set(CMAKE_CXX_STANDARD 17) -else() - set(CMAKE_CXX_STANDARD 11) -endif() +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED true) set(THREADS_PREFER_PTHREAD_FLAG ON) @@ -205,45 +238,26 @@ include(CMakePackageConfigHelpers) # all public headers set(GGML_PUBLIC_HEADERS include/ggml.h + include/ggml-cpu.h include/ggml-alloc.h include/ggml-backend.h include/ggml-blas.h include/ggml-cann.h include/ggml-cuda.h - include/ggml.h include/ggml-kompute.h + include/ggml-opt.h include/ggml-metal.h include/ggml-rpc.h include/ggml-sycl.h - include/ggml-vulkan.h) + include/ggml-vulkan.h + include/gguf.h) set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") #if (GGML_METAL) # set_target_properties(ggml PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/src/ggml-metal.metal") #endif() -install(TARGETS ggml PUBLIC_HEADER) - -if (BUILD_SHARED_LIBS) - install(TARGETS ggml LIBRARY) -endif() - -if (GGML_METAL) - install( - FILES src/ggml-metal.metal - PERMISSIONS - OWNER_READ - OWNER_WRITE - GROUP_READ - WORLD_READ - DESTINATION ${CMAKE_INSTALL_BINDIR}) - - if (NOT GGML_METAL_EMBED_LIBRARY) - install( - FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib - DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - endif() -endif() +install(TARGETS ggml LIBRARY PUBLIC_HEADER) +install(TARGETS ggml-base LIBRARY) if (GGML_STANDALONE) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/ggml.pc.in @@ -253,3 +267,74 @@ if (GGML_STANDALONE) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc DESTINATION share/pkgconfig) endif() + +# +# Create CMake package +# + +# Generate version info based on git commit. + +find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH) +execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_NUMBER + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +if(GGML_BUILD_NUMBER EQUAL 1) + message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.") +endif() + +execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + OUTPUT_VARIABLE GGML_BUILD_COMMIT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# Capture variables prefixed with GGML_. + +set(variable_set_statements +" +####### Expanded from @GGML_VARIABLES_EXPANED@ by configure_package_config_file() ####### +####### Any changes to this file will be overwritten by the next CMake run ####### + +") + +set(GGML_SHARED_LIB ${BUILD_SHARED_LIBS}) + +get_cmake_property(all_variables VARIABLES) +foreach(variable_name IN LISTS all_variables) + if(variable_name MATCHES "^GGML_") + string(REPLACE ";" "\\;" + variable_value "${${variable_name}}") + + set(variable_set_statements + "${variable_set_statements}set(${variable_name} \"${variable_value}\")\n") + endif() +endforeach() + +set(GGML_VARIABLES_EXPANDED ${variable_set_statements}) + +# Create the CMake package and set install location. + +set(GGML_INSTALL_VERSION 0.0.${GGML_BUILD_NUMBER}) +set(GGML_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(GGML_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(GGML_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/ggml-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml + PATH_VARS GGML_INCLUDE_INSTALL_DIR + GGML_LIB_INSTALL_DIR + GGML_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake + VERSION ${GGML_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in new file mode 100644 index 000000000..bf39f9c00 --- /dev/null +++ b/ggml/cmake/ggml-config.cmake.in @@ -0,0 +1,147 @@ + +@GGML_VARIABLES_EXPANDED@ + +@PACKAGE_INIT@ + +set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@") +set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@") +set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") + +find_package(Threads REQUIRED) + +find_library(GGML_LIBRARY ggml + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + +add_library(ggml::ggml UNKNOWN IMPORTED) +set_target_properties(ggml::ggml + PROPERTIES + IMPORTED_LOCATION "${GGML_LIBRARY}") + +find_library(GGML_BASE_LIBRARY ggml-base + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + +add_library(ggml::ggml-base UNKNOWN IMPORTED) +set_target_properties(ggml::ggml-base + PROPERTIES + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") + +if (NOT GGML_SHARED_LIB) + if (APPLE AND GGML_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK}) + endif() + + if (GGML_OPENMP) + find_package(OpenMP REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + endif() + + if (GGML_CPU_HBM) + find_library(memkind memkind REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind) + endif() + + if (GGML_BLAS) + find_package(BLAS REQUIRED) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES}) + list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS}) + endif() + + if (GGML_CUDA) + find_package(CUDAToolkit REQUIRED) + endif() + + if (GGML_METAL) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + + list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES + ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK}) + endif() + + if (GGML_VULKAN) + find_package(Vulkan REQUIRED) + list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan) + endif() + + if (GGML_HIP) + find_package(hip REQUIRED) + find_package(hipblas REQUIRED) + find_package(rocblas REQUIRED) + list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas) + endif() + + if (GGML_SYCL) + find_package(DNNL) + if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") + list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl) + endif() + if (WIN32) + find_package(IntelSYCL REQUIRED) + find_package(MKL REQUIRED) + list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) + endif() + endif() +endif() + +set(_ggml_all_targets "") +foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) + string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}") + string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx) + + find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend} + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) + + message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}") + + add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}" + INTERFACE_COMPILE_FEATURES c_std_90 + POSITION_INDEPENDENT_CODE ON) + + string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}") + if(is_cpu_variant) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}") + + if(GGML_CPU_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}") + endif() + + else() + list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml" "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}") + + if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}") + endif() + endif() + + list(APPEND _ggml_all_targets ggml::${_ggml_backend}) +endforeach() + +add_library(ggml::all INTERFACE IMPORTED) +set_target_properties(ggml::all + PROPERTIES + INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}") + +check_required_components(ggml) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h index 0dff47d65..23600eea9 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h @@ -24,7 +24,7 @@ GGML_API void ggml_tallocr_alloc(struct ggml_tallocr * talloc, st // Graph allocator /* Example usage: - ggml_gallocr_t galloc = ggml_gallocr_new(ggml_bacckend_cpu_buffer_type()); + ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type()); // optional: create a worst-case graph and reserve the buffers to avoid reallocations ggml_gallocr_reserve(galloc, build_graph(max_batch)); diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index e497b6d02..fc9571c82 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -3,6 +3,20 @@ #include "ggml.h" #include "ggml-alloc.h" +#ifdef GGML_BACKEND_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BACKEND_BUILD +# define GGML_BACKEND_API __declspec(dllexport) extern +# else +# define GGML_BACKEND_API __declspec(dllimport) extern +# endif +# else +# define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern +# endif +#else +# define GGML_BACKEND_API extern +#endif + #ifdef __cplusplus extern "C" { #endif @@ -12,43 +26,52 @@ extern "C" { typedef struct ggml_backend_event * ggml_backend_event_t; typedef struct ggml_backend * ggml_backend_t; typedef void * ggml_backend_graph_plan_t; + typedef struct ggml_backend_reg * ggml_backend_reg_t; + typedef struct ggml_backend_device * ggml_backend_dev_t; + + + // + // Backend buffer type + // + + GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); + GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); + GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); + GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); + GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); + GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); + GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft); // // Backend buffer // - // buffer type - GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft); - GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); - GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); - GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); - GGML_API GGML_CALL size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); - GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); - - // buffer enum ggml_backend_buffer_usage { GGML_BACKEND_BUFFER_USAGE_ANY = 0, GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1, GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2, }; - GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer); - GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); - GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); - GGML_API GGML_CALL void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); - GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); - GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); - GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage (ggml_backend_buffer_t buffer); - GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer); - GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); + GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); + GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); + GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage (ggml_backend_buffer_t buffer); + GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); + + // tensor copy between different backends + GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); // - // Backend + // Backend (stream) // GGML_API ggml_guid_t ggml_backend_guid(ggml_backend_t backend); @@ -63,9 +86,10 @@ extern "C" { GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - // "offset" refers to the offset of the tensor data for setting/getting data - GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // "offset" refers to the offset in tensor->data for setting/getting data + GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); GGML_API void ggml_backend_synchronize(ggml_backend_t backend); @@ -75,65 +99,144 @@ extern "C" { GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan); GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph); + + // NOTE: will be removed, use device version instead GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op); GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft); GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op); - // tensor copy between different backends - GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); - // asynchronous copy // the copy is performed after all the currently queued operations in backend_src // backend_dst will wait for the copy to complete before performing other operations // automatic fallback to sync copy if async is not supported GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); - // events - GGML_API ggml_backend_event_t ggml_backend_event_new (ggml_backend_t backend); - GGML_API void ggml_backend_event_free (ggml_backend_event_t event); - GGML_API void ggml_backend_event_record (ggml_backend_event_t event); - GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event); - GGML_API void ggml_backend_event_wait (ggml_backend_t backend, ggml_backend_event_t event); + GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend); // - // CPU backend + // Events // - GGML_API ggml_backend_t ggml_backend_cpu_init(void); + GGML_API ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device); + GGML_API void ggml_backend_event_free(ggml_backend_event_t event); + GGML_API void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend); + GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event); + GGML_API void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event); - GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend); - GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); - GGML_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); - GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); + // + // Backend device + // - // Create a backend buffer from an existing pointer - GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); + enum ggml_backend_dev_type { + // CPU device using system memory + GGML_BACKEND_DEVICE_TYPE_CPU, + // GPU device using dedicated memory + GGML_BACKEND_DEVICE_TYPE_GPU, + // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) + GGML_BACKEND_DEVICE_TYPE_ACCEL + }; - GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); + // functionality supported by the device + struct ggml_backend_dev_caps { + // asynchronous operations + bool async; + // pinned host buffer + bool host_buffer; + // creating buffers from host ptr + bool buffer_from_host_ptr; + // event synchronization + bool events; + }; -#ifdef GGML_USE_CPU_HBM - GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void); -#endif + // all the device properties + struct ggml_backend_dev_props { + const char * name; + const char * description; + size_t memory_free; + size_t memory_total; + enum ggml_backend_dev_type type; + struct ggml_backend_dev_caps caps; + }; + + GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); + GGML_API const char * ggml_backend_dev_description(ggml_backend_dev_t device); + GGML_API void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total); + GGML_API enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device); + GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props); + GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device); + GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params); + GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device); + GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); + GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); + + GGML_API bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op); + GGML_API bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft); + GGML_API bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op); + + // + // Backend (reg) + // + + GGML_API const char * ggml_backend_reg_name(ggml_backend_reg_t reg); + GGML_API size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg); + GGML_API ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index); + GGML_API void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name); + + // Common functions that may be obtained using ggml_backend_reg_get_proc_address + + // Split buffer type for tensor parallelism + typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); + // Set the number of threads for the backend + typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); + // Get additional buffer types provided by the device (returns a NULL-terminated array) + typedef ggml_backend_buffer_type_t * (*ggml_backend_dev_get_extra_bufts_t)(ggml_backend_dev_t device); + // Set the abort callback for the backend + typedef void (*ggml_backend_set_abort_callback_t)(ggml_backend_t backend, ggml_abort_callback abort_callback, void * abort_callback_data); + // Get a list of feature flags supported by the backend (returns a NULL-terminated array) + struct ggml_backend_feature { + const char * name; + const char * value; + }; + typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg); // // Backend registry // - // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); - GGML_API size_t ggml_backend_reg_get_count(void); - GGML_API size_t ggml_backend_reg_find_by_name(const char * name); - GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional) - GGML_API const char * ggml_backend_reg_get_name(size_t i); - GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific - GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i); - GGML_API ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size); + // Backend (reg) enumeration + GGML_API size_t ggml_backend_reg_count(void); + GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index); + GGML_API ggml_backend_reg_t ggml_backend_reg_by_name(const char * name); + + // Device enumeration + GGML_API size_t ggml_backend_dev_count(void); + GGML_API ggml_backend_dev_t ggml_backend_dev_get(size_t index); + GGML_API ggml_backend_dev_t ggml_backend_dev_by_name(const char * name); + GGML_API ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type); + + // Direct backend (stream) initialization + // = ggml_backend_dev_init(ggml_backend_dev_by_name(name), params) + GGML_API ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params); + // = ggml_backend_dev_init(ggml_backend_dev_by_type(type), params) + GGML_API ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params); + // = ggml_backend_dev_init(ggml_backend_dev_by_type(GPU) OR ggml_backend_dev_by_type(CPU), NULL) + GGML_API ggml_backend_t ggml_backend_init_best(void); + + // Load a backend from a dynamic library and register it + GGML_API ggml_backend_reg_t ggml_backend_load(const char * path); + // Unload a backend if loaded dynamically and unregister it + GGML_API void ggml_backend_unload(ggml_backend_reg_t reg); + // Load all known backends from dynamic libraries + GGML_API void ggml_backend_load_all(void); + GGML_API void ggml_backend_load_all_from_path(const char * dir_path); // // Backend scheduler // - // The backend scheduler allows for multiple backends to be used together + // The backend scheduler allows for multiple backend devices to be used together // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends // The backends are selected based on: // - the backend that supports the operation @@ -157,20 +260,26 @@ extern "C" { ggml_backend_sched_reserve(sched, reserve_graph); // compute - graph = build_graph(sched); - ggml_backend_sched_graph_compute(sched, graph); + graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation + for (int i = 0; i < 10; ++i) { + ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically + } // if there are graph inputs: - ggml_backend_sched_reset(sched); - ggml_backend_sched_alloc_graph(sched, graph); - ggml_backend_tensor_set(input_tensor, ...); - ggml_backend_sched_graph_compute(sched, graph); + graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called) + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph + ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it + ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, graph); // execute the graph + + // as an alternative to the above it is also possible to assign the inputs to a dedicated context and + // allocate them statically via ggml_backend_alloc_ctx_tensors } */ - struct ggml_backend_sched; typedef struct ggml_backend_sched * ggml_backend_sched_t; + // Evaluation callback for each node in the graph (set with ggml_backend_sched_set_eval_callback) // when ask == true, the scheduler wants to know if the user wants to observe this node // this allows the scheduler to batch nodes together in order to evaluate them in a single call // @@ -179,12 +288,12 @@ extern "C" { // typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); - // Initialize a backend scheduler + // Initialize a backend scheduler, backends with low index are given priority over backends with high index GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph - GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); + GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched); GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i); @@ -199,12 +308,14 @@ extern "C" { GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); // Allocate and compute graph on the backend scheduler - GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); + GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph); GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched); - // Reset all assignments and allocators - must be called before changing the node backends + // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph. + // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers. + // The correct way to use this API is to discard the deallocated tensors and create new ones. GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); // Set a callback to be called for each resulting node during graph compute @@ -225,7 +336,7 @@ extern "C" { GGML_API struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph); GGML_API void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy); - typedef bool (*GGML_CALL ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); + typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); // Compare the output of two backends GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data); @@ -234,6 +345,9 @@ extern "C" { GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); + // CPU buffer types are always available + GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); + GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void); #ifdef __cplusplus } diff --git a/ggml/include/ggml-blas.h b/ggml/include/ggml-blas.h index f2e37de06..87a81b363 100644 --- a/ggml/include/ggml-blas.h +++ b/ggml/include/ggml-blas.h @@ -9,13 +9,15 @@ extern "C" { #endif // backend API -GGML_API GGML_CALL ggml_backend_t ggml_backend_blas_init(void); +GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void); -GGML_API GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend); +GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend); // number of threads used for conversion to float // for openblas and blis, this will also set the number of threads used for blas operations -GGML_API GGML_CALL void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads); +GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void); #ifdef __cplusplus diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h index ca73211fe..b469e228d 100644 --- a/ggml/include/ggml-cann.h +++ b/ggml/include/ggml-cann.h @@ -34,6 +34,8 @@ extern "C" { */ #define GGML_CANN_MAX_DEVICES 16 +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void); + /** * @brief Initializes the CANN backend for a specified device. * @@ -44,7 +46,7 @@ extern "C" { * @param device The index of the device to initialize. * @return A pointer to the initialized backend instance, or nullptr on failure. */ -GGML_API GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device); +GGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device); /** * @brief Checks if a given backend is a CANN backend. @@ -55,7 +57,7 @@ GGML_API GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device); * @param backend The backend instance to check. * @return True if the backend is a CANN backend, false otherwise. */ -GGML_API GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend); +GGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend); /** * @brief Retrieves the CANN buffer type for a specified device. @@ -67,7 +69,7 @@ GGML_API GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend); * @return A pointer to the buffer type interface for the specified device, or * nullptr if the device index is out of range. */ -GGML_API GGML_CALL ggml_backend_buffer_type_t +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_buffer_type(int32_t device); /** @@ -78,7 +80,14 @@ ggml_backend_cann_buffer_type(int32_t device); * * @return The number of CANN devices available. */ -GGML_API GGML_CALL int32_t ggml_backend_cann_get_device_count(void); +GGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void); + +/** + * @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU. + * + * @return A pointer to the host buffer type interface. + */ +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void); /** * @brief Retrieves the description of a specific CANN device. @@ -90,7 +99,7 @@ GGML_API GGML_CALL int32_t ggml_backend_cann_get_device_count(void); * @param description Pointer to a buffer where the description will be written. * @param description_size Size of the description buffer. */ -GGML_API GGML_CALL void ggml_backend_cann_get_device_description( +GGML_BACKEND_API void ggml_backend_cann_get_device_description( int32_t device, char* description, size_t description_size); /** @@ -105,20 +114,9 @@ GGML_API GGML_CALL void ggml_backend_cann_get_device_description( * @param total Pointer to a variable where the total memory size will be * stored. */ -GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, - size_t* free, - size_t* total); - -/** - * @brief Set the logging callback for GGML. - * - * This function sets the logging callback and user data for logging. - * - * @param log_callback The logging callback to set. - * @param user_data User data to pass to the logging callback. - */ -GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback, - void* user_data); +GGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device, + size_t* free, + size_t* total); #ifdef __cplusplus } diff --git a/ggml/include/ggml-cpp.h b/ggml/include/ggml-cpp.h new file mode 100644 index 000000000..a12342c25 --- /dev/null +++ b/ggml/include/ggml-cpp.h @@ -0,0 +1,39 @@ +#pragma once + +#ifndef __cplusplus +#error "This header is for C++ only" +#endif + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "gguf.h" +#include + +// Smart pointers for ggml types + +// ggml + +struct ggml_context_deleter { void operator()(ggml_context * ctx) { ggml_free(ctx); } }; +struct gguf_context_deleter { void operator()(gguf_context * ctx) { gguf_free(ctx); } }; + +typedef std::unique_ptr ggml_context_ptr; +typedef std::unique_ptr gguf_context_ptr; + +// ggml-alloc + +struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } }; + +typedef std::unique_ptr ggml_gallocr_ptr; + +// ggml-backend + +struct ggml_backend_deleter { void operator()(ggml_backend_t backend) { ggml_backend_free(backend); } }; +struct ggml_backend_buffer_deleter { void operator()(ggml_backend_buffer_t buffer) { ggml_backend_buffer_free(buffer); } }; +struct ggml_backend_event_deleter { void operator()(ggml_backend_event_t event) { ggml_backend_event_free(event); } }; +struct ggml_backend_sched_deleter { void operator()(ggml_backend_sched_t sched) { ggml_backend_sched_free(sched); } }; + +typedef std::unique_ptr ggml_backend_ptr; +typedef std::unique_ptr ggml_backend_buffer_ptr; +typedef std::unique_ptr ggml_backend_event_ptr; +typedef std::unique_ptr ggml_backend_sched_ptr; diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h new file mode 100644 index 000000000..3aa71badb --- /dev/null +++ b/ggml/include/ggml-cpu.h @@ -0,0 +1,135 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + + // the compute plan that needs to be prepared for ggml_graph_compute() + // since https://github.com/ggerganov/ggml/issues/287 + struct ggml_cplan { + size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` + uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` + + int n_threads; + struct ggml_threadpool * threadpool; + + // abort ggml_graph_compute when true + ggml_abort_callback abort_callback; + void * abort_callback_data; + }; + + // numa strategies + enum ggml_numa_strategy { + GGML_NUMA_STRATEGY_DISABLED = 0, + GGML_NUMA_STRATEGY_DISTRIBUTE = 1, + GGML_NUMA_STRATEGY_ISOLATE = 2, + GGML_NUMA_STRATEGY_NUMACTL = 3, + GGML_NUMA_STRATEGY_MIRROR = 4, + GGML_NUMA_STRATEGY_COUNT + }; + + GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems + GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node + + GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); + GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); + + GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); + GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); + + GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); + GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + + GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); + + GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); + GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); + + GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + + GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params); + GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool); + GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool); + GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool); + GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool); + + // ggml_graph_plan() has to be called before ggml_graph_compute() + // when plan.work_size > 0, caller must allocate memory for plan.work_data + GGML_BACKEND_API struct ggml_cplan ggml_graph_plan( + const struct ggml_cgraph * cgraph, + int n_threads, /* = GGML_DEFAULT_N_THREADS */ + struct ggml_threadpool * threadpool /* = NULL */ ); + GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + + // same as ggml_graph_compute() but the work data is allocated as a part of the context + // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data + GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); + + // + // system info + // + + // x86 + GGML_BACKEND_API int ggml_cpu_has_sse3 (void); + GGML_BACKEND_API int ggml_cpu_has_ssse3 (void); + GGML_BACKEND_API int ggml_cpu_has_avx (void); + GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void); + GGML_BACKEND_API int ggml_cpu_has_avx2 (void); + GGML_BACKEND_API int ggml_cpu_has_f16c (void); + GGML_BACKEND_API int ggml_cpu_has_fma (void); + GGML_BACKEND_API int ggml_cpu_has_avx512 (void); + GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void); + GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void); + GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void); + GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void); + // ARM + GGML_BACKEND_API int ggml_cpu_has_neon (void); + GGML_BACKEND_API int ggml_cpu_has_arm_fma (void); + GGML_BACKEND_API int ggml_cpu_has_fp16_va (void); + GGML_BACKEND_API int ggml_cpu_has_dotprod (void); + GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void); + GGML_BACKEND_API int ggml_cpu_has_sve (void); + GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes + // other + GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); + GGML_BACKEND_API int ggml_cpu_has_vsx (void); + GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); + GGML_BACKEND_API int ggml_cpu_has_llamafile (void); + + // Internal types and functions exposed for tests and benchmarks + + typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, + const void * GGML_RESTRICT y, size_t by, int nrc); + + struct ggml_type_traits_cpu { + ggml_from_float_t from_float; + ggml_vec_dot_t vec_dot; + enum ggml_type vec_dot_type; + int64_t nrows; // number of rows to process simultaneously + }; + + GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type); + + GGML_BACKEND_API void ggml_cpu_init(void); + + // + // CPU backend + // + + GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void); + + GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend); + GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); + GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); + GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); + + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index 71bb6dcf0..22ad2c009 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -3,7 +3,11 @@ #include "ggml.h" #include "ggml-backend.h" -#ifdef GGML_USE_HIPBLAS +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef GGML_USE_HIP #define GGML_CUDA_NAME "ROCm" #define GGML_CUBLAS_NAME "hipBLAS" #elif defined(GGML_USE_MUSA) @@ -13,35 +17,31 @@ #define GGML_CUDA_NAME "CUDA" #define GGML_CUBLAS_NAME "cuBLAS" #endif - -#ifdef __cplusplus -extern "C" { -#endif - #define GGML_CUDA_MAX_DEVICES 16 // backend API -GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device); +GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device); -GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend); +GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend); // device buffer -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); // split tensor buffer that splits matrices by rows across multiple devices -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split); // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void); -GGML_API GGML_CALL int ggml_backend_cuda_get_device_count(void); -GGML_API GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size); -GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total); +GGML_BACKEND_API int ggml_backend_cuda_get_device_count(void); +GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size); +GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total); -GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size); -GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer); +GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size); +GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void); -GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data); #ifdef __cplusplus } #endif diff --git a/ggml/include/ggml-kompute.h b/ggml/include/ggml-kompute.h index 171465456..154aa56a7 100644 --- a/ggml/include/ggml-kompute.h +++ b/ggml/include/ggml-kompute.h @@ -11,6 +11,8 @@ extern "C" { #endif +#define GGML_KOMPUTE_MAX_DEVICES 16 + struct ggml_vk_device { int index; int type; // same as VkPhysicalDeviceType @@ -35,11 +37,13 @@ struct ggml_vk_device ggml_vk_current_device(void); // forward declaration typedef struct ggml_backend * ggml_backend_t; -GGML_API ggml_backend_t ggml_backend_kompute_init(int device); +GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device); -GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend); +GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend); -GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void); #ifdef __cplusplus } diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h index d483cf1ac..669c1f84a 100644 --- a/ggml/include/ggml-metal.h +++ b/ggml/include/ggml-metal.h @@ -1,3 +1,5 @@ +// Note: this description is outdated +// // An interface allowing to compute ggml_cgraph with Metal // // This is a fully functional interface that extends ggml with GPU support for Apple devices. @@ -25,9 +27,6 @@ #include #include -// max memory buffers that can be mapped to the device -#define GGML_METAL_MAX_BUFFERS 64 - struct ggml_tensor; struct ggml_cgraph; @@ -40,27 +39,27 @@ extern "C" { // user-code should use only these functions // -GGML_API void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data); +GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void); -GGML_API ggml_backend_t ggml_backend_metal_init(void); +GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); -GGML_API bool ggml_backend_is_metal(ggml_backend_t backend); +GGML_DEPRECATED( + GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), + "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713"); -GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size); +GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); -GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb); - -GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); - -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); // helper to check if the device supports a specific family // ideally, the user code should be doing these checks // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf -GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family); +GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family); // capture all command buffers committed the next time `ggml_backend_graph_compute` is called -GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); +GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void); #ifdef __cplusplus } diff --git a/ggml/include/ggml-opencl.h b/ggml/include/ggml-opencl.h new file mode 100644 index 000000000..6b6177135 --- /dev/null +++ b/ggml/include/ggml-opencl.h @@ -0,0 +1,26 @@ +#ifndef GGML_OPENCL_H +#define GGML_OPENCL_H + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// +// backend API +// +GGML_BACKEND_API ggml_backend_t ggml_backend_opencl_init(void); +GGML_BACKEND_API bool ggml_backend_is_opencl(ggml_backend_t backend); + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_opencl_reg(void); + +#ifdef __cplusplus +} +#endif + +#endif // GGML_OPENCL_H diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h new file mode 100644 index 000000000..eb5eab9de --- /dev/null +++ b/ggml/include/ggml-opt.h @@ -0,0 +1,216 @@ +// This file contains functionality for training models using GGML. +// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets. +// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code. +// +// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + + struct ggml_opt_dataset; + struct ggml_opt_context; + struct ggml_opt_result; + + typedef struct ggml_opt_dataset * ggml_opt_dataset_t; + typedef struct ggml_opt_context * ggml_opt_context_t; + typedef struct ggml_opt_result * ggml_opt_result_t; + + // ====== Loss ====== + + // built-in loss types, i.e. the built-in quantities minimized by the optimizer + // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value + enum ggml_opt_loss_type { + GGML_OPT_LOSS_TYPE_MEAN, + GGML_OPT_LOSS_TYPE_SUM, + GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, + GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, + }; + + // ====== Dataset ====== + + GGML_API ggml_opt_dataset_t ggml_opt_dataset_init( + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); + + // get underlying tensors that store the data + GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] + GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] + + // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative + GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata); + + // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch + GGML_API void ggml_opt_dataset_get_batch( + ggml_opt_dataset_t dataset, + struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] + struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] + int64_t ibatch); + + // ====== Model / Context ====== + + enum ggml_opt_build_type { + GGML_OPT_BUILD_TYPE_FORWARD, + GGML_OPT_BUILD_TYPE_GRAD, + GGML_OPT_BUILD_TYPE_OPT, + }; + + // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss + struct ggml_opt_optimizer_params { + // AdamW optimizer parameters + struct { + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float wd; // weight decay for AdamW, use 0.0f to disable + } adamw; + }; + + // callback to calculate optimizer parameters prior to a backward pass + // userdata can be used to pass arbitrary data + typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); + + // returns the default optimizer params (constant) + // userdata is not used + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata); + + // parameters for initializing a new optimization context + struct ggml_opt_params { + ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs + + struct ggml_context * ctx_compute; // created in user code, holds non-static tensors + + // the forward graph is defined by inputs and outputs + // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; + + enum ggml_opt_loss_type loss_type; + enum ggml_opt_build_type build_type; + + int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done + + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + }; + + // get parameters for an optimization context with defaults set where possible + // parameters for which no sensible defaults exist are supplied as arguments to this function + GGML_API ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + struct ggml_context * ctx_compute, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs, + enum ggml_opt_loss_type loss_type); + + GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); + GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); + + // set gradients to zero, initilize loss, and optionally reset the optimizer + GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); + + // get underlying tensors that store data + GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor + GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor + GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against + GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss + GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs + GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels + + GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); + + // ====== Optimization Result ====== + + GGML_API ggml_opt_result_t ggml_opt_result_init(); + GGML_API void ggml_opt_result_free(ggml_opt_result_t result); + GGML_API void ggml_opt_result_reset(ggml_opt_result_t result); + + // get data from result, uncertainties are optional and can be ignored by passing NULL + GGML_API void ggml_opt_result_ndata( ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints + GGML_API void ggml_opt_result_loss( ggml_opt_result_t result, double * loss, double * unc); // writes 1 value + GGML_API void ggml_opt_result_pred( ggml_opt_result_t result, int32_t * pred); // writes ndata values + GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value + + // ====== Computation ====== + + // do forward pass, increment result if not NULL + GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + + // do forward pass, increment result if not NULL, do backward pass + GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + + // ############################################################################ + // ## The high-level functions start here. They do not depend on any private ## + // ## functions or structs and can be copied to and adapted for user code. ## + // ############################################################################ + + // ====== Intended Usage ====== + // + // 1. Select the appropriate loss for your problem. + // 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them. + // Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster). + // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors. + // The first context should contain the model parameters and inputs and be allocated statically in user code. + // The second context should contain all other tensors and will be (re)allocated automatically. + // Due to this automated allocation the data of the second context is not defined when accessed in user code. + // Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors. + // 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead. + + // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation + typedef void (*ggml_opt_epoch_callback)( + bool train, // true after training evaluation, false after validation evaluation + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, // result associated with the dataset subsection + int64_t ibatch, // number of batches that have been evaluated so far + int64_t ibatch_max, // total number of batches in this dataset subsection + int64_t t_start_us); // time at which the evaluation on the dataset subsection was started + + // do training on front of dataset, do evaluation only on back of dataset + GGML_API void ggml_opt_epoch( + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, // result to increment during training, ignored if NULL + ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL + int64_t idata_split, // data index at which to split training and evaluation + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + + // callback that prints a progress bar on stderr + GGML_API void ggml_opt_epoch_callback_progress_bar( + bool train, + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us); + + // fit model defined by inputs and outputs to dataset + GGML_API void ggml_opt_fit( + ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs + ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs + ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] + ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used + ggml_opt_dataset_t dataset, // dataset with data and optionally also labels + enum ggml_opt_loss_type loss_type, // loss to minimize + ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) + int64_t nepoch, // how many times the dataset should be iterated over + int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs + float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f) + bool silent); // whether or not info prints to stderr should be suppressed + +#ifdef __cplusplus +} +#endif diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index aa144832a..ade6c3b0e 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -10,14 +10,18 @@ extern "C" { #define GGML_RPC_MAX_SERVERS 16 // backend API -GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint); -GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend); +GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); +GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); -GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); +GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); -GGML_API GGML_CALL void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem); +GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); + +GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint); #ifdef __cplusplus } diff --git a/ggml/include/ggml-sycl.h b/ggml/include/ggml-sycl.h index 43ab1519c..5ce349a88 100644 --- a/ggml/include/ggml-sycl.h +++ b/ggml/include/ggml-sycl.h @@ -17,26 +17,33 @@ extern "C" { #endif // backend API -GGML_API ggml_backend_t ggml_backend_sycl_init(int device); +GGML_BACKEND_API ggml_backend_t ggml_backend_sycl_init(int device); + +GGML_BACKEND_API bool ggml_backend_is_sycl(ggml_backend_t backend); // devide buffer -GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device); // split tensor buffer that splits matrices by rows across multiple devices -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split); // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU -GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); -GGML_API void ggml_backend_sycl_print_sycl_devices(void); -GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len); -GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description, size_t description_size); -GGML_API GGML_CALL int ggml_backend_sycl_get_device_count(); -GGML_API GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total); +GGML_BACKEND_API void ggml_backend_sycl_print_sycl_devices(void); +GGML_BACKEND_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len); +GGML_BACKEND_API void ggml_backend_sycl_get_device_description(int device, + char *description, + size_t description_size); +GGML_BACKEND_API int ggml_backend_sycl_get_device_count(); +GGML_BACKEND_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total); // SYCL doesn't support registering host memory, keep here for reference -// GGML_API GGML_CALL bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size); -// GGML_API GGML_CALL void ggml_backend_sycl_unregister_host_buffer(void * buffer); +// GGML_BACKEND_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size); +// GGML_BACKEND_API void ggml_backend_sycl_unregister_host_buffer(void * buffer); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_sycl_reg(void); + #ifdef __cplusplus } #endif diff --git a/ggml/include/ggml-vulkan.h b/ggml/include/ggml-vulkan.h index af661c2d7..53cdba072 100644 --- a/ggml/include/ggml-vulkan.h +++ b/ggml/include/ggml-vulkan.h @@ -10,19 +10,21 @@ extern "C" { #define GGML_VK_NAME "Vulkan" #define GGML_VK_MAX_DEVICES 16 -GGML_API void ggml_vk_instance_init(void); +GGML_BACKEND_API void ggml_vk_instance_init(void); // backend API -GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num); +GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num); -GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend); -GGML_API GGML_CALL int ggml_backend_vk_get_device_count(void); -GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size); -GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total); +GGML_BACKEND_API bool ggml_backend_is_vk(ggml_backend_t backend); +GGML_BACKEND_API int ggml_backend_vk_get_device_count(void); +GGML_BACKEND_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size); +GGML_BACKEND_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total); -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num); // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_vk_reg(void); #ifdef __cplusplus } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 536018b66..1198dc1fd 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -176,25 +176,15 @@ #ifdef GGML_SHARED # if defined(_WIN32) && !defined(__MINGW32__) # ifdef GGML_BUILD -# define GGML_API __declspec(dllexport) +# define GGML_API __declspec(dllexport) extern # else -# define GGML_API __declspec(dllimport) +# define GGML_API __declspec(dllimport) extern # endif # else -# define GGML_API __attribute__ ((visibility ("default"))) +# define GGML_API __attribute__ ((visibility ("default"))) extern # endif #else -# define GGML_API -#endif - -#ifdef GGML_MULTIPLATFORM -# if defined(_WIN32) -# define GGML_CALL -# else -# define GGML_CALL __attribute__((__ms_abi__)) -# endif -#else -# define GGML_CALL +# define GGML_API extern #endif // TODO: support for clang @@ -227,16 +217,17 @@ #define GGML_MAX_DIMS 4 #define GGML_MAX_PARAMS 2048 -#define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 10 -#ifndef GGML_MAX_NAME -#define GGML_MAX_NAME 64 #define GGML_MAX_N_THREADS 512 - -#endif #define GGML_MAX_OP_PARAMS 64 + +#ifndef GGML_MAX_NAME +# define GGML_MAX_NAME 64 +#endif + #define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_GRAPH_SIZE 2048 + #if UINTPTR_MAX == 0xFFFFFFFF #define GGML_MEM_ALIGN 4 #else @@ -246,34 +237,30 @@ #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 -#define GGML_ROPE_TYPE_NEOX 2 - -#define GGUF_MAGIC "GGUF" - -#define GGUF_VERSION 3 - -#define GGUF_DEFAULT_ALIGNMENT 32 +#define GGML_ROPE_TYPE_NEOX 2 +#define GGML_ROPE_TYPE_MROPE 8 +#define GGML_ROPE_TYPE_VISION 24 #define GGML_UNUSED(x) (void)(x) #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) #ifndef NDEBUG -#define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) +# define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) #elif defined(__GNUC__) -#define GGML_UNREACHABLE() __builtin_unreachable() +# define GGML_UNREACHABLE() __builtin_unreachable() #elif defined(_MSC_VER) -#define GGML_UNREACHABLE() __assume(0) +# define GGML_UNREACHABLE() __assume(0) #else -#define GGML_UNREACHABLE() ((void) 0) +# define GGML_UNREACHABLE() ((void) 0) #endif #ifdef __cplusplus -#define GGML_NORETURN [[noreturn]] +# define GGML_NORETURN [[noreturn]] #elif defined(_MSC_VER) -#define GGML_NORETURN __declspec(noreturn) +# define GGML_NORETURN __declspec(noreturn) #else -#define GGML_NORETURN _Noreturn +# define GGML_NORETURN _Noreturn #endif #define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__) @@ -338,7 +325,7 @@ extern "C" { }; // get ggml_status name string - GGML_API GGML_CALL const char * ggml_status_to_string(enum ggml_status status); + GGML_API const char * ggml_status_to_string(enum ggml_status status); // ieee 754-2008 half-precision float16 // todo: make this not an integral type @@ -358,6 +345,7 @@ extern "C" { struct ggml_object; struct ggml_context; + struct ggml_cgraph; // NOTE: always add types at the end of the enum to keep backward compatibility enum ggml_type { @@ -392,12 +380,15 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, - GGML_TYPE_Q4_0_4_4 = 31, - GGML_TYPE_Q4_0_4_8 = 32, - GGML_TYPE_Q4_0_8_8 = 33, + // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files + // GGML_TYPE_Q4_0_4_8 = 32, + // GGML_TYPE_Q4_0_8_8 = 33, GGML_TYPE_TQ1_0 = 34, GGML_TYPE_TQ2_0 = 35, - GGML_TYPE_COUNT, + // GGML_TYPE_IQ4_NL_4_4 = 36, + // GGML_TYPE_IQ4_NL_4_8 = 37, + // GGML_TYPE_IQ4_NL_8_8 = 38, + GGML_TYPE_COUNT = 39, }; // precision @@ -406,12 +397,6 @@ extern "C" { GGML_PREC_F32, }; - enum ggml_backend_type { - GGML_BACKEND_TYPE_CPU = 0, - GGML_BACKEND_TYPE_GPU = 10, - GGML_BACKEND_TYPE_GPU_SPLIT = 20, - }; - // model file types enum ggml_ftype { GGML_FTYPE_UNKNOWN = -1, @@ -438,9 +423,6 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors - GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors }; // available tensor operations: @@ -463,6 +445,7 @@ extern "C" { GGML_OP_SUM_ROWS, GGML_OP_MEAN, GGML_OP_ARGMAX, + GGML_OP_COUNT_EQUAL, GGML_OP_REPEAT, GGML_OP_REPEAT_BACK, GGML_OP_CONCAT, @@ -503,6 +486,7 @@ extern "C" { GGML_OP_POOL_2D_BACK, GGML_OP_UPSCALE, // nearest interpolate GGML_OP_PAD, + GGML_OP_PAD_REFLECT_1D, GGML_OP_ARANGE, GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, @@ -516,7 +500,8 @@ extern "C" { GGML_OP_WIN_UNPART, GGML_OP_GET_REL_POS, GGML_OP_ADD_REL_POS, - GGML_OP_RWKV_WKV, + GGML_OP_RWKV_WKV6, + GGML_OP_GATED_LINEAR_ATTN, GGML_OP_UNARY, @@ -533,6 +518,7 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, + GGML_OP_OPT_STEP_ADAMW, GGML_OP_COUNT, }; @@ -563,37 +549,32 @@ extern "C" { }; enum ggml_log_level { - GGML_LOG_LEVEL_ERROR = 2, + GGML_LOG_LEVEL_NONE = 0, + GGML_LOG_LEVEL_DEBUG = 1, + GGML_LOG_LEVEL_INFO = 2, GGML_LOG_LEVEL_WARN = 3, - GGML_LOG_LEVEL_INFO = 4, - GGML_LOG_LEVEL_DEBUG = 5 + GGML_LOG_LEVEL_ERROR = 4, + GGML_LOG_LEVEL_CONT = 5, // continue previous log }; + // this tensor... enum ggml_tensor_flag { - GGML_TENSOR_FLAG_INPUT = 1, - GGML_TENSOR_FLAG_OUTPUT = 2, - GGML_TENSOR_FLAG_PARAM = 4, + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; - // ggml object - struct ggml_object { - size_t offs; - size_t size; - - struct ggml_object * next; - - enum ggml_object_type type; - - char padding[4]; + struct ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally + bool no_alloc; // don't allocate memory for the tensor data }; - static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); - // n-dimensional tensor struct ggml_tensor { - enum ggml_type type; - - GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor"); + enum ggml_type type; struct ggml_backend_buffer * buffer; @@ -611,7 +592,6 @@ extern "C" { int32_t flags; - struct ggml_tensor * grad; struct ggml_tensor * src[GGML_MAX_SRC]; // source tensor and offset for views @@ -624,7 +604,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - // char padding[4]; + char padding[8]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); @@ -634,95 +614,6 @@ extern "C" { // If it returns true, the computation is aborted typedef bool (*ggml_abort_callback)(void * data); - // Scheduling priorities - enum ggml_sched_priority { - GGML_SCHED_PRIO_NORMAL, - GGML_SCHED_PRIO_MEDIUM, - GGML_SCHED_PRIO_HIGH, - GGML_SCHED_PRIO_REALTIME - }; - - // Threadpool params - // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults - struct ggml_threadpool_params { - bool cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) - int n_threads; // number of threads - enum ggml_sched_priority prio; // thread priority - uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) - bool strict_cpu; // strict cpu placement - bool paused; // start in paused state - }; - - struct ggml_threadpool; // forward declaration, see ggml.c - - typedef struct ggml_threadpool * ggml_threadpool_t; - - // the compute plan that needs to be prepared for ggml_graph_compute() - // since https://github.com/ggerganov/ggml/issues/287 - struct ggml_cplan { - size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` - uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` - - int n_threads; - struct ggml_threadpool * threadpool; - - // abort ggml_graph_compute when true - ggml_abort_callback abort_callback; - void * abort_callback_data; - }; - - enum ggml_cgraph_eval_order { - GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0, - GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT, - GGML_CGRAPH_EVAL_ORDER_COUNT - }; - - typedef uint32_t ggml_bitset_t; - - struct ggml_hash_set { - size_t size; - ggml_bitset_t * used; // whether or not the keys are in use i.e. set - struct ggml_tensor ** keys; // actual tensors in the set, keys[i] is only defined if ggml_bitset_get(used, i) - }; - - // computation graph - struct ggml_cgraph { - int size; - int n_nodes; - int n_leafs; - - struct ggml_tensor ** nodes; - struct ggml_tensor ** grads; - struct ggml_tensor ** leafs; - - struct ggml_hash_set visited_hash_set; - - enum ggml_cgraph_eval_order order; - }; - - // scratch buffer - struct ggml_scratch { - size_t offs; - size_t size; - void * data; - }; - - struct ggml_init_params { - // memory pool - size_t mem_size; // bytes - void * mem_buffer; // if NULL, memory will be allocated internally - bool no_alloc; // don't allocate memory for the tensor data - }; - - // numa strategies - enum ggml_numa_strategy { - GGML_NUMA_STRATEGY_DISABLED = 0, - GGML_NUMA_STRATEGY_DISTRIBUTE = 1, - GGML_NUMA_STRATEGY_ISOLATE = 2, - GGML_NUMA_STRATEGY_NUMACTL = 3, - GGML_NUMA_STRATEGY_MIRROR = 4, - GGML_NUMA_STRATEGY_COUNT - }; // // GUID @@ -745,52 +636,49 @@ extern "C" { // accepts a UTF-8 path, even on Windows GGML_API FILE * ggml_fopen(const char * fname, const char * mode); - GGML_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems - GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node - GGML_API void ggml_print_object (const struct ggml_object * obj); GGML_API void ggml_print_objects(const struct ggml_context * ctx); - GGML_API GGML_CALL int64_t ggml_nelements (const struct ggml_tensor * tensor); - GGML_API GGML_CALL int64_t ggml_nrows (const struct ggml_tensor * tensor); - GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor); - GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN + GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor); + GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes_pad(const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN - GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type); - GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block - GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row + GGML_API int64_t ggml_blck_size(enum ggml_type type); + GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block + GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row GGML_DEPRECATED( GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float "use ggml_row_size() instead"); - GGML_API GGML_CALL const char * ggml_type_name(enum ggml_type type); - GGML_API GGML_CALL const char * ggml_op_name (enum ggml_op op); - GGML_API const char * ggml_op_symbol(enum ggml_op op); + GGML_API const char * ggml_type_name(enum ggml_type type); + GGML_API const char * ggml_op_name (enum ggml_op op); + GGML_API const char * ggml_op_symbol(enum ggml_op op); - GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); - GGML_API GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name + GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op); + GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name - GGML_API GGML_CALL size_t ggml_element_size(const struct ggml_tensor * tensor); + GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); - GGML_API GGML_CALL bool ggml_is_quantized(enum ggml_type type); + GGML_API bool ggml_is_quantized(enum ggml_type type); // TODO: temporary until model loading of ggml examples is refactored GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); - GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor); - GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor); - GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor); - GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); - GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); - GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); - GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); - GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars + GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); + GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars - GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor); - GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous() - GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 - GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 + GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous() + GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 + GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); @@ -804,12 +692,12 @@ extern "C" { // main - GGML_API struct ggml_context * ggml_init(struct ggml_init_params params); - GGML_API void ggml_free(struct ggml_context * ctx); + GGML_API struct ggml_context * ggml_init (struct ggml_init_params params); + GGML_API void ggml_reset(struct ggml_context * ctx); + GGML_API void ggml_free (struct ggml_context * ctx); GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); - GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch); GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx); GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); @@ -849,8 +737,7 @@ extern "C" { int64_t ne2, int64_t ne3); - GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); - GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); + GGML_API void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes); GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); @@ -860,35 +747,25 @@ extern "C" { GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor); GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); - GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); - GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); - GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); - // Converts a flat index into coordinates - GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); + GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); - GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); - GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); - - GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); - GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); - - GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); - GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); - - GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); - GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); - GGML_API GGML_CALL enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); - GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor); GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name); GGML_ATTRIBUTE_FORMAT(2, 3) GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); + // Tensor flags + GGML_API void ggml_set_input(struct ggml_tensor * tensor); + GGML_API void ggml_set_output(struct ggml_tensor * tensor); + GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API void ggml_set_loss(struct ggml_tensor * tensor); + // // operations on tensors with backpropagation // @@ -1039,6 +916,12 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // count number of equal elements in a and b + GGML_API struct ggml_tensor * ggml_count_equal( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // if a is the same shape as b, and a is not parameter, return a // otherwise, return a new tensor: repeat(a) to fit in b GGML_API struct ggml_tensor * ggml_repeat( @@ -1445,14 +1328,14 @@ extern "C" { // supports 3D: a->ne[2] == b->ne[1] GGML_API struct ggml_tensor * ggml_get_rows( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * a, // data + struct ggml_tensor * b); // row indices GGML_API struct ggml_tensor * ggml_get_rows_back( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c); + struct ggml_tensor * a, // gradients of ggml_get_rows result + struct ggml_tensor * b, // row indices + struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape GGML_API struct ggml_tensor * ggml_diag( struct ggml_context * ctx, @@ -1501,16 +1384,20 @@ extern "C" { float scale, float max_bias); - GGML_API struct ggml_tensor * ggml_soft_max_back( + GGML_API struct ggml_tensor * ggml_soft_max_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float scale, + float max_bias); // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( + GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + float scale, + float max_bias); // rotary position embedding // if (mode & 1) - skip n_past elements (NOT SUPPORTED) @@ -1549,6 +1436,22 @@ extern "C" { float beta_fast, float beta_slow); + GGML_API struct ggml_tensor * ggml_rope_multi( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_ext_inplace( struct ggml_context * ctx, @@ -1596,16 +1499,16 @@ extern "C" { "use ggml_rope_ext_inplace instead"); // compute correction dims for YaRN RoPE scaling - GGML_CALL void ggml_rope_yarn_corr_dims( + GGML_API void ggml_rope_yarn_corr_dims( int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy // a - dy - GGML_API struct ggml_tensor * ggml_rope_back( + GGML_API struct ggml_tensor * ggml_rope_ext_back( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, + struct ggml_tensor * a, // gradients of ggml_rope result + struct ggml_tensor * b, // positions + struct ggml_tensor * c, // freq factors int n_dims, int mode, int n_ctx_orig, @@ -1616,6 +1519,23 @@ extern "C" { float beta_fast, float beta_slow); + GGML_API struct ggml_tensor * ggml_rope_multi_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // clamp // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_clamp( @@ -1652,17 +1572,6 @@ extern "C" { int d1, // dilation dimension 1 bool is_2D); - GGML_API struct ggml_tensor * ggml_conv_depthwise_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, // convolution kernel - struct ggml_tensor * b, // data - int s0, // stride dimension 0 - int s1, // stride dimension 1 - int p0, // padding dimension 0 - int p1, // padding dimension 1 - int d0, // dilation dimension 0 - int d1); // dilation dimension 1 - GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel @@ -1680,6 +1589,23 @@ extern "C" { int s, // stride int d); // dilation + // depthwise + // TODO: this is very likely wrong for some cases! - needs more testing + GGML_API struct ggml_tensor * ggml_conv_1d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int p0, // padding + int d0); // dilation + + GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int d0); // dilation + GGML_API struct ggml_tensor * ggml_conv_transpose_1d( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel @@ -1699,7 +1625,6 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 - // kernel size is a->ne[0] x a->ne[1] // stride is equal to kernel size // padding is zero @@ -1726,6 +1651,18 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // depthwise + GGML_API struct ggml_tensor * ggml_conv_2d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1799,6 +1736,13 @@ extern "C" { int p2, int p3); + // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] + GGML_API struct ggml_tensor * ggml_pad_reflect_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1); + // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 // timesteps: [N,] // return: [N, dim] @@ -1852,6 +1796,9 @@ extern "C" { struct ggml_tensor * a, enum ggml_prec prec); + GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec( + const struct ggml_tensor * a); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, @@ -1925,7 +1872,7 @@ extern "C" { struct ggml_tensor * pw, struct ggml_tensor * ph); - GGML_API struct ggml_tensor * ggml_rwkv_wkv( + GGML_API struct ggml_tensor * ggml_rwkv_wkv6( struct ggml_context * ctx, struct ggml_tensor * k, struct ggml_tensor * v, @@ -1934,6 +1881,15 @@ extern "C" { struct ggml_tensor * td, struct ggml_tensor * state); + GGML_API struct ggml_tensor * ggml_gated_linear_attn( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * q, + struct ggml_tensor * g, + struct ggml_tensor * state, + float scale); + // custom operators typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); @@ -2017,7 +1973,8 @@ extern "C" { typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); - #define GGML_N_TASKS_MAX -1 +#define GGML_N_TASKS_MAX (-1) + // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks GGML_API struct ggml_tensor * ggml_map_custom1( struct ggml_context * ctx, @@ -2070,62 +2027,59 @@ extern "C" { // loss function GGML_API struct ggml_tensor * ggml_cross_entropy_loss( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_context * ctx, + struct ggml_tensor * a, // logits + struct ggml_tensor * b); // labels GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c); + struct ggml_context * ctx, + struct ggml_tensor * a, // logits + struct ggml_tensor * b, // labels + struct ggml_tensor * c); // gradients of cross_entropy_loss result + + // AdamW optimizer step + // Paper: https://arxiv.org/pdf/1711.05101v3.pdf + // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html + GGML_API struct ggml_tensor * ggml_opt_step_adamw( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * grad, + struct ggml_tensor * m, + struct ggml_tensor * v, + struct ggml_tensor * adamw_params); // parameters such a the learning rate // // automatic differentiation // - GGML_API void ggml_set_param( - struct ggml_context * ctx, - struct ggml_tensor * tensor); - - - GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); - GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); + GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand( + struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation) + struct ggml_context * ctx_compute, // context for gradient computation + struct ggml_cgraph * cgraph, + bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static // graph allocation in a context - GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false - GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads); - GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); - GGML_API struct ggml_cgraph ggml_graph_view (struct ggml_cgraph * cgraph, int i0, int i1); - GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); - GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads - GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); + GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false + GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads); + GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); + GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); + GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1 + GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); + + GGML_API int ggml_graph_size (struct ggml_cgraph * cgraph); + GGML_API struct ggml_tensor * ggml_graph_node (struct ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i] + GGML_API struct ggml_tensor ** ggml_graph_nodes (struct ggml_cgraph * cgraph); + GGML_API int ggml_graph_n_nodes(struct ggml_cgraph * cgraph); + + GGML_API void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API size_t ggml_graph_overhead(void); GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); - GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads); - GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params *p, int n_threads); - GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params *p0, const struct ggml_threadpool_params *p1); - GGML_API struct ggml_threadpool* ggml_threadpool_new (struct ggml_threadpool_params * params); - GGML_API void ggml_threadpool_free (struct ggml_threadpool * threadpool); - GGML_API int ggml_threadpool_get_n_threads(struct ggml_threadpool * threadpool); - GGML_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool); - GGML_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool); - - // ggml_graph_plan() has to be called before ggml_graph_compute() - // when plan.work_size > 0, caller must allocate memory for plan.work_data - GGML_API struct ggml_cplan ggml_graph_plan( - const struct ggml_cgraph * cgraph, - int n_threads, /* = GGML_DEFAULT_N_THREADS */ - struct ggml_threadpool * threadpool /* = NULL */ ); - GGML_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); - - // same as ggml_graph_compute() but the work data is allocated as a part of the context - // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data - GGML_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); - - GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); + GGML_API struct ggml_tensor * ggml_graph_get_tensor (const struct ggml_cgraph * cgraph, const char * name); + GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); + GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); @@ -2136,197 +2090,14 @@ extern "C" { // dump the graph into a file using the dot format GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); - // build gradient checkpointing backward graph gb for gf using provided checkpoints - // gb_tmp will contain original backward graph with rewritten backward process nodes, - // but without the second forward pass nodes. - GGML_API void ggml_build_backward_gradient_checkpointing( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - struct ggml_cgraph * gb_tmp, - struct ggml_tensor * * checkpoints, - int n_checkpoints); - // - // optimization - // - - // optimization methods - enum ggml_opt_type { - GGML_OPT_TYPE_ADAM, - GGML_OPT_TYPE_LBFGS, - }; - - // linesearch methods - enum ggml_linesearch { - GGML_LINESEARCH_DEFAULT = 1, - - GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, - GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, - GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, - }; - - // optimization return values - enum ggml_opt_result { - GGML_OPT_RESULT_OK = 0, - GGML_OPT_RESULT_DID_NOT_CONVERGE, - GGML_OPT_RESULT_NO_CONTEXT, - GGML_OPT_RESULT_INVALID_WOLFE, - GGML_OPT_RESULT_FAIL, - GGML_OPT_RESULT_CANCEL, - - GGML_LINESEARCH_FAIL = -128, - GGML_LINESEARCH_MINIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_ITERATIONS, - GGML_LINESEARCH_INVALID_PARAMETERS, - }; - - typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); + // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); - // optimization parameters - // - // see ggml.c (ggml_opt_default_params) for default values - // - struct ggml_opt_params { - enum ggml_opt_type type; + // Set callback for all future logging events. + // If this is not called, or NULL is supplied, everything is output on stderr. + GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data); - size_t graph_size; - - int n_threads; - - // delta-based convergence test - // - // if past == 0 - disabled - // if past > 0: - // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) - // - int past; - float delta; - - // maximum number of iterations without improvement - // - // if 0 - disabled - // if > 0: - // assume convergence if no cost improvement in this number of iterations - // - int max_no_improvement; - - bool print_forward_graph; - bool print_backward_graph; - - int n_gradient_accumulation; - - // ADAM parameters - struct { - int n_iter; - - float sched; // schedule multiplier (fixed, decay or warmup) - float decay; // weight decay for AdamW, use 0.0f to disable - int decay_min_ndim; // minimum number of tensor dimension to apply weight decay - float alpha; // learning rate - float beta1; - float beta2; - float eps; // epsilon for numerical stability - float eps_f; // epsilon for convergence test - float eps_g; // epsilon for convergence test - float gclip; // gradient clipping - } adam; - - // LBFGS parameters - struct { - int m; // number of corrections to approximate the inv. Hessian - int n_iter; - int max_linesearch; - - float eps; // convergence tolerance - float ftol; // line search tolerance - float wolfe; - float min_step; - float max_step; - - enum ggml_linesearch linesearch; - } lbfgs; - }; - - struct ggml_opt_context { - struct ggml_context * ctx; - struct ggml_opt_params params; - - int iter; - int64_t nx; // number of parameter elements - - bool just_initialized; - - float loss_before; - float loss_after; - - struct { - struct ggml_tensor * g; // current gradient - struct ggml_tensor * m; // first moment - struct ggml_tensor * v; // second moment - struct ggml_tensor * pf; // past function values - float fx_best; - float fx_prev; - int n_no_improvement; - } adam; - - struct { - struct ggml_tensor * x; // current parameters - struct ggml_tensor * xp; // previous parameters - struct ggml_tensor * g; // current gradient - struct ggml_tensor * gp; // previous gradient - struct ggml_tensor * d; // search direction - struct ggml_tensor * pf; // past function values - struct ggml_tensor * lmal; // the L-BFGS memory alpha - struct ggml_tensor * lmys; // the L-BFGS memory ys - struct ggml_tensor * lms; // the L-BFGS memory s - struct ggml_tensor * lmy; // the L-BFGS memory y - float fx_best; - float step; - int j; - int k; - int end; - int n_no_improvement; - } lbfgs; - }; - - GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); - - // optimize the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f); - - // initialize optimizer context - GGML_API void ggml_opt_init( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - int64_t nx); - - // continue optimizing the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt_resume( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f); - - // continue optimizing the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt_resume_g( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data); - - // - // tensor flags - // - GGML_API void ggml_set_input(struct ggml_tensor * tensor); - GGML_API void ggml_set_output(struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); // // quantization @@ -2357,205 +2128,65 @@ extern "C" { int64_t n_per_row, const float * imatrix); - // - // gguf - // - - enum gguf_type { - GGUF_TYPE_UINT8 = 0, - GGUF_TYPE_INT8 = 1, - GGUF_TYPE_UINT16 = 2, - GGUF_TYPE_INT16 = 3, - GGUF_TYPE_UINT32 = 4, - GGUF_TYPE_INT32 = 5, - GGUF_TYPE_FLOAT32 = 6, - GGUF_TYPE_BOOL = 7, - GGUF_TYPE_STRING = 8, - GGUF_TYPE_ARRAY = 9, - GGUF_TYPE_UINT64 = 10, - GGUF_TYPE_INT64 = 11, - GGUF_TYPE_FLOAT64 = 12, - GGUF_TYPE_COUNT, // marks the end of the enum - }; - - struct gguf_context; - - struct gguf_init_params { - bool no_alloc; - - // if not NULL, create a ggml_context and allocate the tensor data in it - struct ggml_context ** ctx; - }; - - GGML_API struct gguf_context * gguf_init_empty(void); - GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); - //GGML_API struct gguf_context * gguf_init_from_buffer(..); - - GGML_API void gguf_free(struct gguf_context * ctx); - - GGML_API const char * gguf_type_name(enum gguf_type type); - - GGML_API int gguf_get_version (const struct gguf_context * ctx); - GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); - GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); - GGML_API void * gguf_get_data (const struct gguf_context * ctx); - - GGML_API int gguf_get_n_kv(const struct gguf_context * ctx); - GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key); - GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id); - - GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id); - GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id); - - // will abort if the wrong type is used for the key - GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int key_id); - GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int key_id); - GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int key_id); - GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int key_id); - GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int key_id); - GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int key_id); - GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int key_id); - GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int key_id); - GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key_id); - GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); - GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); - GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); - GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id); - GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); - GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); - GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); - - GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); - GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); - GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i); - GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i); - GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int i); - - // removes key if it exists - GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key); - - // overrides existing values or adds a new one - GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); - GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); - GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); - GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); - GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); - GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); - GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); - GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); - GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); - GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); - GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); - GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); - GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n); - GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n); - - // set or add KV pairs from another context - GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src); - - // manage tensor info - GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); - GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); - GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size); - - // writing gguf files can be done in 2 ways: - // - // - write the entire gguf_context to a binary file in a single pass: - // - // gguf_write_to_file(ctx, fname); - // - // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: - // - // FILE * f = fopen(fname, "wb"); - // fseek(f, gguf_get_meta_size(ctx), SEEK_SET); - // fwrite(f, ...); - // void * data = gguf_meta_get_meta_data(ctx); - // fseek(f, 0, SEEK_SET); - // fwrite(f, data, gguf_get_meta_size(ctx)); - // free(data); - // fclose(f); - // - - // write the entire context to a binary file - GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); - - // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding - GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); - GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); - - // - // system info - // - - GGML_API int ggml_cpu_has_avx (void); - GGML_API int ggml_cpu_has_avx_vnni (void); - GGML_API int ggml_cpu_has_avx2 (void); - GGML_API int ggml_cpu_has_avx512 (void); - GGML_API int ggml_cpu_has_avx512_vbmi(void); - GGML_API int ggml_cpu_has_avx512_vnni(void); - GGML_API int ggml_cpu_has_avx512_bf16(void); - GGML_API int ggml_cpu_has_fma (void); - GGML_API int ggml_cpu_has_neon (void); - GGML_API int ggml_cpu_has_sve (void); - GGML_API int ggml_cpu_has_arm_fma (void); - GGML_API int ggml_cpu_has_metal (void); - GGML_API int ggml_cpu_has_f16c (void); - GGML_API int ggml_cpu_has_fp16_va (void); - GGML_API int ggml_cpu_has_wasm_simd (void); - GGML_API int ggml_cpu_has_blas (void); - GGML_API int ggml_cpu_has_cuda (void); - GGML_API int ggml_cpu_has_vulkan (void); - GGML_API int ggml_cpu_has_kompute (void); - GGML_API int ggml_cpu_has_gpublas (void); - GGML_API int ggml_cpu_has_sse3 (void); - GGML_API int ggml_cpu_has_ssse3 (void); - GGML_API int ggml_cpu_has_sycl (void); - GGML_API int ggml_cpu_has_rpc (void); - GGML_API int ggml_cpu_has_vsx (void); - GGML_API int ggml_cpu_has_matmul_int8(void); - GGML_API int ggml_cpu_has_cann (void); - GGML_API int ggml_cpu_has_llamafile (void); - - // - // Internal types and functions exposed for tests and benchmarks - // - -#ifdef __cplusplus -// restrict not standard in C++ -#define GGML_RESTRICT +#ifdef __cplusplus + // restrict not standard in C++ +# if defined(__GNUC__) +# define GGML_RESTRICT __restrict__ +# elif defined(__clang__) +# define GGML_RESTRICT __restrict +# elif defined(_MSC_VER) +# define GGML_RESTRICT __restrict +# else +# define GGML_RESTRICT +# endif #else -#define GGML_RESTRICT restrict +# define GGML_RESTRICT restrict #endif typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - typedef void (*ggml_from_float_to_mat_t) - (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs); - typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, - const void * GGML_RESTRICT y, size_t by, int nrc); - typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, - const void * GGML_RESTRICT y, int nr, int nc); - typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, - const void * GGML_RESTRICT y, int nr, int nc); - typedef struct { + struct ggml_type_traits { const char * type_name; int64_t blck_size; int64_t blck_size_interleave; // interleave elements in blocks size_t type_size; bool is_quantized; ggml_to_float_t to_float; - ggml_from_float_t from_float; ggml_from_float_t from_float_ref; - ggml_from_float_to_mat_t from_float_to_mat; - ggml_vec_dot_t vec_dot; - enum ggml_type vec_dot_type; - int64_t nrows; // number of rows to process simultaneously - int64_t ncols; // number of columns to process simultaneously - ggml_gemv_t gemv; - ggml_gemm_t gemm; - } ggml_type_traits_t; + }; - GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); + GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type); + + // ggml threadpool + // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend + // the goal should be to create an API that other backends can use move everything to the ggml base + + // scheduling priorities + enum ggml_sched_priority { + GGML_SCHED_PRIO_NORMAL, + GGML_SCHED_PRIO_MEDIUM, + GGML_SCHED_PRIO_HIGH, + GGML_SCHED_PRIO_REALTIME + }; + + // threadpool params + // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults + struct ggml_threadpool_params { + bool cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) + int n_threads; // number of threads + enum ggml_sched_priority prio; // thread priority + uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) + bool strict_cpu; // strict cpu placement + bool paused; // start in paused state + }; + + struct ggml_threadpool; // forward declaration, see ggml.c + + typedef struct ggml_threadpool * ggml_threadpool_t; + + GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads); + GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads); + GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1); #ifdef __cplusplus } diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h new file mode 100644 index 000000000..79ee20206 --- /dev/null +++ b/ggml/include/gguf.h @@ -0,0 +1,202 @@ +// This file contains functionality related to "GGUF" files, the binary file format used by ggml. +// GGUF files have the following structure: +// +// 1. File magic "GGUF" (4 bytes). +// 2. File version (uint32_t). +// 3. Number of ggml tensors in file (int64_t). +// 4. Number of key-value-pairs in file (int64_t). +// 5. For each KV pair: +// 1. The key (string). +// 2. The value type (gguf_type). +// 3a. If the value type is GGUF_TYPE_ARRAY: +// 1. The type of the array (gguf_type). +// 2. The number of elements in the array (uint64_t). +// 3. The binary representation of each element in the array. +// 3b. Otherwise: +// 1. The binary representation of the value. +// 6. For each ggml tensor: +// 1. The tensor name (string). +// 2. The number of dimensions of the tensor (uint32_t). +// 3. For each dimension: +// 1. The size of the tensor in the dimension (int64_t). +// 4. The tensor data type (ggml_type). +// 5. The tensor data offset in the tensor data binary blob (uint64_t). +// 7. The tensor data binary blob (optional, aligned). +// +// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator. +// All enums are stored as int32_t. +// All bool values are stored as int8_t. +// If the special key "general.alignment" (uint32_t) is defined it is used for alignment, +// otherwise GGUF_DEFAULT_ALIGNMENT is used. +// +// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) + +#pragma once + +#include "ggml.h" + +#include +#include + +#define GGUF_MAGIC "GGUF" +#define GGUF_VERSION 3 + +#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment" + +#define GGUF_DEFAULT_ALIGNMENT 32 + +#ifdef __cplusplus +extern "C" { +#endif + + // types that can be stored as GGUF KV data + enum gguf_type { + GGUF_TYPE_UINT8 = 0, + GGUF_TYPE_INT8 = 1, + GGUF_TYPE_UINT16 = 2, + GGUF_TYPE_INT16 = 3, + GGUF_TYPE_UINT32 = 4, + GGUF_TYPE_INT32 = 5, + GGUF_TYPE_FLOAT32 = 6, + GGUF_TYPE_BOOL = 7, + GGUF_TYPE_STRING = 8, + GGUF_TYPE_ARRAY = 9, + GGUF_TYPE_UINT64 = 10, + GGUF_TYPE_INT64 = 11, + GGUF_TYPE_FLOAT64 = 12, + GGUF_TYPE_COUNT, // marks the end of the enum + }; + + struct gguf_context; + + struct gguf_init_params { + bool no_alloc; + + // if not NULL, create a ggml_context and allocate the tensor data in it + struct ggml_context ** ctx; + }; + + GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); + //GGML_API struct gguf_context * gguf_init_from_buffer(..); + + GGML_API void gguf_free(struct gguf_context * ctx); + + GGML_API const char * gguf_type_name(enum gguf_type type); + + GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx); + GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + + GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx); + GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found + GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id); + + GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id); + GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id); + + // will abort if the wrong type is used for the key + GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int64_t key_id); + GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id); + GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id); + GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id); + GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id); + GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id); + GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id); + GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id); + GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id); + GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id); + GGML_API size_t gguf_get_arr_n (const struct gguf_context * ctx, int64_t key_id); + + // get raw pointer to the first element of the array with the given key_id + // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference) + GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id); + + // get ith C string from array with given key_id + GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); + + GGML_API int64_t gguf_get_n_tensors (const struct gguf_context * ctx); + GGML_API int64_t gguf_find_tensor (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found + GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id); + GGML_API const char * gguf_get_tensor_name (const struct gguf_context * ctx, int64_t tensor_id); + GGML_API enum ggml_type gguf_get_tensor_type (const struct gguf_context * ctx, int64_t tensor_id); + GGML_API size_t gguf_get_tensor_size (const struct gguf_context * ctx, int64_t tensor_id); + + // removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist) + GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key); + + // overrides an existing KV pair or adds a new one, the new KV pair is always at the back + GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); + GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); + GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); + GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); + GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); + GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); + GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); + GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); + GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); + GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); + GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); + GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); + + // creates a new array with n elements of the given type and copies the corresponding number of bytes from data + GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n); + + // creates a new array with n strings and copies the corresponding strings from data + GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n); + + // set or add KV pairs from another context + GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src); + + // add tensor to GGUF context, tensor name must be unique + GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); + + // after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated + // in such a way that the tensor data remains as one contiguous block (except for padding) + GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); + + // assumes that at least gguf_get_tensor_size bytes can be read from data + GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data); + + // writing gguf files can be done in 3 ways: + // + // - write the entire gguf_context to a binary file in a single pass: + // + // gguf_write_to_file(ctx, fname, /*only_meta =*/ false); + // + // - write only the meta data to a file, then re-open the file and append the tensor data: + // + // gguf_write_to_file(ctx, fname, /*only_meta =*/ true); + // FILE * f = fopen(fname, "ab"); + // fwrite(f, ...); // write tensor data + // fclose(f); + // + // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: + // + // FILE * f = fopen(fname, "wb"); + // const size_t size_meta = gguf_get_meta_size(ctx); + // fseek(f, size_meta, SEEK_SET); + // fwrite(f, ...); // write tensor data + // void * data = malloc(size_meta); + // gguf_get_meta_data(ctx, data); + // rewind(f); + // fwrite(data, 1, data, f); + // free(data); + // fclose(f); + // + + // write the entire context to a binary file + GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); + + // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding + GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); + + // writes the meta data to pointer "data" + GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index cd2dcd066..566709135 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -1,7 +1,5 @@ include(CheckCXXCompilerFlag) -unset(GGML_CDEF_PUBLIC) - add_compile_definitions(GGML_SCHED_MAX_COPIES=${GGML_SCHED_MAX_COPIES}) # enable libstdc++ assertions for debug builds @@ -26,879 +24,7 @@ if (NOT MSVC) endif() endif() -if (APPLE AND GGML_ACCELERATE) - find_library(ACCELERATE_FRAMEWORK Accelerate) - if (ACCELERATE_FRAMEWORK) - message(STATUS "Accelerate framework found") - - add_compile_definitions(GGML_USE_ACCELERATE) - add_compile_definitions(ACCELERATE_NEW_LAPACK) - add_compile_definitions(ACCELERATE_LAPACK_ILP64) - - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) - else() - message(WARNING "Accelerate framework not found") - endif() -endif() - -if (GGML_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) - - message(STATUS "Metal framework found") - set(GGML_HEADERS_METAL ../include/ggml-metal.h) - set(GGML_SOURCES_METAL ggml-metal.m) - - list(APPEND GGML_CDEF_PUBLIC GGML_USE_METAL) - if (GGML_METAL_NDEBUG) - add_compile_definitions(GGML_METAL_NDEBUG) - endif() - - # copy ggml-common.h and ggml-metal.metal to bin directory - configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) - configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) - - if (GGML_METAL_EMBED_LIBRARY) - enable_language(ASM) - - add_compile_definitions(GGML_METAL_EMBED_LIBRARY) - - set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/ggml-common.h") - set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") - - file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") - - # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") - - add_custom_command( - OUTPUT ${METALLIB_EMBED_ASM} - COMMAND echo "Embedding Metal library" - COMMAND sed -e '/\#include \"ggml-common.h\"/r ${METALLIB_COMMON}' -e '/\#include \"ggml-common.h\"/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED} - COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM} - COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM} - COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM} - COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM} - COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM} - COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM} - DEPENDS ggml-metal.metal ggml-common.h - COMMENT "Generate assembly for embedded Metal library" - ) - - set(GGML_SOURCES_METAL ${GGML_SOURCES_METAL} ${METALLIB_EMBED_ASM}) - else() - if (GGML_METAL_SHADER_DEBUG) - # custom command to do the following: - # xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air - # xcrun -sdk macosx metallib ggml-metal.air -o default.metallib - # - # note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works - # disabling fast math is needed in order to pass tests/test-backend-ops - # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1 - # note: unfortunately, we have to call it default.metallib instead of ggml.metallib - # ref: https://github.com/ggerganov/whisper.cpp/issues/1720 - set(XC_FLAGS -fno-fast-math -fno-inline -g) - else() - set(XC_FLAGS -O3) - endif() - - # Append macOS metal versioning flags - if (GGML_METAL_MACOSX_VERSION_MIN) - message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation") - list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN}) - endif() - - if (GGML_METAL_STD) - message(STATUS "Adding -std=${GGML_METAL_STD} flag to metal compilation") - list (APPEND XC_FLAGS -std=${GGML_METAL_STD}) - endif() - - add_custom_command( - OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib - COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air - COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib - COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air - COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h - COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal - DEPENDS ggml-metal.metal ggml-common.h - COMMENT "Compiling Metal kernels" - ) - - add_custom_target( - ggml-metal ALL - DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib - ) - endif() # GGML_METAL_EMBED_LIBRARY - - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} - ${FOUNDATION_LIBRARY} - ${METAL_FRAMEWORK} - ${METALKIT_FRAMEWORK} - ) -endif() - -if (GGML_MUSA) - set(CMAKE_C_COMPILER clang) - set(CMAKE_C_EXTENSIONS OFF) - set(CMAKE_CXX_COMPILER clang++) - set(CMAKE_CXX_EXTENSIONS OFF) - - set(GGML_CUDA ON) - - list(APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA) -endif() - -if (GGML_OPENMP) - find_package(OpenMP) - if (OpenMP_FOUND) - message(STATUS "OpenMP found") - - add_compile_definitions(GGML_USE_OPENMP) - - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX) - - if (GGML_MUSA) - set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp") - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so") - endif() - else() - message(WARNING "OpenMP not found") - endif() -endif() - -if (GGML_BLAS) - if (GGML_STATIC) - set(BLA_STATIC ON) - endif() - #if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22) - # set(BLA_SIZEOF_INTEGER 8) - #endif() - - set(BLA_VENDOR ${GGML_BLAS_VENDOR}) - find_package(BLAS) - - if (BLAS_FOUND) - message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") - - if (("${BLAS_INCLUDE_DIRS}" STREQUAL "") AND NOT (${GGML_BLAS_VENDOR} MATCHES "Apple")) - # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. - # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 - find_package(PkgConfig REQUIRED) - if (${GGML_BLAS_VENDOR} MATCHES "Generic") - pkg_check_modules(DepBLAS REQUIRED blas) - elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS") - # As of openblas v0.3.22, the 64-bit is named openblas64.pc - pkg_check_modules(DepBLAS openblas64) - if (NOT DepBLAS_FOUND) - pkg_check_modules(DepBLAS REQUIRED openblas) - endif() - elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME") - pkg_check_modules(DepBLAS REQUIRED blis) - elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS") - pkg_check_modules(DepBLAS REQUIRED blas-atlas) - elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS") - pkg_check_modules(DepBLAS REQUIRED flexiblas_api) - elseif (${GGML_BLAS_VENDOR} MATCHES "Intel") - # all Intel* libraries share the same include path - pkg_check_modules(DepBLAS REQUIRED mkl-sdl) - elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC") - # this doesn't provide pkg-config - # suggest to assign BLAS_INCLUDE_DIRS on your own - if ("${NVHPC_VERSION}" STREQUAL "") - message(WARNING "Better to set NVHPC_VERSION") - else() - set(DepBLAS_FOUND ON) - set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") - endif() - endif() - if (DepBLAS_FOUND) - set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) - else() - message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" - " detected by pkgconfig, trying to find cblas.h from possible paths...") - find_path(BLAS_INCLUDE_DIRS - NAMES cblas.h - HINTS - /usr/include - /usr/local/include - /usr/include/openblas - /opt/homebrew/opt/openblas/include - /usr/local/opt/openblas/include - /usr/include/x86_64-linux-gnu/openblas/include - ) - endif() - endif() - - message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") - - add_compile_options(${BLAS_LINKER_FLAGS}) - - list(APPEND GGML_CDEF_PUBLIC GGML_USE_BLAS) - - if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) - add_compile_definitions(GGML_BLAS_USE_MKL) - endif() - - set(GGML_HEADERS_BLAS ../include/ggml-blas.h) - set(GGML_SOURCES_BLAS ggml-blas.cpp) - - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${BLAS_LIBRARIES}) - set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS}) - else() - message(WARNING "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct GGML_BLAS_VENDOR") - endif() -endif() - -if (GGML_LLAMAFILE) - message(STATUS "Using llamafile") - - add_compile_definitions(GGML_USE_LLAMAFILE) - - set(GGML_HEADERS_LLAMAFILE llamafile/sgemm.h) - set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp) -endif() - -if (GGML_CUDA) - cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES - - if (GGML_MUSA) - list(APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/") - find_package(MUSAToolkit) - set(CUDAToolkit_FOUND ${MUSAToolkit_FOUND}) - else() - find_package(CUDAToolkit) - endif() - - if (CUDAToolkit_FOUND) - message(STATUS "CUDA found") - - if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - # 52 == lowest CUDA 12 standard - # 60 == FP16 CUDA intrinsics - # 61 == integer CUDA intrinsics - # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") - else() - set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") - #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work - endif() - endif() - message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") - - if (GGML_MUSA) - set(CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE}) - else() - enable_language(CUDA) - endif() - - file(GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh") - list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h") - - file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu") - list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu") - file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - - if (GGML_CUDA_FA_ALL_QUANTS) - file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) - else() - file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - endif() - - list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA) - - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) - add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) - - if (GGML_CUDA_USE_GRAPHS) - add_compile_definitions(GGML_CUDA_USE_GRAPHS) - endif() - - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - - if (GGML_CUDA_FORCE_MMQ) - add_compile_definitions(GGML_CUDA_FORCE_MMQ) - endif() - - if (GGML_CUDA_FORCE_CUBLAS) - add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) - endif() - - if (GGML_CUDA_NO_VMM) - add_compile_definitions(GGML_CUDA_NO_VMM) - endif() - - if (DEFINED GGML_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility - endif() - - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - add_compile_definitions(GGML_CUDA_F16) - endif() - - if (GGML_CUDA_NO_PEER_COPY) - add_compile_definitions(GGML_CUDA_NO_PEER_COPY) - endif() - - if (GGML_MUSA) - set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX) - foreach(SOURCE ${GGML_SOURCES_CUDA}) - set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22") - endforeach() - endif() - - if (GGML_STATIC) - if (WIN32) - # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt) - else () - if (GGML_MUSA) - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static) - else() - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) - endif() - endif() - else() - if (GGML_MUSA) - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas) - else() - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) - endif() - endif() - - if (GGML_CUDA_NO_VMM) - # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so) - else() - if (GGML_MUSA) - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ... - else() - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ... - endif() - endif() - else() - message(WARNING "CUDA not found") - endif() -endif() - -if (GGML_HIPBLAS) - if (NOT EXISTS $ENV{ROCM_PATH}) - if (NOT EXISTS /opt/rocm) - set(ROCM_PATH /usr) - else() - set(ROCM_PATH /opt/rocm) - endif() - else() - set(ROCM_PATH $ENV{ROCM_PATH}) - endif() - - list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) - list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake") - - # CMake on Windows doesn't support the HIP language yet - if (WIN32) - set(CXX_IS_HIPCC TRUE) - else() - string(REGEX MATCH "hipcc(\.bat)?$" CXX_IS_HIPCC "${CMAKE_CXX_COMPILER}") - endif() - - if (CXX_IS_HIPCC) - if (LINUX) - if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") - message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") - endif() - - message(WARNING "Setting hipcc as the C++ compiler is legacy behavior." - " Prefer setting the HIP compiler directly. See README for details.") - endif() - else() - # Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES. - if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) - set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) - endif() - cmake_minimum_required(VERSION 3.21) - enable_language(HIP) - endif() - - find_package(hip REQUIRED) - find_package(hipblas REQUIRED) - find_package(rocblas REQUIRED) - - message(STATUS "HIP and hipBLAS found") - - file(GLOB GGML_HEADERS_ROCM "ggml-cuda/*.cuh") - list(APPEND GGML_HEADERS_ROCM "../include/ggml-cuda.h") - - file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu") - list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu") - file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - - if (GGML_CUDA_FA_ALL_QUANTS) - file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) - else() - file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - endif() - - list(APPEND GGML_CDEF_PUBLIC GGML_USE_CUDA) - - add_compile_definitions(GGML_USE_HIPBLAS) - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) - - if (GGML_HIP_UMA) - add_compile_definitions(GGML_HIP_UMA) - endif() - - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - - if (GGML_CUDA_FORCE_MMQ) - add_compile_definitions(GGML_CUDA_FORCE_MMQ) - endif() - - if (GGML_CUDA_FORCE_CUBLAS) - add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) - endif() - - if (GGML_CUDA_NO_PEER_COPY) - add_compile_definitions(GGML_CUDA_NO_PEER_COPY) - endif() - - if (CXX_IS_HIPCC) - set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} hip::device) - else() - set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP) - endif() - - if (GGML_STATIC) - message(FATAL_ERROR "Static linking not supported for HIP/ROCm") - endif() - - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} PUBLIC hip::host roc::rocblas roc::hipblas) -endif() - -if (GGML_SYCL) - if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA)$") - message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL or NVIDIA") - endif() - - check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL) - if ( DEFINED ENV{ONEAPI_ROOT}) - message(STATUS "Using oneAPI Release SYCL compiler (icpx).") - elseif(SUPPORTS_SYCL) - message(WARNING "Using open-source SYCL compiler (clang++). Didn't detect ENV {ONEAPI_ROOT}. - If you expected the oneAPI Release compiler, please install oneAPI & source it, like: - source /opt/intel/oneapi/setvars.sh") - else() - message(FATAL_ERROR, "C++ compiler lacks SYCL support.") - endif() - message(STATUS "SYCL found") - #todo: AOT - - list(APPEND GGML_CDEF_PUBLIC GGML_USE_SYCL) - - if (GGML_SYCL_F16) - add_compile_definitions(GGML_SYCL_F16) - endif() - - if (GGML_CUDA_FORCE_MMQ) - add_compile_definitions(GGML_SYCL_FORCE_MMQ) - endif() - - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl") - - if (GGML_SYCL_TARGET STREQUAL "NVIDIA") - add_compile_definitions(GGML_SYCL_WARP_SIZE=32) - else() - add_compile_definitions(GGML_SYCL_WARP_SIZE=16) - endif() - - file(GLOB GGML_HEADERS_SYCL "ggml-sycl/*.hpp") - list(APPEND GGML_HEADERS_SYCL "../include/ggml-sycl.h") - - file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp") - list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp") - - find_package(DNNL) - message("-- DNNL found:" ${DNNL_FOUND}) - if (GGML_SYCL_TARGET STREQUAL "INTEL") - add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND}) - else() - add_compile_definitions(GGML_SYCL_DNNL=0) - endif() - if (WIN32) - find_package(IntelSYCL REQUIRED) - find_package(MKL REQUIRED) - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) - else() - if (GGML_SYCL_TARGET STREQUAL "INTEL") - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) - elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl) - endif() - endif() - if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") - list(APPEND GGML_EXTRA_LIBS DNNL::dnnl) - endif() -endif() - -if (GGML_RPC) - message(STATUS "RPC found") - - list(APPEND GGML_CDEF_PUBLIC GGML_USE_RPC) - - if (WIN32) - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ws2_32) - endif() - - set(GGML_HEADERS_RPC ../include/ggml-rpc.h) - set(GGML_SOURCES_RPC ggml-rpc.cpp) -endif() - -if (GGML_VULKAN) - find_package(Vulkan COMPONENTS glslc REQUIRED) - - if (Vulkan_FOUND) - message(STATUS "Vulkan found") - - list(APPEND GGML_CDEF_PUBLIC GGML_USE_VULKAN) - - # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build - # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector - if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") - add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) - endif() - - if (GGML_VULKAN_CHECK_RESULTS) - add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) - endif() - - if (GGML_VULKAN_DEBUG) - add_compile_definitions(GGML_VULKAN_DEBUG) - endif() - - if (GGML_VULKAN_MEMORY_DEBUG) - add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) - endif() - - if (GGML_VULKAN_SHADER_DEBUG_INFO) - add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) - endif() - - if (GGML_VULKAN_PERF) - add_compile_definitions(GGML_VULKAN_PERF) - endif() - - if (GGML_VULKAN_VALIDATE) - add_compile_definitions(GGML_VULKAN_VALIDATE) - endif() - - if (GGML_VULKAN_RUN_TESTS) - add_compile_definitions(GGML_VULKAN_RUN_TESTS) - endif() - - add_subdirectory(vulkan-shaders) - - set (_ggml_vk_genshaders_cmd vulkan-shaders-gen) - set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) - set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) - set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) - set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) - - file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") - - add_custom_command( - OUTPUT ${_ggml_vk_header} - ${_ggml_vk_source} - - COMMAND ${_ggml_vk_genshaders_cmd} - --glslc ${Vulkan_GLSLC_EXECUTABLE} - --input-dir ${_ggml_vk_input_dir} - --output-dir ${_ggml_vk_output_dir} - --target-hpp ${_ggml_vk_header} - --target-cpp ${_ggml_vk_source} - --no-clean - - DEPENDS ${_ggml_vk_shader_deps} - COMMENT "Generate vulkan shaders" - ) - - set(GGML_HEADERS_VULKAN ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml-vulkan.h ${_ggml_vk_header}) - set(GGML_SOURCES_VULKAN ggml-vulkan.cpp ${_ggml_vk_source}) - - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} Vulkan::Vulkan) - set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CMAKE_CURRENT_BINARY_DIR}) - else() - message(WARNING "Vulkan not found") - endif() -endif() - -if (GGML_KOMPUTE) - add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1) - - find_package(Vulkan COMPONENTS glslc REQUIRED) - find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc) - - if (NOT glslc_executable) - message(FATAL_ERROR "glslc not found") - endif() - - function(compile_shader) - set(options) - set(oneValueArgs) - set(multiValueArgs SOURCES) - cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - foreach(source ${compile_shader_SOURCES}) - get_filename_component(filename ${source} NAME) - set(spv_file ${filename}.spv) - add_custom_command( - OUTPUT ${spv_file} - DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source} - ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp - ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp - ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp - ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp - COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source} - COMMENT "Compiling ${source} to ${spv_file}" - ) - - get_filename_component(RAW_FILE_NAME ${spv_file} NAME) - set(FILE_NAME "shader${RAW_FILE_NAME}") - string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME}) - string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE) - string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}") - set(OUTPUT_HEADER_FILE "${HEADER_FILE}") - message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}") - if(CMAKE_GENERATOR MATCHES "Visual Studio") - add_custom_command( - OUTPUT ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_BINARY_DIR}/bin/$/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - DEPENDS ${spv_file} xxd - COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$/xxd" - ) - else() - add_custom_command( - OUTPUT ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE} - COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} - DEPENDS ${spv_file} xxd - COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd" - ) - endif() - endforeach() - endfunction() - - if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt") - message(STATUS "Kompute found") - set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level") - add_subdirectory(kompute) - - # Compile our shaders - compile_shader(SOURCES - kompute-shaders/op_scale.comp - kompute-shaders/op_scale_8.comp - kompute-shaders/op_add.comp - kompute-shaders/op_addrow.comp - kompute-shaders/op_mul.comp - kompute-shaders/op_silu.comp - kompute-shaders/op_relu.comp - kompute-shaders/op_gelu.comp - kompute-shaders/op_softmax.comp - kompute-shaders/op_norm.comp - kompute-shaders/op_rmsnorm.comp - kompute-shaders/op_diagmask.comp - kompute-shaders/op_mul_mat_mat_f32.comp - kompute-shaders/op_mul_mat_f16.comp - kompute-shaders/op_mul_mat_q8_0.comp - kompute-shaders/op_mul_mat_q4_0.comp - kompute-shaders/op_mul_mat_q4_1.comp - kompute-shaders/op_mul_mat_q6_k.comp - kompute-shaders/op_getrows_f32.comp - kompute-shaders/op_getrows_f16.comp - kompute-shaders/op_getrows_q4_0.comp - kompute-shaders/op_getrows_q4_1.comp - kompute-shaders/op_getrows_q6_k.comp - kompute-shaders/op_rope_f16.comp - kompute-shaders/op_rope_f32.comp - kompute-shaders/op_cpy_f16_f16.comp - kompute-shaders/op_cpy_f16_f32.comp - kompute-shaders/op_cpy_f32_f16.comp - kompute-shaders/op_cpy_f32_f32.comp - ) - - # Create a custom target for our generated shaders - add_custom_target(generated_shaders DEPENDS - shaderop_scale.h - shaderop_scale_8.h - shaderop_add.h - shaderop_addrow.h - shaderop_mul.h - shaderop_silu.h - shaderop_relu.h - shaderop_gelu.h - shaderop_softmax.h - shaderop_norm.h - shaderop_rmsnorm.h - shaderop_diagmask.h - shaderop_mul_mat_mat_f32.h - shaderop_mul_mat_f16.h - shaderop_mul_mat_q8_0.h - shaderop_mul_mat_q4_0.h - shaderop_mul_mat_q4_1.h - shaderop_mul_mat_q6_k.h - shaderop_getrows_f32.h - shaderop_getrows_f16.h - shaderop_getrows_q4_0.h - shaderop_getrows_q4_1.h - shaderop_getrows_q6_k.h - shaderop_rope_f16.h - shaderop_rope_f32.h - shaderop_cpy_f16_f16.h - shaderop_cpy_f16_f32.h - shaderop_cpy_f32_f16.h - shaderop_cpy_f32_f32.h - ) - - # Create a custom command that depends on the generated_shaders - add_custom_command( - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp - COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp - DEPENDS generated_shaders - COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp" - ) - - # Add the stamp to the main sources to ensure dependency tracking - set(GGML_SOURCES_KOMPUTE ggml-kompute.cpp ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp) - set(GGML_HEADERS_KOMPUTE ../include/ggml-kompute.h ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp) - - list(APPEND GGML_CDEF_PUBLIC GGML_USE_KOMPUTE) - - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} kompute) - set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CMAKE_CURRENT_BINARY_DIR}) - else() - message(WARNING "Kompute not found") - endif() -endif() - -if (GGML_CPU_HBM) - find_library(memkind memkind REQUIRED) - - message(STATUS "Using memkind for CPU HBM") - - add_compile_definitions(GGML_USE_CPU_HBM) - - target_link_libraries(ggml PUBLIC memkind) -endif() - -if (GGML_CANN) - if ("cann${CANN_INSTALL_DIR}" STREQUAL "cann" AND DEFINED ENV{ASCEND_TOOLKIT_HOME}) - set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME}) - message(STATUS "CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}") - endif() - - if (CANN_INSTALL_DIR) - # Only Support Linux. - if (GGML_CANN) - if (NOT UNIX) - set(GGML_CANN OFF) - message(WARNING "CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}. Turning off GGML_CANN") - endif() - endif() - - # Supported platforms: x86-64, arm64 - if (GGML_CANN) - if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") - elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64") - else() - set(GGML_CANN OFF) - message(WARNING "CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}. Turning off GGML_CANN") - endif() - endif() - - # Set header and libs - if(GGML_CANN) - set(CANN_INCLUDE_DIRS - ${CANN_INSTALL_DIR}/include - ${CANN_INSTALL_DIR}/include/aclnn - ${CANN_INSTALL_DIR}/acllib/include - ) - - add_subdirectory(ggml-cann/kernels) - list(APPEND CANN_LIBRARIES - ascendcl - nnopbase - opapi - acl_op_compiler - ascendc_kernels - ) - - set(GGML_HEADERS_CANN "../include/ggml-cann.h") - file(GLOB GGML_SOURCES_CANN "ggml-cann/*.cpp") - list(APPEND GGML_SOURCES_CANN "ggml-cann.cpp") - - message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}") - message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}") - - set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${CANN_LIBRARIES} ) - set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CANN_INCLUDE_DIRS}) - set(GGML_EXTRA_LIBDIRS ${GGML_EXTRA_LIBDIRS} ${CANN_INSTALL_DIR}/lib64) - list(APPEND GGML_CDEF_PUBLIC GGML_USE_CANN) - endif() - else() - set(GGML_CANN OFF) - message(WARNING "CANN: Can't find CANN_INSTALL_DIR, do you forget to source set_var.sh. Turning off GGML_CANN") - endif() - - if(NOT GGML_CANN) - message(WARNING "CANN: GGML_CANN is turned OFF, see above for details.") - endif() -endif() - -function(get_flags CCID CCVER) +function(ggml_get_flags CCID CCVER) set(C_FLAGS "") set(CXX_FLAGS "") @@ -916,11 +42,6 @@ function(get_flags CCID CCVER) set(C_FLAGS -Wdouble-promotion) set(CXX_FLAGS -Wno-array-bounds) - if (NOT GGML_MUSA) - if (CCVER VERSION_GREATER_EQUAL 7.1.0) - list(APPEND CXX_FLAGS -Wno-format-truncation) - endif() - endif() if (CCVER VERSION_GREATER_EQUAL 8.1.0) list(APPEND CXX_FLAGS -Wextra-semi) endif() @@ -949,7 +70,7 @@ if (GGML_ALL_WARNINGS) list(APPEND C_FLAGS ${WARNING_FLAGS}) list(APPEND CXX_FLAGS ${WARNING_FLAGS}) - get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}) + ggml_get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION}) add_compile_options("$<$:${C_FLAGS};${GF_C_FLAGS}>" "$<$:${CXX_FLAGS};${GF_CXX_FLAGS}>") @@ -960,54 +81,6 @@ if (GGML_ALL_WARNINGS) endif() endif() -set(CUDA_CXX_FLAGS "") - -if (GGML_CUDA) - set(CUDA_FLAGS -use_fast_math) - - if (GGML_FATAL_WARNINGS) - list(APPEND CUDA_FLAGS -Werror all-warnings) - endif() - - if (GGML_ALL_WARNINGS AND NOT MSVC) - set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c) - if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "") - list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER}) - endif() - - execute_process( - COMMAND ${NVCC_CMD} -Xcompiler --version - OUTPUT_VARIABLE CUDA_CCFULLVER - ERROR_QUIET - ) - - if (NOT CUDA_CCFULLVER MATCHES clang) - set(CUDA_CCID "GNU") - execute_process( - COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion" - OUTPUT_VARIABLE CUDA_CCVER - ERROR_QUIET - ) - else() - if (CUDA_CCFULLVER MATCHES Apple) - set(CUDA_CCID "AppleClang") - else() - set(CUDA_CCID "Clang") - endif() - string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER}) - endif() - - message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") - - get_flags(${CUDA_CCID} ${CUDA_CCVER}) - list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later - endif() - - if (NOT MSVC) - list(APPEND CUDA_CXX_FLAGS -Wno-pedantic) - endif() -endif() - if (GGML_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT result OUTPUT output) @@ -1065,167 +138,6 @@ if (NOT MSVC) endif() endif() -set(ARCH_FLAGS "") - -if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR - CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR - (NOT CMAKE_OSX_ARCHITECTURES AND - NOT CMAKE_GENERATOR_PLATFORM_LWR AND - CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$")) - - message(STATUS "ARM detected") - - if (MSVC) - add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead - add_compile_definitions(__ARM_NEON) - add_compile_definitions(__ARM_FEATURE_FMA) - - set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS}) - string(JOIN " " CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} "/arch:armv8.2") - - check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD) - if (GGML_COMPILER_SUPPORT_DOTPROD) - add_compile_definitions(__ARM_FEATURE_DOTPROD) - endif () - - check_cxx_source_compiles("#include \nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8) - - if (GGML_COMPILER_SUPPORT_MATMUL_INT8) - add_compile_definitions(__ARM_FEATURE_MATMUL_INT8) - endif () - - check_cxx_source_compiles("#include \nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC) - if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC) - add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - endif () - - set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV}) - else() - check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E) - if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "") - list(APPEND ARCH_FLAGS -mfp16-format=ieee) - endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6") - # Raspberry Pi 1, Zero - list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access) - endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7") - if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android") - # Android armeabi-v7a - list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations) - else() - # Raspberry Pi 2 - list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations) - endif() - endif() - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8") - # Android arm64-v8a - # Raspberry Pi 3, 4, Zero 2 (32-bit) - list(APPEND ARCH_FLAGS -mno-unaligned-access) - endif() - if (GGML_SVE) - list(APPEND ARCH_FLAGS -march=armv8.6-a+sve) - endif() - endif() -elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR - (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND - CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$")) - message(STATUS "x86 detected") - if (MSVC) - # instruction set detection for MSVC only - if (GGML_NATIVE) - # TODO: improve, should not reference files from the parent folder - include(../cmake/FindSIMD.cmake) - endif () - if (GGML_AVX512) - list(APPEND ARCH_FLAGS /arch:AVX512) - # MSVC has no compile-time flags enabling specific - # AVX512 extensions, neither it defines the - # macros corresponding to the extensions. - # Do it manually. - if (GGML_AVX512_VBMI) - add_compile_definitions($<$:__AVX512VBMI__>) - add_compile_definitions($<$:__AVX512VBMI__>) - endif() - if (GGML_AVX512_VNNI) - add_compile_definitions($<$:__AVX512VNNI__>) - add_compile_definitions($<$:__AVX512VNNI__>) - endif() - if (GGML_AVX512_BF16) - add_compile_definitions($<$:__AVX512BF16__>) - add_compile_definitions($<$:__AVX512BF16__>) - endif() - elseif (GGML_AVX2) - list(APPEND ARCH_FLAGS /arch:AVX2) - elseif (GGML_AVX) - list(APPEND ARCH_FLAGS /arch:AVX) - endif() - else() - if (GGML_NATIVE) - list(APPEND ARCH_FLAGS -march=native) - endif() - if (GGML_F16C) - list(APPEND ARCH_FLAGS -mf16c) - endif() - if (GGML_FMA) - list(APPEND ARCH_FLAGS -mfma) - endif() - if (GGML_AVX) - list(APPEND ARCH_FLAGS -mavx) - endif() - if (GGML_AVX2) - list(APPEND ARCH_FLAGS -mavx2) - endif() - if (GGML_AVX512) - list(APPEND ARCH_FLAGS -mavx512f) - list(APPEND ARCH_FLAGS -mavx512bw) - endif() - if (GGML_AVX512_VBMI) - list(APPEND ARCH_FLAGS -mavx512vbmi) - endif() - if (GGML_AVX512_VNNI) - list(APPEND ARCH_FLAGS -mavx512vnni) - endif() - if (GGML_AVX512_BF16) - list(APPEND ARCH_FLAGS -mavx512bf16) - endif() - endif() -elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") - message(STATUS "PowerPC detected") - if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") - list(APPEND ARCH_FLAGS -mcpu=powerpc64le) - else() - list(APPEND ARCH_FLAGS -mcpu=native -mtune=native) - #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) - endif() -elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") - message(STATUS "loongarch64 detected") - - list(APPEND ARCH_FLAGS -march=loongarch64) - if (GGML_LASX) - list(APPEND ARCH_FLAGS -mlasx) - endif() - if (GGML_LSX) - list(APPEND ARCH_FLAGS -mlsx) - endif() -else() - message(STATUS "Unknown architecture") -endif() - -add_compile_options("$<$:${ARCH_FLAGS}>") -add_compile_options("$<$:${ARCH_FLAGS}>") - -if (GGML_CUDA) - list(APPEND CUDA_CXX_FLAGS ${ARCH_FLAGS}) - list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument - - if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "") - list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED}) - endif() - - add_compile_options("$<$:${CUDA_FLAGS}>") -endif() - if (MINGW) # Target Windows 8 for PrefetchVirtualMemory add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) @@ -1239,14 +151,14 @@ endif() # CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional # posix_memalign came in POSIX.1-2001 / SUSv3 # M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985) -add_compile_definitions(_XOPEN_SOURCE=600) # Somehow in OpenBSD whenever POSIX conformance is specified # some string functions rely on locale_t availability, # which was introduced in POSIX.1-2008, forcing us to go higher if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") - remove_definitions(-D_XOPEN_SOURCE=600) add_compile_definitions(_XOPEN_SOURCE=700) +else() + add_compile_definitions(_XOPEN_SOURCE=600) endif() # Data types, macros and functions related to controlling CPU affinity and @@ -1282,62 +194,158 @@ endif() if (WIN32) add_compile_definitions(_CRT_SECURE_NO_WARNINGS) - - if (BUILD_SHARED_LIBS) - # TODO: should not use this - set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) - endif() endif() -# -# libraries -# - # ggml -add_library(ggml +if (GGML_BACKEND_DL AND NOT BUILD_SHARED_LIBS) + message(FATAL_ERROR "GGML_BACKEND_DL requires BUILD_SHARED_LIBS") +endif() + +add_library(ggml-base ../include/ggml.h ../include/ggml-alloc.h ../include/ggml-backend.h + ../include/ggml-cpp.h + ../include/ggml-opt.h + ../include/gguf.h ggml.c ggml-alloc.c - ggml-backend.c + ggml-backend.cpp + ggml-opt.cpp + ggml-threading.cpp + ggml-threading.h ggml-quants.c ggml-quants.h - ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} - ${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL} - ${GGML_SOURCES_RPC} ${GGML_HEADERS_RPC} - ${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA} - ${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL} - ${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE} - ${GGML_SOURCES_VULKAN} ${GGML_HEADERS_VULKAN} - ${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM} - ${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS} - ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE} - ${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN} - ggml-aarch64.c ggml-aarch64.h - ) + gguf.cpp) -if (EMSCRIPTEN) - set_target_properties(ggml PROPERTIES COMPILE_FLAGS "-msimd128") +target_include_directories(ggml-base PRIVATE .) + +add_library(ggml + ggml-backend-reg.cpp) + +target_link_libraries(ggml PUBLIC ggml-base) + +if (CMAKE_SYSTEM_NAME MATCHES "Linux") + target_link_libraries(ggml PRIVATE dl) endif() -target_compile_definitions(ggml PUBLIC ${GGML_CDEF_PUBLIC}) -target_include_directories(ggml PUBLIC ../include) -target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES}) -target_link_directories(ggml PRIVATE ${GGML_EXTRA_LIBDIRS}) -target_compile_features (ggml PRIVATE c_std_11) # don't bump +function(ggml_add_backend_library backend) + if (GGML_BACKEND_DL) + add_library(${backend} MODULE ${ARGN}) + # write the shared library to the output directory + set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL) + add_dependencies(ggml ${backend}) + else() + add_library(${backend} ${ARGN}) + target_link_libraries(ggml PUBLIC ${backend}) + install(TARGETS ${backend} LIBRARY) + endif() -target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS}) + target_link_libraries(${backend} PRIVATE ggml-base) + target_include_directories(${backend} PRIVATE ..) + + if (${BUILD_SHARED_LIBS}) + target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD) + target_compile_definitions(${backend} PUBLIC GGML_BACKEND_SHARED) + endif() + + if(NOT GGML_AVAILABLE_BACKENDS) + set(GGML_AVAILABLE_BACKENDS "${backend}" + CACHE INTERNAL "List of backends for cmake package") + else() + list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend) + if(has_backend EQUAL -1) + set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}" + CACHE INTERNAL "List of backends for cmake package") + endif() + endif() +endfunction() + +function(ggml_add_backend backend) + string(TOUPPER "GGML_${backend}" backend_id) + if (${backend_id}) + string(TOLOWER "ggml-${backend}" backend_target) + add_subdirectory(${backend_target}) + message(STATUS "Including ${backend} backend") + if (NOT GGML_BACKEND_DL) + string(TOUPPER "GGML_USE_${backend}" backend_use) + target_compile_definitions(ggml PUBLIC ${backend_use}) + endif() + endif() +endfunction() + +function(ggml_add_cpu_backend_variant tag_name) + set(GGML_CPU_TAG_NAME ${tag_name}) + # other: OPENMP LLAMAFILE CPU_HBM + foreach (feat NATIVE + AVX AVX2 AVX_VNNI FMA F16C + AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 + AMX_TILE AMX_INT8 AMX_BF16) + set(GGML_${feat} OFF) + endforeach() + + foreach (feat ${ARGN}) + set(GGML_${feat} ON) + endforeach() + + ggml_add_cpu_backend_variant_impl(${tag_name}) +endfunction() + +ggml_add_backend(CPU) + +if (GGML_CPU_ALL_VARIANTS) + if (NOT GGML_BACKEND_DL) + message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL") + endif() + ggml_add_cpu_backend_variant(sandybridge AVX) + ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 FMA) + ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512) + ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI) + ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI) + if (NOT MSVC) + # MSVC doesn't support AMX + ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) + endif() +elseif (GGML_CPU) + ggml_add_cpu_backend_variant_impl("") +endif() + +ggml_add_backend(BLAS) +ggml_add_backend(CANN) +ggml_add_backend(CUDA) +ggml_add_backend(HIP) +ggml_add_backend(Kompute) +ggml_add_backend(METAL) +ggml_add_backend(MUSA) +ggml_add_backend(RPC) +ggml_add_backend(SYCL) +ggml_add_backend(Vulkan) +ggml_add_backend(OpenCL) + +foreach (target ggml-base ggml) + target_include_directories(${target} PUBLIC $ $) + target_compile_features (${target} PRIVATE c_std_11 cxx_std_17) # don't bump +endforeach() + +target_link_libraries(ggml-base PRIVATE Threads::Threads) find_library(MATH_LIBRARY m) if (MATH_LIBRARY) - if (NOT WIN32 OR NOT GGML_SYCL) - target_link_libraries(ggml PRIVATE ${MATH_LIBRARY}) + if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT}) + target_link_libraries(ggml-base PRIVATE m) endif() endif() -if (BUILD_SHARED_LIBS) - set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(ggml PRIVATE GGML_SHARED GGML_BUILD) +if (CMAKE_SYSTEM_NAME MATCHES "Android") + target_link_libraries(ggml-base PRIVATE dl) +endif() + +if (BUILD_SHARED_LIBS) + foreach (target ggml-base ggml) + set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(${target} PRIVATE GGML_BUILD) + target_compile_definitions(${target} PUBLIC GGML_SHARED) + endforeach() endif() diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c deleted file mode 100644 index 72cb83c9b..000000000 --- a/ggml/src/ggml-aarch64.c +++ /dev/null @@ -1,2792 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. -#define GGML_COMMON_IMPL_C -#include "ggml-common.h" - -#include "ggml-quants.h" -#include "ggml-impl.h" - -#include -#include -#include -#include -#include // for qsort -#include // for GGML_ASSERT - -#include "ggml-aarch64.h" - -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Woverlength-strings" -#elif defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -#define UNUSED GGML_UNUSED - -// Functions to create the interleaved data layout formats - -// interleave 4 block_q4_0s in blocks of blck_size_interleave -// returns an interleaved block_q4_0x4 -// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks -// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave -// -// - in : an array of block_q4_0 pointers -// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of -// blck_size_interleave bytes -// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes -// from bias offset form to pure sign form (this saves subtract -// operations durin unpacking) -// -#if defined(__AVX__) -#if defined(__F16C__) -// the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68)) -#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)) -#else -static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { - float tmp[8]; - - for (int i = 0; i < 8; i++) { - tmp[i] = GGML_FP16_TO_FP32(x[i]); - } - - return _mm256_loadu_ps(tmp); -} -static inline __m256 __avx_repeat_f32cx8_load(ggml_fp16_t *x) { - float tmp[8]; - - for (int i = 0; i < 4; i++) { - tmp[i] = GGML_FP16_TO_FP32(x[i]); - tmp[i + 4] = GGML_FP16_TO_FP32(x[i]); - } - - return _mm256_loadu_ps(tmp); -} -static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrangeMask) { - uint16_t tmphalf[8]; - float tmp[8]; - - _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)); - for (int i = 0; i < 8; i++) { - tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]); - } - - return _mm256_loadu_ps(tmp); -} - -#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) -#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x) -#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask) -#endif -#endif - - -#if defined(__AVX2__) || defined(__AVX512F__) -static inline __m256i sum_i16_pairs_int(const __m256i x) { - const __m256i ones = _mm256_set1_epi16(1); - return _mm256_madd_epi16(ones, x); -} - -static inline __m256i mul_sum_us8_pairs_int(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbusd_epi32(zero, ax, sy); -#else - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_int(dot); -#endif -} - -// Integer variant of the function defined in ggml-quants.c -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256i mul_sum_i8_pairs_int(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbssd_epi32(zero, x, y); -#else - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_int(ax, sy); -#endif -} -#endif - -static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { - block_q4_0x4 out; - - for (int i = 0; i < 4; i++) { - out.d[i] = in[i].d; - } - - for (int i = 0; i < QK4_0 * 2; i++) { - int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); - - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; - } - - return out; -} - -// interleave 8 block_q4_0s in blocks of blck_size_interleave -// returns an interleaved block_q4_0x8 -// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks -// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave -static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { - block_q4_0x8 out; - - for (int i = 0; i < 8; i++) { - out.d[i] = in[i].d; - } - - for (int i = 0; i < QK4_0 * 4; i++) { - int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; - int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; - src_offset += (i % blck_size_interleave); - - out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; - } - - return out; -} - -void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; - -#if defined(__ARM_NEON) - float32x4_t srcv[4][8]; - float id[4]; - - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - - for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 8; j++) { - float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); - int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[1][j], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[2][j], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[3][j], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); - } - } -#else - // scalar - const int blck_size_interleave = 4; - float srcv[4][QK8_0]; - float id[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); - } - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (j % blck_size_interleave); - - float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0); - } - } -#endif -} - -void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; - -#if defined(__ARM_NEON) - float32x4_t srcv[4][8]; - float id[4]; - - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - - for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 4; j++) { - float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); - int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[1][2 * j], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[2][2 * j], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[3][2 * j], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); - } - } -#elif defined(__AVX2__) || defined(__AVX__) - float id[4]; - __m256 srcv[4][4]; - __m256 idvec[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 ); - __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 ); - __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 ); - __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 ); - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Divided by 127.f to mirror results in quantize_row_q8_0 - const float d = maxScalar / 127.f; - id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f; - - // Store the scale for the individual block - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - - // Store the values in blocks of eight values - Aim is to use these later for block interleaving - srcv[row_iter][0] = v0; - srcv[row_iter][1] = v1; - srcv[row_iter][2] = v2; - srcv[row_iter][3] = v3; - idvec[row_iter] = _mm256_set1_ps(id[row_iter]); - } - - // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved - for (int j = 0; j < 4; j++) { - // Apply the multiplier - __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]); - __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]); - __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]); - __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); - i2 = _mm256_packs_epi32( i2, i3 ); - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); - - // Permute and store the quantized weights in the required order after the pack instruction - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4); -#endif - } - } -#else - // scalar - const int blck_size_interleave = 8; - float srcv[4][QK8_0]; - float id[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); - } - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (j % blck_size_interleave); - - float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0); - } - } -#endif -} - -void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { - assert(nrow == 4); - UNUSED(nrow); - if (blck_size_interleave == 4) { - quantize_q8_0_4x4(x, vy, n_per_row); - } else if (blck_size_interleave == 8) { - quantize_q8_0_4x8(x, vy, n_per_row); - } else { - assert(false); - } -} - -static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) { - assert(n_per_row % QK4_0 == 0); - const int nb = n_per_row / QK4_0; - - void * out_ptr = NULL; - if (nrows_interleaved == 8) { - out_ptr = (block_q4_0x8 *) dst; - } - else if (nrows_interleaved == 4) { - out_ptr = (block_q4_0x4 *) dst; - } - assert(nrows_interleaved <= 8); - block_q4_0 dst_tmp[8]; - - for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { - - for (int64_t x = 0; x < nb; x++) { - - for (int i = 0; i < nrows_interleaved; i++ ) { - quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); - } - - if (nrows_interleaved == 8) { - *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88); - out_ptr = (block_q4_0x8 *) out_ptr + 1; - } - else if (nrows_interleaved == 4) { - *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88); - out_ptr = (block_q4_0x4 *) out_ptr + 1; - } - } - } - - return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0)); -} - -size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - UNUSED(quant_weights); - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); -} - -size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - UNUSED(quant_weights); - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); -} - -size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - UNUSED(quant_weights); - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); -} - -void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) - if (ggml_sve_cnt_b == QK8_0) { - GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) && - "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v31.16b, #0x4\n" - "movi v30.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x22, %x[a_ptr], #0x2\n" - "movi v29.16b, #0x0\n" - "mov x21, %x[nb]\n" - "2:" // Block loop - "ldr q28, [%x[b_ptr], #0x0]\n" - "ldr q27, [x22, #0x0]\n" - "movi v26.4s, #0x0\n" - "sub x20, x22, #0x2\n" - "ldr q25, [x22, #0x10]\n" - "ldr q24, [%x[b_ptr], #0x10]\n" - "sub x21, x21, #0x1\n" - "add x22, x22, #0x22\n" - "ldr q23, [%x[b_ptr], #0x20]\n" - "ldr q22, [%x[b_ptr], #0x30]\n" - "ld1r { v21.8h }, [x20]\n" - "ldr q20, [%x[b_ptr], #-0x8]\n" - "sshl v16.16b, v28.16b, v31.16b\n" - "and v28.16b, v28.16b, v30.16b\n" - "sshl v19.16b, v24.16b, v31.16b\n" - "and v24.16b, v24.16b, v30.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "sshl v18.16b, v23.16b, v31.16b\n" - "and v23.16b, v23.16b, v30.16b\n" - ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" - "sshl v17.16b, v22.16b, v31.16b\n" - "and v22.16b, v22.16b, v30.16b\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v16.4s, v20.4h\n" - ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" - "fmul v16.4s, v16.4s, v21.4s\n" - ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" - ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" - ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" - ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" - ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" - ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v29.4s, v26.4s, v16.4s\n" - "cbnz x21, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q29, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" - ); -#else - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -#endif -} - -void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) - if (ggml_sve_cnt_b == QK8_0) { - GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v2.16b, #0x4\n" - "movi v1.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x23, %x[a_ptr], #0x2\n" - "movi v0.16b, #0x0\n" - "mov x22, %x[nb]\n" - "2:" // Block loop - "ldr q31, [%x[b_ptr], #0x0]\n" - "ldr q30, [%x[b_ptr], #0x10]\n" - "mov x21, x23\n" - "movi v29.4s, #0x0\n" - "ldr q28, [%x[b_ptr], #0x20]\n" - "ldr q27, [%x[b_ptr], #0x30]\n" - "movi v26.4s, #0x0\n" - "sub x20, x23, #0x2\n" - "ld1r { v25.8h }, [x20]\n" - "ldr q24, [%x[b_ptr], #-0x8]\n" - "sub x22, x22, #0x1\n" - "add x23, x23, #0x22\n" - "ld1r { v23.2d }, [x21], #0x8\n" - "sshl v22.16b, v31.16b, v2.16b\n" - "sshl v16.16b, v30.16b, v2.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "ld1r { v21.2d }, [x21], #0x8\n" - "sshl v20.16b, v28.16b, v2.16b\n" - "sshl v19.16b, v27.16b, v2.16b\n" - "ld1r { v18.2d }, [x21], #0x8\n" - "ld1r { v17.2d }, [x21], #0x8\n" - "and v31.16b, v31.16b, v1.16b\n" - "and v30.16b, v30.16b, v1.16b\n" - ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" - ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" - "and v28.16b, v28.16b, v1.16b\n" - "and v27.16b, v27.16b, v1.16b\n" - "fcvtl v25.4s, v25.4h\n" - "fcvtl v16.4s, v24.4h\n" - ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" - ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" - "fmul v16.4s, v16.4s, v25.4s\n" - ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" - ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" - ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" - ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" - "addp v29.4s, v29.4s, v26.4s\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v0.4s, v29.4s, v16.4s\n" - "cbnz x22, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q0, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" - ); -#elif defined(__ARM_NEON) && defined(__aarch64__) - GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -#endif -} - -void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - if (ggml_sve_cnt_b == QK8_0) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "ptrue p0.b\n" - "add %x[b_ptr], %x[b_ptr], #0x10\n" - "1:" // Column loop - "add x22, %x[a_ptr], #0x2\n" - "mov z31.b, #0x0\n" - "mov x21, %x[nb]\n" - "2:" // Block loop - "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" - "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" - "mov z28.s, #0x0\n" - "mov z27.s, #0x0\n" - "ld1rd { z26.d }, p0/Z, [x22]\n" - "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" - "sub x20, x22, #0x2\n" - "sub x21, x21, #0x1\n" - "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" - "ld1rd { z23.d }, p0/Z, [x22, #8]\n" - "lsl z22.b, z30.b, #0x4\n" - "lsl z16.b, z29.b, #0x4\n" - "and z30.b, z30.b, #0xf0\n" - "and z29.b, z29.b, #0xf0\n" - "ld1rd { z21.d }, p0/Z, [x22, #16]\n" - "ld1rd { z20.d }, p0/Z, [x22, #24]\n" - "lsl z19.b, z25.b, #0x4\n" - "and z25.b, z25.b, #0xf0\n" - "ld1rh { z17.h }, p0/Z, [x20]\n" - "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" - "sdot z28.s, z22.b, z26.b\n" - "sdot z27.s, z16.b, z26.b\n" - "lsl z16.b, z24.b, #0x4\n" - "add x22, x22, #0x22\n" - "and z24.b, z24.b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x90\n" - "fcvt z17.s, p0/m, z17.h\n" - "fcvt z18.s, p0/m, z18.h\n" - "sdot z28.s, z19.b, z23.b\n" - "sdot z27.s, z16.b, z23.b\n" - "fmul z18.s, z18.s, z17.s\n" - "sdot z28.s, z30.b, z21.b\n" - "sdot z27.s, z29.b, z21.b\n" - "sdot z28.s, z25.b, z20.b\n" - "sdot z27.s, z24.b, z20.b\n" - "uzp1 z17.s, z28.s, z27.s\n" - "uzp2 z16.s, z28.s, z27.s\n" - "add z17.s, z17.s, z16.s\n" - "asr z17.s, z17.s, #0x4\n" - "scvtf z17.s, p0/m, z17.s\n" - "fmla z31.s, p0/M, z17.s, z18.s\n" - "cbnz x21, 2b\n" - "sub %x[nc], %x[nc], #0x8\n" - "st1w { z31.s }, p0, [%x[res_ptr]]\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" - ); - return; - } - else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { - GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " - "performance"); - } - else if (ggml_cpu_has_neon()) { - GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " - "quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - GGML_ASSERT(ggml_cpu_has_sve() && - "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) - GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#elif defined(__AVX2__) - // Lookup table to convert signed nibbles to signed bytes - __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); - signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); - __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); - __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - - // Permute mask used for easier vector processing at later stages - const __m256i m4b = _mm256_set1_epi8(0x0F); - - int64_t b_nb = n / QK4_0; - - const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; - const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy; - - // Process Q8_0 blocks one by one - for (int64_t y = 0; y < nr; y++) { - - // Pointers to LHS blocks of block_q8_0 format - const block_q8_0 * a_ptr = a_ptr_start + (y * nb); - - // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < nc / 8; x++) { - - // Pointers to RHS blocks - const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); - - // Master FP accumulator - __m256 acc_row = _mm256_setzero_ps(); - - for (int64_t b = 0; b < nb; b++) { - // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7) - const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); - const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1); - const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2); - const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3); - - // 4-bit -> 8-bit - Sign is maintained - const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7) - const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7) - const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) - const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) - - const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23) - const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23) - const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31) - const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31) - - // Load the scale values for the 8 blocks interleaved in block_q4_0x8 - const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); - - // Load and convert to FP32 scale from block_q8_0 - const __m256 row_scale_f32 = _mm256_set1_ps(GGML_FP16_TO_FP32(a_ptr[b].d)); - - // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector - __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs)); - __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16))); - - lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15) - lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31)) - - __m256i iacc = _mm256_setzero_si256(); - - // Dot product done within 32 bit lanes and accumulated in the same vector - // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3) - // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7) - // ........................................................................... - // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) - - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); - - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); - - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); - - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); - - // Accumulated values multipled with appropriate scales - acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); - } - - // Accumulated output values permuted so as to be stored in appropriate order post accumulation - acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask); - _mm256_storeu_ps(s + (y * nr + x * 8), acc_row); - } - } -#else - float sumf[8]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -#endif -} - -void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (ggml_sve_cnt_b == QK8_0) { - GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) && - "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v23.16b, #0x0\n" - "movi v16.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v0.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v21.16b, #0x0\n" - "movi v8.16b, #0x0\n" - "movi v1.16b, #0x0\n" - "3:" // Block loop - "ldr q3, [x28, #0x0]\n" - "ldr q31, [x25, #0x0]\n" - "movi v28.16b, #0x4\n" - "movi v10.4s, #0x0\n" - "ldr q22, [x28, #0x10]\n" - "ldr q6, [x25, #0x10]\n" - "movi v29.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q27, [x28, #0x20]\n" - "ldr q30, [x28, #0x30]\n" - "movi v20.4s, #0x0\n" - "movi v24.16b, #0xf0\n" - "ldr d2, [x25, #-0x8]\n" - "ldr d26, [x23, #-0x8]\n" - "sshl v12.16b, v3.16b, v28.16b\n" - "sub x20, x28, #0x8\n" - "ldr d17, [x20, #0x0]\n" - "and v3.16b, v3.16b, v24.16b\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" - ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" - ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" - ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" - "sshl v31.16b, v22.16b, v28.16b\n" - "and v22.16b, v22.16b, v24.16b\n" - "fcvtl v17.4s, v17.4h\n" - "fcvtl v2.4s, v2.4h\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" - ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" - ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" - ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" - "sshl v6.16b, v27.16b, v28.16b\n" - "sshl v28.16b, v30.16b, v28.16b\n" - "and v27.16b, v27.16b, v24.16b\n" - "and v30.16b, v30.16b, v24.16b\n" - "ldr q24, [x25, #0x20]\n" - ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x30]\n" - ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" - ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" - ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" - ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x40]\n" - ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x50]\n" - ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" - ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" - ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" - ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x60]\n" - ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" - ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" - ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" - ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" - "fmul v24.4s, v17.4s, v2.s[0]\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v15.4s, v10.4s, v24.4s\n" - "ldr q24, [x23, #0x0]\n" - "fmul v10.4s, v17.4s, v2.s[1]\n" - "fmla v19.4s, v29.4s, v10.4s\n" - "ldr q10, [x23, #0x10]\n" - "fmul v29.4s, v17.4s, v2.s[2]\n" - "fmul v2.4s, v17.4s, v2.s[3]\n" - "fmla v18.4s, v9.4s, v29.4s\n" - "movi v9.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" - "fmla v14.4s, v20.4s, v2.4s\n" - "movi v20.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x20]\n" - ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" - ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" - ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" - ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x30]\n" - ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x40]\n" - ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" - ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" - ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" - ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x50]\n" - ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x60]\n" - ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" - ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" - ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" - ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x0]\n" - ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" - ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" - ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" - ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" - "fmul v10.4s, v17.4s, v26.s[0]\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v11.4s, v9.4s, v10.4s\n" - "ldr q9, [x22, #0x10]\n" - "fmul v10.4s, v17.4s, v26.s[1]\n" - "fmla v13.4s, v29.4s, v10.4s\n" - "ldr d29, [x22, #-0x8]\n" - "fmul v10.4s, v17.4s, v26.s[2]\n" - "fmul v26.4s, v17.4s, v26.s[3]\n" - "fcvtl v29.4s, v29.4h\n" - "fmla v23.4s, v20.4s, v10.4s\n" - "movi v20.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v16.4s, v2.4s, v26.4s\n" - "movi v26.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x20]\n" - ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x30]\n" - ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x40]\n" - ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x50]\n" - ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x60]\n" - ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x21, #0x0]\n" - ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" - ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" - ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" - "fmul v9.4s, v17.4s, v29.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v25.4s, v20.4s, v9.4s\n" - "ldr q9, [x21, #0x10]\n" - "fmul v20.4s, v17.4s, v29.s[1]\n" - "fmla v7.4s, v10.4s, v20.4s\n" - "ldr d20, [x21, #-0x8]\n" - "fmul v10.4s, v17.4s, v29.s[2]\n" - "fmul v29.4s, v17.4s, v29.s[3]\n" - "fcvtl v20.4s, v20.4h\n" - "fmla v0.4s, v26.4s, v10.4s\n" - "movi v26.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v4.4s, v2.4s, v29.4s\n" - "movi v2.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" - "ldr q12, [x21, #0x20]\n" - "fmul v24.4s, v17.4s, v20.s[0]\n" - ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x30]\n" - "fmul v31.4s, v17.4s, v20.s[1]\n" - ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" - ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" - ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" - ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x40]\n" - "fmul v6.4s, v17.4s, v20.s[2]\n" - "fmul v20.4s, v17.4s, v20.s[3]\n" - ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x50]\n" - ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" - ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" - ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" - ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x60]\n" - ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" - "ldr q17, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" - ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" - ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" - ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" - ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" - ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" - ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" - ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "fmla v5.4s, v26.4s, v24.4s\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v21.4s, v10.4s, v31.4s\n" - "fmla v8.4s, v2.4s, v6.4s\n" - "fmla v1.4s, v29.4s, v20.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q16, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q0, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q21, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q8, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q1, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q7, [x24, #0x0]\n" - "ldr q5, [x25, #0x0]\n" - "movi v9.16b, #0x4\n" - "movi v4.4s, #0x0\n" - "ldr q3, [x24, #0x10]\n" - "ldr q2, [x25, #0x10]\n" - "movi v1.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q13, [x24, #0x20]\n" - "ldr q31, [x25, #0x20]\n" - "movi v30.4s, #0x0\n" - "movi v29.16b, #0xf0\n" - "ldr q28, [x24, #0x30]\n" - "ldr q27, [x25, #0x30]\n" - "sshl v20.16b, v7.16b, v9.16b\n" - "sub x20, x24, #0x8\n" - "ldr q26, [x25, #0x40]\n" - "ldr q25, [x25, #0x50]\n" - "sshl v17.16b, v3.16b, v9.16b\n" - "and v7.16b, v7.16b, v29.16b\n" - "ldr q24, [x25, #0x60]\n" - "ldr q16, [x25, #0x70]\n" - "sshl v22.16b, v13.16b, v9.16b\n" - "and v3.16b, v3.16b, v29.16b\n" - "ldr d21, [x20, #0x0]\n" - "ldr d12, [x25, #-0x8]\n" - ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" - ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" - ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" - ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" - "sshl v9.16b, v28.16b, v9.16b\n" - "subs x21, x21, #0x1\n" - "and v13.16b, v13.16b, v29.16b\n" - "and v28.16b, v28.16b, v29.16b\n" - "add x25, x25, #0x88\n" - "add x24, x24, #0x48\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v12.4s, v12.4h\n" - ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" - ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" - ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" - ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" - "fmul v11.4s, v21.4s, v12.s[0]\n" - "fmul v23.4s, v21.4s, v12.s[1]\n" - "fmul v17.4s, v21.4s, v12.s[2]\n" - ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" - "fmul v6.4s, v21.4s, v12.s[3]\n" - ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" - ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" - ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" - ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" - ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" - ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" - ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" - ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" - ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" - ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" - ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" - ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" - ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" - ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" - ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" - ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" - ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" - ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" - ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" - ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" - ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" - ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" - ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" - "scvtf v4.4s, v4.4s, #0x4\n" - "scvtf v1.4s, v1.4s, #0x4\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "fmla v15.4s, v4.4s, v11.4s\n" - "scvtf v30.4s, v30.4s, #0x4\n" - "fmla v19.4s, v1.4s, v23.4s\n" - "fmla v18.4s, v0.4s, v17.4s\n" - "fmla v14.4s, v30.4s, v6.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q14, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); -#else - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -#endif -} - -void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (ggml_sve_cnt_b == QK8_0) { - GGML_ASSERT(!(ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v22.16b, #0x0\n" - "movi v23.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v6.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "movi v24.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "3:" // Block loop - "ldr q21, [x28, #0x0]\n" - "ldr q16, [x28, #0x10]\n" - "movi v1.16b, #0x4\n" - "movi v19.4s, #0x0\n" - "ldr q27, [x25, #0x0]\n" - "ldr q15, [x25, #0x10]\n" - "movi v26.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "ldr q29, [x28, #0x20]\n" - "ldr q3, [x28, #0x30]\n" - "movi v17.4s, #0x0\n" - "movi v0.16b, #0xf0\n" - "ldr d20, [x25, #-0x8]\n" - "ldr d9, [x23, #-0x8]\n" - "sshl v8.16b, v21.16b, v1.16b\n" - "sshl v31.16b, v16.16b, v1.16b\n" - "and v21.16b, v21.16b, v0.16b\n" - "and v16.16b, v16.16b, v0.16b\n" - "sub x20, x28, #0x8\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" - ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" - "ldr q27, [x25, #0x20]\n" - ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" - ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" - "sshl v15.16b, v29.16b, v1.16b\n" - "sshl v1.16b, v3.16b, v1.16b\n" - "and v29.16b, v29.16b, v0.16b\n" - "and v3.16b, v3.16b, v0.16b\n" - "ldr q0, [x25, #0x30]\n" - "fcvtl v20.4s, v20.4h\n" - ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" - "fcvtl v9.4s, v9.4h\n" - ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" - "ldr q27, [x25, #0x40]\n" - ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" - ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" - "ldr q0, [x25, #0x50]\n" - ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" - ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" - "ldr q27, [x25, #0x60]\n" - ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" - ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" - "ldr q0, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" - ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" - "ldr d27, [x20, #0x0]\n" - ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" - ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" - "fcvtl v27.4s, v27.4h\n" - "uzp1 v0.2d, v19.2d, v26.2d\n" - "uzp2 v26.2d, v19.2d, v26.2d\n" - "fmul v19.4s, v27.4s, v20.s[0]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v2.4s, v0.4s, v19.4s\n" - "ldr q19, [x23, #0x0]\n" - "uzp1 v0.2d, v18.2d, v17.2d\n" - "uzp2 v18.2d, v18.2d, v17.2d\n" - "fmul v17.4s, v27.4s, v20.s[1]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v10.4s, v26.4s, v17.4s\n" - "ldr q17, [x23, #0x10]\n" - "fmul v26.4s, v27.4s, v20.s[2]\n" - "fmul v20.4s, v27.4s, v20.s[3]\n" - "fmla v12.4s, v0.4s, v26.4s\n" - "ldr d0, [x22, #-0x8]\n" - "ldr d26, [x21, #-0x8]\n" - "fcvtl v0.4s, v0.4h\n" - "fmla v28.4s, v18.4s, v20.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x23, #0x20]\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x23, #0x40]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q19, [x23, #0x60]\n" - ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" - ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" - "uzp1 v19.2d, v20.2d, v18.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp2 v20.2d, v20.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v9.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v11.4s, v19.4s, v18.4s\n" - "ldr q18, [x22, #0x0]\n" - "fmul v19.4s, v27.4s, v9.s[1]\n" - "fmla v13.4s, v20.4s, v19.4s\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" - ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" - "ldr q17, [x23, #0x30]\n" - ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" - "ldr q17, [x23, #0x50]\n" - ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" - "ldr q17, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v9.s[2]\n" - "fmul v9.4s, v27.4s, v9.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v22.4s, v17.4s, v19.4s\n" - "ldr q17, [x22, #0x10]\n" - "movi v19.4s, #0x0\n" - ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" - "fmla v23.4s, v20.4s, v9.4s\n" - "movi v20.4s, #0x0\n" - "movi v9.4s, #0x0\n" - ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" - "ldr q18, [x22, #0x20]\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" - ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" - "ldr q18, [x22, #0x40]\n" - ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" - ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" - "ldr q18, [x22, #0x60]\n" - ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" - ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" - "ldr q17, [x22, #0x30]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" - "ldr q17, [x22, #0x50]\n" - ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" - "ldr q17, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v0.s[0]\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v25.4s, v17.4s, v19.4s\n" - "ldr q19, [x21, #0x0]\n" - "fmul v17.4s, v27.4s, v0.s[1]\n" - "fmla v5.4s, v20.4s, v17.4s\n" - "ldr q17, [x21, #0x10]\n" - "uzp1 v20.2d, v9.2d, v18.2d\n" - "uzp2 v9.2d, v9.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v0.s[2]\n" - "fmul v0.4s, v27.4s, v0.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "fmla v7.4s, v20.4s, v18.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x21, #0x20]\n" - "fmla v4.4s, v9.4s, v0.4s\n" - "movi v9.4s, #0x0\n" - "movi v0.4s, #0x0\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - "fmul v8.4s, v27.4s, v26.s[0]\n" - ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" - "ldr q17, [x21, #0x30]\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - "fmul v31.4s, v27.4s, v26.s[1]\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x21, #0x40]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - "fmul v15.4s, v27.4s, v26.s[2]\n" - "fmul v27.4s, v27.4s, v26.s[3]\n" - ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" - "ldr q1, [x21, #0x50]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q26, [x21, #0x60]\n" - ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" - ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" - "ldr q21, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" - ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" - ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" - ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" - "uzp1 v29.2d, v20.2d, v18.2d\n" - "uzp2 v21.2d, v20.2d, v18.2d\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "uzp1 v18.2d, v9.2d, v0.2d\n" - "uzp2 v16.2d, v9.2d, v0.2d\n" - "scvtf v21.4s, v21.4s, #0x4\n" - "fmla v6.4s, v29.4s, v8.4s\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v30.4s, v21.4s, v31.4s\n" - "fmla v24.4s, v18.4s, v15.4s\n" - "fmla v14.4s, v16.4s, v27.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q6, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q6, [x24, #0x0]\n" - "ldr q5, [x24, #0x10]\n" - "movi v17.16b, #0x4\n" - "movi v8.4s, #0x0\n" - "ldr q4, [x25, #0x0]\n" - "ldr q13, [x25, #0x10]\n" - "movi v27.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q31, [x24, #0x20]\n" - "ldr q14, [x24, #0x30]\n" - "movi v29.4s, #0x0\n" - "movi v22.16b, #0xf0\n" - "ldr q11, [x25, #0x20]\n" - "ldr q23, [x25, #0x30]\n" - "sshl v21.16b, v6.16b, v17.16b\n" - "sshl v16.16b, v5.16b, v17.16b\n" - "ldr q20, [x25, #0x40]\n" - "ldr q26, [x25, #0x50]\n" - "and v6.16b, v6.16b, v22.16b\n" - "and v5.16b, v5.16b, v22.16b\n" - "ldr q25, [x25, #0x60]\n" - "ldr q3, [x25, #0x70]\n" - "sshl v19.16b, v31.16b, v17.16b\n" - "sshl v18.16b, v14.16b, v17.16b\n" - "ldr d17, [x25, #-0x8]\n" - ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" - ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" - "and v31.16b, v31.16b, v22.16b\n" - ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" - ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" - "and v14.16b, v14.16b, v22.16b\n" - "sub x20, x24, #0x8\n" - "ldr d16, [x20, #0x0]\n" - "subs x21, x21, #0x1\n" - "add x25, x25, #0x88\n" - "fcvtl v17.4s, v17.4h\n" - "add x24, x24, #0x48\n" - ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" - ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" - ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" - ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" - "fcvtl v16.4s, v16.4h\n" - ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" - ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" - "fmul v23.4s, v16.4s, v17.s[0]\n" - "fmul v21.4s, v16.4s, v17.s[1]\n" - "fmul v1.4s, v16.4s, v17.s[2]\n" - "fmul v20.4s, v16.4s, v17.s[3]\n" - ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" - ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" - ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" - ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" - ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" - ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" - "uzp1 v19.2d, v8.2d, v27.2d\n" - "uzp2 v18.2d, v8.2d, v27.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp1 v17.2d, v0.2d, v29.2d\n" - "uzp2 v16.2d, v0.2d, v29.2d\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v2.4s, v19.4s, v23.4s\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v10.4s, v18.4s, v21.4s\n" - "fmla v12.4s, v17.4s, v1.4s\n" - "fmla v28.4s, v16.4s, v20.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q28, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); -#elif defined(__ARM_NEON) && defined(__aarch64__) - GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#else - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -#endif -} - -void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) - if (ggml_sve_cnt_b == QK8_0) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x20, #0x4\n" - "mov x13, %x[nr]\n" - "mov z28.s, #-0x4\n" - "mov x12, #0x88\n" - "ptrue p1.b\n" - "whilelt p0.s, XZR, x20\n" - "cmp x13, #0x10\n" - "mul x12, %x[nb], x12\n" - "blt 4f\n" - "1:" // Row loop - "add x11, %x[b_ptr], #0x10\n" - "mov x10, %x[nc]\n" - "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x28, %x[a_ptr], #0x8\n" - "mov z24.b, #0x0\n" - "mov z15.b, #0x0\n" - "mov x27, %x[nb]\n" - "add x26, x28, x12\n" - "mov z12.b, #0x0\n" - "mov z0.b, #0x0\n" - "add x25, x26, x12\n" - "mov z13.b, #0x0\n" - "mov z1.b, #0x0\n" - "add x24, x25, x12\n" - "mov z20.b, #0x0\n" - "mov z25.b, #0x0\n" - "mov z11.b, #0x0\n" - "mov z16.b, #0x0\n" - "mov z19.b, #0x0\n" - "mov z26.b, #0x0\n" - "mov z8.b, #0x0\n" - "mov z29.b, #0x0\n" - "mov z27.b, #0x0\n" - "mov z10.b, #0x0\n" - "3:" // Block loop - "ld1b { z30.b }, p1/Z, [x11]\n" - "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" - "mov z18.s, #0x0\n" - "mov z7.s, #0x0\n" - "ld1rqb { z3.b }, p1/Z, [x28]\n" - "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" - "mov z9.s, #0x0\n" - "mov z22.s, #0x0\n" - "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" - "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" - "sub x20, x11, #0x10\n" - "sub x23, x28, #0x8\n" - "lsl z31.b, z30.b, #0x4\n" - "lsl z6.b, z21.b, #0x4\n" - "ld1h { z23.s }, p1/Z, [x20]\n" - "sub x22, x26, #0x8\n" - "and z30.b, z30.b, #0xf0\n" - "and z21.b, z21.b, #0xf0\n" - "sub x21, x25, #0x8\n" - "sub x20, x24, #0x8\n" - "lsl z14.b, z4.b, #0x4\n" - "lsl z2.b, z17.b, #0x4\n" - "subs x27, x27, #0x1\n" - "add x11, x11, #0x90\n" - ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" - ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" - "and z4.b, z4.b, #0xf0\n" - ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" - ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" - "and z17.b, z17.b, #0xf0\n" - "fcvt z23.s, p1/m, z23.h\n" - ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" - ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" - ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" - ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" - "fscale z23.s, p1/m, z23.s, z28.s\n" - ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" - ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" - ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" - ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" - "add x28, x28, #0x88\n" - ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" - ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" - "ld1h { z3.s }, p0/Z, [x23]\n" - ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" - ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" - "fcvt z3.s, p1/m, z3.h\n" - "uzp1 z5.d, z18.d, z7.d\n" - "uzp2 z18.d, z18.d, z7.d\n" - "mov z3.q, z3.q[0]\n" - "uzp1 z7.d, z9.d, z22.d\n" - "uzp2 z22.d, z9.d, z22.d\n" - "fmul z9.s, z23.s, z3.s[0]\n" - "scvtf z5.s, p1/m, z5.s\n" - "scvtf z18.s, p1/m, z18.s\n" - "scvtf z7.s, p1/m, z7.s\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z24.s, p1/M, z5.s, z9.s\n" - "ld1rqb { z5.b }, p1/Z, [x26]\n" - "fmul z9.s, z23.s, z3.s[1]\n" - "fmla z15.s, p1/M, z18.s, z9.s\n" - "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" - "fmul z9.s, z23.s, z3.s[2]\n" - "fmul z3.s, z23.s, z3.s[3]\n" - "fmla z12.s, p1/M, z7.s, z9.s\n" - "mov z9.s, #0x0\n" - "ld1h { z7.s }, p0/Z, [x22]\n" - ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" - "fmla z0.s, p1/M, z22.s, z3.s\n" - "mov z22.s, #0x0\n" - "ld1h { z3.s }, p0/Z, [x21]\n" - ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" - "fcvt z7.s, p1/m, z7.h\n" - "fcvt z3.s, p1/m, z3.h\n" - ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" - ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" - "mov z7.q, z7.q[0]\n" - "mov z3.q, z3.q[0]\n" - ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" - ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" - ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" - ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" - "uzp1 z5.d, z9.d, z22.d\n" - "scvtf z5.s, p1/m, z5.s\n" - "uzp2 z22.d, z9.d, z22.d\n" - "fmul z9.s, z23.s, z7.s[0]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z13.s, p1/M, z5.s, z9.s\n" - "ld1rqb { z9.b }, p1/Z, [x25]\n" - "fmul z5.s, z23.s, z7.s[1]\n" - "fmla z1.s, p1/M, z22.s, z5.s\n" - "mov z5.s, #0x0\n" - "mov z22.s, #0x0\n" - ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" - ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" - ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" - ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" - ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" - ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" - "add x26, x26, #0x88\n" - ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" - ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" - "uzp1 z18.d, z5.d, z22.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp2 z22.d, z5.d, z22.d\n" - "fmul z5.s, z23.s, z7.s[2]\n" - "fmul z7.s, z23.s, z7.s[3]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z20.s, p1/M, z18.s, z5.s\n" - "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" - "ld1h { z5.s }, p0/Z, [x20]\n" - "fcvt z5.s, p1/m, z5.h\n" - "fmla z25.s, p1/M, z22.s, z7.s\n" - "mov z22.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" - ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" - "mov z5.q, z5.q[0]\n" - ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" - ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" - ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" - ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" - ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" - ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" - "uzp1 z9.d, z22.d, z7.d\n" - "scvtf z9.s, p1/m, z9.s\n" - "uzp2 z22.d, z22.d, z7.d\n" - "fmul z7.s, z23.s, z3.s[0]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z11.s, p1/M, z9.s, z7.s\n" - "ld1rqb { z9.b }, p1/Z, [x24]\n" - "fmul z7.s, z23.s, z3.s[1]\n" - "fmla z16.s, p1/M, z22.s, z7.s\n" - "mov z22.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" - ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" - ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" - ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" - ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" - ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" - "add x25, x25, #0x88\n" - ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" - ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" - "uzp1 z18.d, z22.d, z7.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp2 z7.d, z22.d, z7.d\n" - "fmul z22.s, z23.s, z3.s[2]\n" - "fmul z3.s, z23.s, z3.s[3]\n" - "scvtf z7.s, p1/m, z7.s\n" - "fmla z19.s, p1/M, z18.s, z22.s\n" - "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" - "fmul z22.s, z23.s, z5.s[0]\n" - "fmla z26.s, p1/M, z7.s, z3.s\n" - "mov z3.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" - ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" - "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" - ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" - ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" - "mov z9.s, #0x0\n" - ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" - "mov z31.s, #0x0\n" - ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" - "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" - "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" - ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" - "fmul z14.s, z23.s, z5.s[1]\n" - ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" - "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" - "fmul z2.s, z23.s, z5.s[2]\n" - "fmul z23.s, z23.s, z5.s[3]\n" - ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" - ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" - ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" - ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" - "add x24, x24, #0x88\n" - ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" - ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" - ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" - ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" - "uzp1 z18.d, z3.d, z7.d\n" - "uzp2 z5.d, z3.d, z7.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp1 z6.d, z9.d, z31.d\n" - "uzp2 z9.d, z9.d, z31.d\n" - "scvtf z5.s, p1/m, z5.s\n" - "fmla z8.s, p1/M, z18.s, z22.s\n" - "scvtf z6.s, p1/m, z6.s\n" - "scvtf z9.s, p1/m, z9.s\n" - "fmla z29.s, p1/M, z5.s, z14.s\n" - "fmla z27.s, p1/M, z6.s, z2.s\n" - "fmla z10.s, p1/M, z9.s, z23.s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x10, x10, #0x8\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "st1w { z24.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z15.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z12.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z0.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z13.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z1.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z20.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z25.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z11.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z16.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z19.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z26.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z8.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z29.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z27.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z10.s }, p1, [x20]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x13, x13, #0x10\n" - "cmp x13, #0x10\n" - "mov %x[res_ptr], x9\n" - "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x13, 9f\n" - "5:" // Row tail: Row loop - "add x25, %x[b_ptr], #0x10\n" - "mov x24, %x[nc]\n" - "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "mov z24.b, #0x0\n" - "mov z15.b, #0x0\n" - "add x28, %x[a_ptr], #0x8\n" - "mov x22, %x[nb]\n" - "mov z12.b, #0x0\n" - "mov z0.b, #0x0\n" - "7:" // Row tail: Block loop - "ld1b { z3.b }, p1/Z, [x25]\n" - "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" - "mov z2.s, #0x0\n" - "mov z25.s, #0x0\n" - "ld1rqb { z26.b }, p1/Z, [x28]\n" - "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" - "mov z27.s, #0x0\n" - "mov z19.s, #0x0\n" - "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" - "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" - "sub x21, x25, #0x10\n" - "sub x20, x28, #0x8\n" - "lsl z20.b, z3.b, #0x4\n" - "lsl z4.b, z6.b, #0x4\n" - "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" - "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" - "and z3.b, z3.b, #0xf0\n" - "and z6.b, z6.b, #0xf0\n" - "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" - "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" - "lsl z8.b, z29.b, #0x4\n" - "lsl z14.b, z16.b, #0x4\n" - "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" - "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" - ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" - ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" - "and z29.b, z29.b, #0xf0\n" - "ld1h { z17.s }, p1/Z, [x21]\n" - ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" - ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" - "and z16.b, z16.b, #0xf0\n" - "ld1h { z4.s }, p0/Z, [x20]\n" - "subs x22, x22, #0x1\n" - "add x28, x28, #0x88\n" - "fcvt z17.s, p1/m, z17.h\n" - "add x25, x25, #0x90\n" - ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" - ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" - "fcvt z4.s, p1/m, z4.h\n" - ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" - ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" - "fscale z17.s, p1/m, z17.s, z28.s\n" - "mov z4.q, z4.q[0]\n" - ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" - ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" - "fmul z23.s, z17.s, z4.s[0]\n" - "fmul z9.s, z17.s, z4.s[1]\n" - "fmul z21.s, z17.s, z4.s[2]\n" - "fmul z4.s, z17.s, z4.s[3]\n" - ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" - ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" - ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" - ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" - ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" - ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" - "uzp1 z31.d, z2.d, z25.d\n" - "uzp2 z13.d, z2.d, z25.d\n" - "scvtf z31.s, p1/m, z31.s\n" - "uzp1 z17.d, z27.d, z19.d\n" - "uzp2 z18.d, z27.d, z19.d\n" - "scvtf z13.s, p1/m, z13.s\n" - "fmla z24.s, p1/M, z31.s, z23.s\n" - "scvtf z17.s, p1/m, z17.s\n" - "scvtf z18.s, p1/m, z18.s\n" - "fmla z15.s, p1/M, z13.s, z9.s\n" - "fmla z12.s, p1/M, z17.s, z21.s\n" - "fmla z0.s, p1/M, z18.s, z4.s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x13, #0x1\n" - "st1w { z24.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x13, #0x2\n" - "st1w { z15.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x13, #0x3\n" - "st1w { z12.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "st1w { z0.s }, p1, [x20]\n" - "8:" // Row tail: Accumulator store skip - "subs x24, x24, #0x8\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "bne 6b\n" - "subs x13, x13, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x12\n" - "mov %x[res_ptr], x23\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" - ); - return; - } - else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { - GGML_ASSERT((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) && - "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " - "performance"); - } - else if (ggml_cpu_has_neon()) { - GGML_ASSERT(((ggml_cpu_has_sve() && (ggml_sve_cnt_b == QK8_0)) || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " - "quantization format for optimal performance"); - } -#endif -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - GGML_ASSERT(ggml_cpu_has_sve() && - "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); -#elif defined(__ARM_NEON) && defined(__aarch64__) - GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && - "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " - "performance"); -#elif defined(__AVX2__) || defined(__AVX512F__) - const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; - const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy; - int64_t b_nb = n / QK4_0; - int64_t y = 0; - // Mask to mask out nibbles from packed bytes - const __m256i m4b = _mm256_set1_epi8(0x0F); - const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3); - // Lookup table to convert signed nibbles to signed bytes - __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); - signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); - // Permute mask used for easier vector processing at later stages - __m256i requiredOrder = _mm256_set_epi32(3 ,2 ,1 ,0, 7 ,6, 5, 4); - - // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation - int anr = nr - nr %16; // Used to align nr with boundary of 16 - - for (; y < anr / 4; y += 4) { - const block_q8_0x4 * a_ptrs[4]; - - a_ptrs[0] = a_ptr_start + (y * nb); - for (int i = 0; i < 3; ++i) { - a_ptrs[i + 1] = a_ptrs[i] + nb; - } - - // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < nc / 8; x++) { - - const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); - - // Master FP accumulators - __m256 acc_rows[16]; - for (int i = 0; i < 16; i++) { - acc_rows[i] = _mm256_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - // 4-bit -> 8-bit - Sign is maintained - const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) - const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) - - const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) - const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) - - const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) - const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) - - const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) - const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) - - // Shuffle pattern one - right side input - const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) - const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) - - const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) - const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) - - const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) - const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) - - const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) - const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) - - // Shuffle pattern two - right side input - - const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) - const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) - - const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) - const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) - - const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) - const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) - - const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) - const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) - - // Scale values - Load the wight scale values of block_q4_0x8 - const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); - - // Process LHS in groups of four - for (int rp = 0; rp < 4; rp++) { - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated into a 256 bit vector - __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); - __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); - __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); - __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); - __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); - __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); - __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); - __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); - __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); - __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); - __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); - __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); - - // Shuffle pattern one - left side input - const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - // Straighten out to make 4 row vectors - __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); - __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); - __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); - __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); - - // Multiply with appropiate scales and accumulate - acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); - acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); - acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); - acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); - } - } - - // Store the accumulated values - for (int i = 0; i < 16; i++) { - _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } - - // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation - for (; y < nr / 4; y ++) { - - const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); - - // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - for (int64_t x = 0; x < nc / 8; x++) { - - const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); - - // Master FP accumulators - __m256 acc_rows[4]; - for (int i = 0; i < 4; i++) { - acc_rows[i] = _mm256_setzero_ps(); - } - - for (int64_t b = 0; b < nb; b++) { - // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); - const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); - const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); - const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); - - // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess - const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); - const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); - const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); - - // 4-bit -> 8-bit - Sign is maintained - const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) - const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) - - const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) - const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) - - const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) - const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) - - const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) - const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) - - // Shuffle pattern one - right side input - const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) - const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) - - const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) - const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) - - const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) - const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) - - const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) - const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) - - // Shuffle pattern two - right side input - - const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) - const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) - - const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) - const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) - - const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) - const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) - - const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) - const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) - - // Scale values - Load the wight scale values of block_q4_0x8 - const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); - - // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 - // Loaded as set of 128 bit vectors and repeated into a 256 bit vector - __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); - __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); - __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); - __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); - __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); - __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); - __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); - __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); - __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); - __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); - __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); - __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); - - // Shuffle pattern one - left side input - - const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - - const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - - const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - - const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) - - // Shuffle pattern two - left side input - - const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - - const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - - const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - - const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) - - // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane - // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); - - // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block - __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); - __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); - __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); - __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); - - - // Straighten out to make 4 row vectors - __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); - __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); - __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); - __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); - - // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes - const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask); - - // Multiply with appropiate scales and accumulate - acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); - acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); - acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); - acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); - } - - // Store the accumulated values - for (int i = 0; i < 4; i++) { - _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); - } - } - } -#else - float sumf[4][8]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -#endif -} diff --git a/ggml/src/ggml-aarch64.h b/ggml/src/ggml-aarch64.h deleted file mode 100644 index 517babaf1..000000000 --- a/ggml/src/ggml-aarch64.h +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. -#pragma once - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" - -#include "ggml.h" - -// GGML internal header - -#ifdef __cplusplus -extern "C" { -#endif - -// Quantization -void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - -void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave); - -// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") -size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); - -// GEMV -void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); - -// GEMM -void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); - -#ifdef __cplusplus -} -#endif - diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index e485326ab..9a3bf9f29 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -14,7 +14,7 @@ //#define GGML_ALLOCATOR_DEBUG -//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__) +//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__) #define AT_PRINTF(...) @@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml return true; } +// ops that return true for this function must not use restrict pointers for their backend implementations static bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { case GGML_OP_SCALE: @@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) { case GGML_OP_LOG: case GGML_OP_UNARY: case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: return true; default: @@ -89,7 +94,7 @@ void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tenso size = GGML_PAD(size, talloc->alignment); if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) { - fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n", + GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n", __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset); GGML_ABORT("not enough space in the buffer"); } @@ -172,7 +177,7 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz best_fit_block = alloc->n_free_blocks - 1; } else { // this should never happen - fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", + GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", __func__, size, max_avail); GGML_ABORT("not enough space in the buffer"); } @@ -209,16 +214,16 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz } } } - fprintf(stderr, "max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); + GGML_LOG_DEBUG("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0); for (int i = 0; i < 1024; i++) { if (alloc->allocated_tensors[i].tensor) { - fprintf(stderr, "%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name, + GGML_LOG_DEBUG("%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name, alloc->allocated_tensors[i].offset, alloc->allocated_tensors[i].offset + ggml_nbytes(alloc->allocated_tensors[i].tensor), ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0); } } - fprintf(stderr, "\n"); + GGML_LOG_DEBUG("\n"); } #endif @@ -294,6 +299,12 @@ static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) { alloc->free_blocks[0].offset = 0; alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows alloc->max_size = 0; + +#ifdef GGML_ALLOCATOR_DEBUG + for (int i = 0; i < 1024; i++) { + alloc->allocated_tensors[i].tensor = NULL; + } +#endif } static struct ggml_dyn_tallocr * ggml_dyn_tallocr_new(size_t alignment) { @@ -342,7 +353,6 @@ struct tensor_alloc { }; struct leaf_alloc { - int buffer_id; struct tensor_alloc leaf; }; @@ -461,18 +471,12 @@ static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) { return ggml_gallocr_hash_get(galloc, t)->allocated; } -static void ggml_gallocr_set_node_offset(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, size_t offset) { - struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); - hn->buffer_id = buffer_id; - hn->offset = offset; - hn->allocated = true; -} - static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) { return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated; } static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) { + GGML_ASSERT(buffer_id >= 0); struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) { @@ -535,7 +539,6 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor size_t offset = ggml_dyn_tallocr_alloc(alloc, size, node); hn->buffer_id = buffer_id; hn->offset = offset; - return; } } @@ -734,7 +737,6 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c for (int i = 0; i < graph->n_leafs; i++) { struct ggml_tensor * leaf = graph->leafs[i]; struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf); - galloc->leaf_allocs[i].buffer_id = hn->buffer_id; if (leaf->view_src || leaf->data) { galloc->leaf_allocs[i].leaf.buffer_id = -1; galloc->leaf_allocs[i].leaf.offset = SIZE_MAX; @@ -762,13 +764,13 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views if (new_size > cur_size || galloc->buffers[i] == NULL) { #ifndef NDEBUG - fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif ggml_backend_buffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); if (galloc->buffers[i] == NULL) { - fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); + GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); return false; } ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); @@ -812,21 +814,25 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * } static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) { - size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); + size_t node_size = 0; + if (!node->data && !node->view_src) { + GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API + node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); + } return talloc->size_max >= node_size; } static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph * graph) { if (galloc->n_nodes != graph->n_nodes) { #ifndef NDEBUG - fprintf(stderr, "%s: graph has different number of nodes\n", __func__); + GGML_LOG_DEBUG("%s: graph has different number of nodes\n", __func__); #endif return true; } if (galloc->n_leafs != graph->n_leafs) { #ifndef NDEBUG - fprintf(stderr, "%s: graph has different number of leafs\n", __func__); + GGML_LOG_DEBUG("%s: graph has different number of leafs\n", __func__); #endif return true; } @@ -837,7 +843,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) { #ifndef NDEBUG - fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name); + GGML_LOG_DEBUG("%s: node %s is not valid\n", __func__, node->name); #endif return true; } @@ -849,7 +855,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph } if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) { #ifndef NDEBUG - fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name); + GGML_LOG_DEBUG("%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name); #endif return true; } @@ -863,14 +869,14 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) if (ggml_gallocr_needs_realloc(galloc, graph)) { if (galloc->n_buffers == 1) { #ifndef NDEBUG - fprintf(stderr, "%s: reallocating buffers automatically\n", __func__); + GGML_LOG_DEBUG("%s: reallocating buffers automatically\n", __func__); #endif if (!ggml_gallocr_reserve(galloc, graph)) { return false; } } else { #ifndef NDEBUG - fprintf(stderr, "%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__); + GGML_LOG_DEBUG("%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__); #endif return false; } @@ -934,7 +940,7 @@ static bool alloc_tensor_range(struct ggml_context * ctx, ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size); if (buffer == NULL) { #ifndef NDEBUG - fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size); + GGML_LOG_DEBUG("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size); #endif for (size_t i = 0; i < *n_buffers; i++) { ggml_backend_buffer_free((*buffers)[i]); @@ -984,7 +990,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte } if (this_size > max_size) { - fprintf(stderr, "%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n", + GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n", __func__, t->name, ggml_backend_buft_name(buft), this_size, max_size); @@ -1016,7 +1022,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte if (n_buffers == 0) { #ifndef NDEBUG - fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__); + GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__); #endif return NULL; } diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 36ca37086..d1c2d76d8 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -8,145 +8,247 @@ extern "C" { #endif - // - // Backend buffer - // + #define GGML_BACKEND_API_VERSION 1 - // buffer type - typedef void * ggml_backend_buffer_type_context_t; + // + // Backend buffer type + // struct ggml_backend_buffer_type_i { - const char * (*GGML_CALL get_name) (ggml_backend_buffer_type_t buft); + const char * (*get_name) (ggml_backend_buffer_type_t buft); // allocate a buffer of this type - ggml_backend_buffer_t (*GGML_CALL alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size); + ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size); // tensor alignment - size_t (*GGML_CALL get_alignment) (ggml_backend_buffer_type_t buft); - // max buffer size that can be allocated - size_t (*GGML_CALL get_max_size) (ggml_backend_buffer_type_t buft); - // data size needed to allocate the tensor, including padding - size_t (*GGML_CALL get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); - // check if tensor data is in host memory - bool (*GGML_CALL is_host) (ggml_backend_buffer_type_t buft); + size_t (*get_alignment) (ggml_backend_buffer_type_t buft); + // (optional) max buffer size that can be allocated (defaults to SIZE_MAX) + size_t (*get_max_size) (ggml_backend_buffer_type_t buft); + // (optional) data size needed to allocate the tensor, including padding (defaults to ggml_nbytes) + size_t (*get_alloc_size)(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); + // (optional) check if tensor data is in host memory and uses standard ggml tensor layout (defaults to false) + bool (*is_host) (ggml_backend_buffer_type_t buft); }; struct ggml_backend_buffer_type { struct ggml_backend_buffer_type_i iface; - ggml_backend_buffer_type_context_t context; + ggml_backend_dev_t device; + void * context; }; - // buffer - typedef void * ggml_backend_buffer_context_t; + // + // Backend buffer + // struct ggml_backend_buffer_i { - const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer); - void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer); - void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer); - void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer - void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value); - void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras + // (optional) free the buffer + void (*free_buffer) (ggml_backend_buffer_t buffer); + // base address of the buffer + void * (*get_base) (ggml_backend_buffer_t buffer); + // (optional) initialize a tensor in the buffer (eg. add tensor extras) + void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + // tensor data access + void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported) + bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); + // clear the entire buffer + void (*clear) (ggml_backend_buffer_t buffer, uint8_t value); + // (optional) reset any internal state due to tensor initialization, such as tensor extras + void (*reset) (ggml_backend_buffer_t buffer); }; struct ggml_backend_buffer { struct ggml_backend_buffer_i iface; ggml_backend_buffer_type_t buft; - ggml_backend_buffer_context_t context; + void * context; size_t size; enum ggml_backend_buffer_usage usage; }; - GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init( - ggml_backend_buffer_type_t buft, - struct ggml_backend_buffer_i iface, - ggml_backend_buffer_context_t context, - size_t size); + GGML_API ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, + struct ggml_backend_buffer_i iface, + void * context, + size_t size); // do not use directly, use ggml_backend_tensor_copy instead - bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst); + // multi-buffer // buffer that contains a collection of buffers - GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers); - GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer); - GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + GGML_API ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers); + GGML_API bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); // - // Backend + // Backend (stream) // - typedef void * ggml_backend_context_t; - struct ggml_backend_i { - const char * (*GGML_CALL get_name)(ggml_backend_t backend); + const char * (*get_name)(ggml_backend_t backend); - void (*GGML_CALL free)(ggml_backend_t backend); - - // buffer allocation - ggml_backend_buffer_type_t (*GGML_CALL get_default_buffer_type)(ggml_backend_t backend); + void (*free)(ggml_backend_t backend); // (optional) asynchronous tensor data access - void (*GGML_CALL set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); + void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); - // (optional) complete all pending operations - void (*GGML_CALL synchronize)(ggml_backend_t backend); + // (optional) complete all pending operations (required if the backend supports async operations) + void (*synchronize)(ggml_backend_t backend); - // compute graph with a plan (not used currently) - // create a new plan for a graph - ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph); - void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + // (optional) graph plans (not used currently) + // compute graph with a plan + ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph); + void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology - void (*GGML_CALL graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph); + void (*graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph); // compute the graph with the plan - enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); + enum ggml_status (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); - // compute graph without a plan (async) - enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph); - - // check if the backend can compute an operation - bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); - - // check if the backend can use tensors allocated in a buffer type - bool (*GGML_CALL supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft); - - // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer - // these should be expensive operations with large batch sizes that may benefit from running on this backend - // even if the weight has to be copied from the CPU temporarily - bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op); + // compute graph (always async if supported by the backend) + enum ggml_status (*graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph); // (optional) event synchronization - // create a new event that can record events on this backend instance - ggml_backend_event_t (*GGML_CALL event_new) (ggml_backend_t backend); - void (*GGML_CALL event_free) (ggml_backend_event_t event); - // record an event on the backend instance that created it - void (*GGML_CALL event_record) (ggml_backend_event_t event); - // wait for an event on on a different backend instance - void (*GGML_CALL event_wait) (ggml_backend_t backend, ggml_backend_event_t event); - // block until an event is recorded - void (*GGML_CALL event_synchronize) (ggml_backend_event_t event); + // record an event on this stream + void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); + // wait for an event on on a different stream + void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); }; struct ggml_backend { ggml_guid_t guid; - struct ggml_backend_i iface; - ggml_backend_context_t context; + ggml_backend_dev_t device; + void * context; }; struct ggml_backend_event { - ggml_backend_t backend; + struct ggml_backend_device * device; void * context; }; // - // Backend registry + // Backend device // - typedef ggml_backend_t (*GGML_CALL ggml_backend_init_fn)(const char * params, void * user_data); + // Note: if additional properties are needed, we should add a struct with all of them + // the current functions to obtain the properties can remain, since they are more convenient for often used properties + struct ggml_backend_device_i { + // device name: short identifier for this device, such as "CPU" or "CUDA0" + const char * (*get_name)(ggml_backend_dev_t dev); - GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data); + // device description: short informative description of the device, could be the model name + const char * (*get_description)(ggml_backend_dev_t dev); + + // device memory in bytes + void (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total); + + // device type + enum ggml_backend_dev_type (*get_type)(ggml_backend_dev_t dev); + + // device properties + void (*get_props)(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props); + + // backend (stream) initialization + ggml_backend_t (*init_backend)(ggml_backend_dev_t dev, const char * params); + + // preferred buffer type + ggml_backend_buffer_type_t (*get_buffer_type)(ggml_backend_dev_t dev); + + // (optional) host buffer type (in system memory, typically this is a pinned memory buffer for faster transfers between host and device) + ggml_backend_buffer_type_t (*get_host_buffer_type)(ggml_backend_dev_t dev); + + // (optional) buffer from pointer: create a buffer from a host pointer (useful for memory mapped models and importing data from other libraries) + ggml_backend_buffer_t (*buffer_from_host_ptr)(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size); + + // check if the backend can compute an operation + bool (*supports_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op); + + // check if the backend can use tensors allocated in a buffer type + bool (*supports_buft)(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft); + + // (optional) check if the backend wants to run an operation, even if the weights are allocated in an incompatible buffer + // these should be expensive operations that may benefit from running on this backend instead of the CPU backend + bool (*offload_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op); + + // (optional) event synchronization + ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev); + void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event); + void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event); + }; + + struct ggml_backend_device { + struct ggml_backend_device_i iface; + ggml_backend_reg_t reg; + void * context; + }; + + // + // Backend (reg) + // + + struct ggml_backend_reg_i { + const char * (*get_name)(ggml_backend_reg_t reg); + + // enumerate available devices + size_t (*get_device_count)(ggml_backend_reg_t reg); + ggml_backend_dev_t (*get_device)(ggml_backend_reg_t reg, size_t index); + + // (optional) get a pointer to a function in the backend + // backends can add custom functions that are not part of the standard ggml-backend interface + void * (*get_proc_address)(ggml_backend_reg_t reg, const char * name); + }; + + struct ggml_backend_reg { + int api_version; // initialize to GGML_BACKEND_API_VERSION + struct ggml_backend_reg_i iface; + void * context; + }; + + // Internal backend registry API + GGML_API void ggml_backend_register(ggml_backend_reg_t reg); + + // Add backend dynamic loading support to the backend + + // Initialize the backend + typedef ggml_backend_reg_t (*ggml_backend_init_t)(void); + // Optional: obtain a score for the backend based on the system configuration + // Higher scores are preferred, 0 means the backend is not supported in the current system + typedef int (*ggml_backend_score_t)(void); + +#ifdef GGML_BACKEND_DL +# ifdef __cplusplus +# define GGML_BACKEND_DL_IMPL(reg_fn) \ + extern "C" { \ + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void); \ + } \ + ggml_backend_reg_t ggml_backend_init(void) { \ + return reg_fn(); \ + } +# define GGML_BACKEND_DL_SCORE_IMPL(score_fn) \ + extern "C" { \ + GGML_BACKEND_API int ggml_backend_score(void); \ + } \ + int ggml_backend_score(void) { \ + return score_fn(); \ + } +# else +# define GGML_BACKEND_DL_IMPL(reg_fn) \ + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_init(void); \ + ggml_backend_reg_t ggml_backend_init(void) { \ + return reg_fn(); \ + } +# define GGML_BACKEND_DL_SCORE_IMPL(score_fn) \ + GGML_BACKEND_API int ggml_backend_score(void); \ + int ggml_backend_score(void) { \ + return score_fn(); \ + } +# endif +#else +# define GGML_BACKEND_DL_IMPL(reg_fn) +# define GGML_BACKEND_DL_SCORE_IMPL(score_fn) +#endif #ifdef __cplusplus } diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp new file mode 100644 index 000000000..955ed505f --- /dev/null +++ b/ggml/src/ggml-backend-reg.cpp @@ -0,0 +1,582 @@ +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +#elif defined(__APPLE__) +# include +# include +#else +# include +# include +#endif + +// Backend registry +#ifdef GGML_USE_CPU +#include "ggml-cpu.h" +#endif + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#ifdef GGML_USE_SYCL +#include "ggml-sycl.h" +#endif + +#ifdef GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif + +#ifdef GGML_USE_OPENCL +#include "ggml-opencl.h" +#endif + +#ifdef GGML_USE_BLAS +#include "ggml-blas.h" +#endif + +#ifdef GGML_USE_RPC +#include "ggml-rpc.h" +#endif + +#ifdef GGML_USE_CANN +#include "ggml-cann.h" +#endif + +#ifdef GGML_USE_KOMPUTE +#include "ggml-kompute.h" +#endif + +// disable C++17 deprecation warning for std::codecvt_utf8 +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + +static std::wstring utf8_to_utf16(const std::string & str) { + std::wstring_convert> converter; + return converter.from_bytes(str); +} + +static std::string utf16_to_utf8(const std::wstring & str) { + std::wstring_convert> converter; + return converter.to_bytes(str); +} + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif + +#ifdef _WIN32 + +using dl_handle = std::remove_pointer_t; + +struct dl_handle_deleter { + void operator()(HMODULE handle) { + FreeLibrary(handle); + } +}; + +static dl_handle * dl_load_library(const std::wstring & path) { + // suppress error dialogs for missing DLLs + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.c_str()); + + SetErrorMode(old_mode); + + return handle; +} + +static void * dl_get_sym(dl_handle * handle, const char * name) { + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + void * p = (void *) GetProcAddress(handle, name); + + SetErrorMode(old_mode); + + return p; +} + +#else + +using dl_handle = void; + +struct dl_handle_deleter { + void operator()(void * handle) { + dlclose(handle); + } +}; + +static void * dl_load_library(const std::wstring & path) { + dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL); + + return handle; +} + +static void * dl_get_sym(dl_handle * handle, const char * name) { + return dlsym(handle, name); +} + +#endif + +using dl_handle_ptr = std::unique_ptr; + +struct ggml_backend_reg_entry { + ggml_backend_reg_t reg; + dl_handle_ptr handle; +}; + +struct ggml_backend_registry { + std::vector backends; + std::vector devices; + + ggml_backend_registry() { +#ifdef GGML_USE_CUDA + register_backend(ggml_backend_cuda_reg()); +#endif +#ifdef GGML_USE_METAL + register_backend(ggml_backend_metal_reg()); +#endif +#ifdef GGML_USE_SYCL + register_backend(ggml_backend_sycl_reg()); +#endif +#ifdef GGML_USE_VULKAN + register_backend(ggml_backend_vk_reg()); +#endif +#ifdef GGML_USE_OPENCL + register_backend(ggml_backend_opencl_reg()); +#endif +#ifdef GGML_USE_CANN + register_backend(ggml_backend_cann_reg()); +#endif +#ifdef GGML_USE_BLAS + register_backend(ggml_backend_blas_reg()); +#endif +#ifdef GGML_USE_RPC + register_backend(ggml_backend_rpc_reg()); +#endif +#ifdef GGML_USE_KOMPUTE + register_backend(ggml_backend_kompute_reg()); +#endif +#ifdef GGML_USE_CPU + register_backend(ggml_backend_cpu_reg()); +#endif + } + + ~ggml_backend_registry() { + // FIXME: backends cannot be safely unloaded without a function to destroy all the backend resources, + // since backend threads may still be running and accessing resources from the dynamic library + for (auto & entry : backends) { + if (entry.handle) { + entry.handle.release(); // NOLINT + } + } + } + + void register_backend(ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) { + if (!reg) { + return; + } + +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", + __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg)); +#endif + backends.push_back({ reg, std::move(handle) }); + for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { + register_device(ggml_backend_reg_dev_get(reg, i)); + } + } + + void register_device(ggml_backend_dev_t device) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device)); +#endif + devices.push_back(device); + } + + ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) { + dl_handle_ptr handle { dl_load_library(path) }; + if (!handle) { + if (!silent) { + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str()); + } + return nullptr; + } + + auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); + if (score_fn && score_fn() == 0) { + if (!silent) { + GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str()); + } + return nullptr; + } + + auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init"); + if (!backend_init_fn) { + if (!silent) { + GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str()); + } + return nullptr; + } + + ggml_backend_reg_t reg = backend_init_fn(); + if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) { + if (!silent) { + if (!reg) { + GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str()); + } else { + GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n", + __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); + } + } + return nullptr; + } + + GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str()); + + register_backend(reg, std::move(handle)); + + return reg; + } + + void unload_backend(ggml_backend_reg_t reg, bool silent) { + auto it = std::find_if(backends.begin(), backends.end(), + [reg](const ggml_backend_reg_entry & entry) { return entry.reg == reg; }); + + if (it == backends.end()) { + if (!silent) { + GGML_LOG_ERROR("%s: backend not found\n", __func__); + } + return; + } + + if (!silent) { + GGML_LOG_DEBUG("%s: unloading %s backend\n", __func__, ggml_backend_reg_name(reg)); + } + + // remove devices + devices.erase( + std::remove_if(devices.begin(), devices.end(), + [reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }), + devices.end()); + + // remove backend + backends.erase(it); + } +}; + +static ggml_backend_registry & get_reg() { + static ggml_backend_registry reg; + return reg; +} + +// Internal API +void ggml_backend_register(ggml_backend_reg_t reg) { + get_reg().register_backend(reg); +} + +void ggml_backend_device_register(ggml_backend_dev_t device) { + get_reg().register_device(device); +} + +// Backend (reg) enumeration +static bool striequals(const char * a, const char * b) { + for (; *a && *b; a++, b++) { + if (std::tolower(*a) != std::tolower(*b)) { + return false; + } + } + return *a == *b; +} + +size_t ggml_backend_reg_count() { + return get_reg().backends.size(); +} + +ggml_backend_reg_t ggml_backend_reg_get(size_t index) { + GGML_ASSERT(index < ggml_backend_reg_count()); + return get_reg().backends[index].reg; +} + +ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) { + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + ggml_backend_reg_t reg = ggml_backend_reg_get(i); + if (striequals(ggml_backend_reg_name(reg), name)) { + return reg; + } + } + return nullptr; +} + +// Device enumeration +size_t ggml_backend_dev_count() { + return get_reg().devices.size(); +} + +ggml_backend_dev_t ggml_backend_dev_get(size_t index) { + GGML_ASSERT(index < ggml_backend_dev_count()); + return get_reg().devices[index]; +} + +ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) { + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (striequals(ggml_backend_dev_name(dev), name)) { + return dev; + } + } + return nullptr; +} + +ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) { + for (size_t i = 0; i < ggml_backend_dev_count(); i++) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == type) { + return dev; + } + } + return nullptr; +} + +// Convenience functions +ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) { + ggml_backend_dev_t dev = ggml_backend_dev_by_name(name); + if (!dev) { + return nullptr; + } + return ggml_backend_dev_init(dev, params); +} + +ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) { + ggml_backend_dev_t dev = ggml_backend_dev_by_type(type); + if (!dev) { + return nullptr; + } + return ggml_backend_dev_init(dev, params); +} + +ggml_backend_t ggml_backend_init_best(void) { + ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); + if (!dev) { + dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + } + if (!dev) { + return nullptr; + } + return ggml_backend_dev_init(dev, nullptr); +} + +// Dynamic loading +ggml_backend_reg_t ggml_backend_load(const char * path) { + return get_reg().load_backend(utf8_to_utf16(path), false); +} + +void ggml_backend_unload(ggml_backend_reg_t reg) { + get_reg().unload_backend(reg, true); +} + +static std::wstring get_executable_path() { +#if defined(__APPLE__) + // get executable path + std::vector path; + uint32_t size; + while (true) { + size = path.size(); + if (_NSGetExecutablePath(path.data(), &size) == 0) { + break; + } + path.resize(size); + } + std::string base_path(path.data(), size); + // remove executable name + auto last_slash = base_path.find_last_of('/'); + if (last_slash != std::string::npos) { + base_path = base_path.substr(0, last_slash); + } + return utf8_to_utf16(base_path + "/"); +#elif defined(__linux__) || defined(__FreeBSD__) + std::string base_path = "."; + std::vector path(1024); + while (true) { + // get executable path +# if defined(__linux__) + ssize_t len = readlink("/proc/self/exe", path.data(), path.size()); +# elif defined(__FreeBSD__) + ssize_t len = readlink("/proc/curproc/file", path.data(), path.size()); +# endif + if (len == -1) { + break; + } + if (len < (ssize_t) path.size()) { + base_path = std::string(path.data(), len); + // remove executable name + auto last_slash = base_path.find_last_of('/'); + if (last_slash != std::string::npos) { + base_path = base_path.substr(0, last_slash); + } + break; + } + path.resize(path.size() * 2); + } + + return utf8_to_utf16(base_path + "/"); +#elif defined(_WIN32) + std::vector path(MAX_PATH); + DWORD len = GetModuleFileNameW(NULL, path.data(), path.size()); + if (len == 0) { + return {}; + } + std::wstring base_path(path.data(), len); + // remove executable name + auto last_slash = base_path.find_last_of('\\'); + if (last_slash != std::string::npos) { + base_path = base_path.substr(0, last_slash); + } + return base_path + L"\\"; +#else + return {}; +#endif +} + +static std::wstring backend_filename_prefix() { +#ifdef _WIN32 + return L"ggml-"; +#else + return L"libggml-"; +#endif +} + +static std::wstring backend_filename_suffix() { +#ifdef _WIN32 + return L".dll"; +#else + return L".so"; +#endif +} + +static std::wstring path_separator() { +#ifdef _WIN32 + return L"\\"; +#else + return L"/"; +#endif +} + +static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) { + // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths + // TODO: search system paths + std::wstring file_prefix = backend_filename_prefix() + utf8_to_utf16(name) + L"-"; + std::vector search_paths; + if (user_search_path == nullptr) { + search_paths.push_back(L"." + path_separator()); + search_paths.push_back(get_executable_path()); + } else { + search_paths.push_back(utf8_to_utf16(user_search_path) + path_separator()); + } + + int best_score = 0; + std::wstring best_path; + + namespace fs = std::filesystem; + for (const auto & search_path : search_paths) { + if (!fs::exists(search_path)) { + continue; + } + fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); + for (const auto & entry : dir_it) { + if (entry.is_regular_file()) { + std::wstring filename = entry.path().filename().wstring(); + std::wstring ext = entry.path().extension().wstring(); + if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { + dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; + if (!handle && !silent) { + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); + } + if (handle) { + auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); + if (score_fn) { + int s = score_fn(); +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); +#endif + if (s > best_score) { + best_score = s; + best_path = entry.path().wstring(); + } + } else { + if (!silent) { + GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); + } + } + } + } + } + } + } + + if (best_score == 0) { + // try to load the base backend + for (const auto & search_path : search_paths) { + std::wstring path = search_path + backend_filename_prefix() + utf8_to_utf16(name) + backend_filename_suffix(); + if (fs::exists(path)) { + return get_reg().load_backend(path, silent); + } + } + return nullptr; + } + + return get_reg().load_backend(best_path, silent); +} + +void ggml_backend_load_all() { + ggml_backend_load_all_from_path(nullptr); +} + +void ggml_backend_load_all_from_path(const char * dir_path) { +#ifdef NDEBUG + bool silent = true; +#else + bool silent = false; +#endif + + ggml_backend_load_best("blas", silent, dir_path); + ggml_backend_load_best("cann", silent, dir_path); + ggml_backend_load_best("cuda", silent, dir_path); + ggml_backend_load_best("hip", silent, dir_path); + ggml_backend_load_best("kompute", silent, dir_path); + ggml_backend_load_best("metal", silent, dir_path); + ggml_backend_load_best("rpc", silent, dir_path); + ggml_backend_load_best("sycl", silent, dir_path); + ggml_backend_load_best("vulkan", silent, dir_path); + ggml_backend_load_best("opencl", silent, dir_path); + ggml_backend_load_best("musa", silent, dir_path); + ggml_backend_load_best("cpu", silent, dir_path); + // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend + const char * backend_path = std::getenv("GGML_BACKEND_PATH"); + if (backend_path) { + ggml_backend_load(backend_path); + } +} diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.cpp similarity index 71% rename from ggml/src/ggml-backend.c rename to ggml/src/ggml-backend.cpp index b5d9301a7..dba7be33b 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.cpp @@ -1,3 +1,14 @@ +// Note: porting this file to C++ is a work in progress + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + +#include "ggml-backend.h" #include "ggml-backend-impl.h" #include "ggml-alloc.h" #include "ggml-impl.h" @@ -8,9 +19,14 @@ #include #include #include +#include +#include +#ifdef __APPLE__ +#include +#include +#endif -#define MAX(a, b) ((a) > (b) ? (a) : (b)) // backend buffer type @@ -18,7 +34,12 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { return buft->iface.get_name(buft); } -GGML_CALL ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + if (size == 0) { + // return a dummy buffer for zero-sized allocations + return ggml_backend_buffer_init(buft, {}, NULL, 0); + } + return buft->iface.alloc_buffer(buft, size); } @@ -34,7 +55,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { return SIZE_MAX; } -GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) { +size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) { // get_alloc_size is optional, defaults to ggml_nbytes if (buft->iface.get_alloc_size) { size_t size = buft->iface.get_alloc_size(buft, tensor); @@ -51,16 +72,18 @@ bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { return false; } +ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) { + return buft->device; +} + // backend buffer -GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init( - ggml_backend_buffer_type_t buft, - struct ggml_backend_buffer_i iface, - ggml_backend_buffer_context_t context, - size_t size) { - ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer)); - - (*buffer) = (struct ggml_backend_buffer) { +ggml_backend_buffer_t ggml_backend_buffer_init( + ggml_backend_buffer_type_t buft, + struct ggml_backend_buffer_i iface, + void * context, + size_t size) { + ggml_backend_buffer_t buffer = new ggml_backend_buffer { /* .interface = */ iface, /* .buft = */ buft, /* .context = */ context, @@ -72,7 +95,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init( } const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) { - return buffer->iface.get_name(buffer); + return ggml_backend_buft_name(ggml_backend_buffer_get_type(buffer)); } void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { @@ -83,7 +106,7 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { if (buffer->iface.free_buffer != NULL) { buffer->iface.free_buffer(buffer); } - free(buffer); + delete buffer; } size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { @@ -91,6 +114,11 @@ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { } void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { + // get_base is optional if the buffer is zero-sized + if (buffer->size == 0) { + return NULL; + } + void * base = buffer->iface.get_base(buffer); GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); @@ -98,14 +126,23 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { return base; } -GGML_CALL void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { // init_tensor is optional if (buffer->iface.init_tensor) { buffer->iface.init_tensor(buffer, tensor); } } -size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) { +void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + // clear is optional if the buffer is zero-sized + if (buffer->size == 0) { + return; + } + + buffer->iface.clear(buffer, value); +} + +size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) { return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer)); } @@ -117,10 +154,6 @@ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct g return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor); } -void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - buffer->iface.clear(buffer, value); -} - bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) { return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer)); } @@ -181,7 +214,7 @@ void ggml_backend_free(ggml_backend_t backend) { } ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) { - return backend->iface.get_default_buffer_type(backend); + return ggml_backend_dev_buffer_type(backend->device); } ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size) { @@ -218,32 +251,49 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten } } -GGML_CALL void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + if (size == 0) { + return; + } + GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); - if (!size) { - return; - } - buf->iface.set_tensor(buf, tensor, data, offset, size); } -GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { +void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + if (size == 0) { + return; + } + GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); - if (!size) { + buf->iface.get_tensor(buf, tensor, data, offset, size); +} + +void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + if (size == 0) { return; } - buf->iface.get_tensor(buf, tensor, data, offset, size); + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not implemented by backend buffer"); + + buf->iface.memset_tensor(buf, tensor, value, offset, size); } void ggml_backend_synchronize(ggml_backend_t backend) { @@ -283,18 +333,19 @@ enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct } bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - return backend->iface.supports_op(backend, op); + return ggml_backend_dev_supports_op(backend->device, op); } bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - return backend->iface.supports_buft(backend, buft); + return ggml_backend_dev_supports_buft(backend->device, buft); } bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) { - if (backend->iface.offload_op != NULL) { - return backend->iface.offload_op(backend, op); - } - return false; + return ggml_backend_dev_offload_op(backend->device, op); +} + +ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { + return backend->device; } // backend copy @@ -327,7 +378,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src)); } else if (!ggml_backend_buffer_copy_tensor(src, dst)) { #ifndef NDEBUG - fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); + GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); #endif size_t nbytes = ggml_nbytes(src); void * data = malloc(nbytes); @@ -359,30 +410,31 @@ void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t b // events -ggml_backend_event_t ggml_backend_event_new(ggml_backend_t backend) { - if (backend->iface.event_new == NULL) { +ggml_backend_event_t ggml_backend_event_new(ggml_backend_dev_t device) { + // null device is allowed for the transition period to the device interface + if (device == NULL || device->iface.event_new == NULL) { return NULL; } - return backend->iface.event_new(backend); + return device->iface.event_new(device); } void ggml_backend_event_free(ggml_backend_event_t event) { if (event == NULL) { return; } - event->backend->iface.event_free(event); + event->device->iface.event_free(event->device, event); } -void ggml_backend_event_record(ggml_backend_event_t event) { - GGML_ASSERT(event->backend->iface.event_record != NULL); +void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) { + GGML_ASSERT(backend->iface.event_record != NULL); - event->backend->iface.event_record(event); + backend->iface.event_record(backend, event); } void ggml_backend_event_synchronize(ggml_backend_event_t event) { - GGML_ASSERT(event->backend->iface.event_synchronize != NULL); + GGML_ASSERT(event->device->iface.event_synchronize); - event->backend->iface.event_synchronize(event); + event->device->iface.event_synchronize(event->device, event); } void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { @@ -391,555 +443,88 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) backend->iface.event_wait(backend, event); } -// backend registry +// Backend device -#define GGML_REG_MAX_BACKENDS 64 - -struct ggml_backend_reg { - char name[128]; - ggml_backend_init_fn init_fn; - ggml_backend_buffer_type_t default_buffer_type; - void * user_data; -}; - -static struct ggml_backend_reg ggml_backend_registry[GGML_REG_MAX_BACKENDS]; -static size_t ggml_backend_registry_count = 0; - -GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data); - -GGML_CALL static void ggml_backend_registry_init(void) { - static bool initialized = false; - - if (initialized) { - return; - } - - initialized = true; - - ggml_backend_register("CPU", ggml_backend_reg_cpu_init, ggml_backend_cpu_buffer_type(), NULL); - - // add forward decls here to avoid including the backend headers -#ifdef GGML_USE_CUDA - extern GGML_CALL void ggml_backend_cuda_reg_devices(void); - ggml_backend_cuda_reg_devices(); -#endif - -#ifdef GGML_USE_SYCL - extern void ggml_backend_sycl_reg_devices(void); - ggml_backend_sycl_reg_devices(); -#endif - -#ifdef GGML_USE_METAL - extern GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); - extern GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); - ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL); -#endif - -#ifdef GGML_USE_VULKAN - extern GGML_CALL int ggml_backend_vk_reg_devices(void); - ggml_backend_vk_reg_devices(); -#endif - -#ifdef GGML_USE_KOMPUTE - extern GGML_CALL void ggml_backend_kompute_reg_devices(void); - ggml_backend_kompute_reg_devices(); -#endif - -#ifdef GGML_USE_CANN - extern GGML_CALL int ggml_backend_cann_reg_devices(void); - ggml_backend_cann_reg_devices(); -#endif +const char * ggml_backend_dev_name(ggml_backend_dev_t device) { + return device->iface.get_name(device); } -GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) { - GGML_ASSERT(ggml_backend_registry_count < GGML_REG_MAX_BACKENDS); - - size_t id = ggml_backend_registry_count; - - ggml_backend_registry[id] = (struct ggml_backend_reg) { - /* .name = */ {0}, - /* .fn = */ init_fn, - /* .default_buffer_type = */ default_buffer_type, - /* .user_data = */ user_data, - }; - - snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name); - -#ifndef NDEBUG - fprintf(stderr, "%s: registered backend %s\n", __func__, name); -#endif - - ggml_backend_registry_count++; +const char * ggml_backend_dev_description(ggml_backend_dev_t device) { + return device->iface.get_description(device); } -size_t ggml_backend_reg_get_count(void) { - ggml_backend_registry_init(); - - return ggml_backend_registry_count; +void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + device->iface.get_memory(device, free, total); } -size_t ggml_backend_reg_find_by_name(const char * name) { - ggml_backend_registry_init(); - - for (size_t i = 0; i < ggml_backend_registry_count; i++) { - // TODO: case insensitive in a portable way - if (strcmp(ggml_backend_registry[i].name, name) == 0) { - return i; - } - } - - // not found - return SIZE_MAX; +enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { + return device->iface.get_type(device); } -// init from backend:params string -ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str) { - ggml_backend_registry_init(); +void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) { + memset(props, 0, sizeof(*props)); + device->iface.get_props(device, props); +} - const char * params = strchr(backend_str, ':'); - char backend_name[128]; - if (params == NULL) { - snprintf(backend_name, sizeof(backend_name), "%s", backend_str); - params = ""; - } else { - snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str); - params++; - } +ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) { + return device->reg; +} - size_t backend_i = ggml_backend_reg_find_by_name(backend_name); +ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) { + return device->iface.init_backend(device, params); +} - if (backend_i == SIZE_MAX) { - fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name); +ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) { + return device->iface.get_buffer_type(device); +} + +ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) { + if (device->iface.get_host_buffer_type == NULL) { return NULL; } - return ggml_backend_reg_init_backend(backend_i, params); + return device->iface.get_host_buffer_type(device); } -const char * ggml_backend_reg_get_name(size_t i) { - ggml_backend_registry_init(); - - GGML_ASSERT(i < ggml_backend_registry_count); - return ggml_backend_registry[i].name; +ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) { + return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size); } -ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params) { - ggml_backend_registry_init(); - - GGML_ASSERT(i < ggml_backend_registry_count); - return ggml_backend_registry[i].init_fn(params, ggml_backend_registry[i].user_data); +bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { + return device->iface.supports_op(device, op); } -ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i) { - ggml_backend_registry_init(); - - GGML_ASSERT(i < ggml_backend_registry_count); - return ggml_backend_registry[i].default_buffer_type; +bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) { + return device->iface.supports_buft(device, buft); } -ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) { - ggml_backend_registry_init(); - - GGML_ASSERT(i < ggml_backend_registry_count); - return ggml_backend_buft_alloc_buffer(ggml_backend_registry[i].default_buffer_type, size); -} - -// backend CPU - -static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment - -GGML_CALL static const char * ggml_backend_cpu_buffer_name(ggml_backend_buffer_t buffer) { - return "CPU"; - - GGML_UNUSED(buffer); -} - -GGML_CALL static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { - uintptr_t data = (uintptr_t)buffer->context; - - // align the buffer - if (data % TENSOR_ALIGNMENT != 0) { - data = GGML_PAD(data, TENSOR_ALIGNMENT); +bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { + if (device->iface.offload_op != NULL) { + return device->iface.offload_op(device, op); } - return (void *)data; -} - -GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { - free(buffer->context); -} - -GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - memcpy((char *)tensor->data + offset, data, size); - - GGML_UNUSED(buffer); -} - -GGML_CALL static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - memcpy(data, (const char *)tensor->data + offset, size); - - GGML_UNUSED(buffer); -} - -GGML_CALL static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { - if (ggml_backend_buffer_is_host(src->buffer)) { - memcpy(dst->data, src->data, ggml_nbytes(src)); - return true; - } return false; - - GGML_UNUSED(buffer); } -GGML_CALL static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - memset(buffer->context, value, buffer->size); +// Backend (reg) + +const char * ggml_backend_reg_name(ggml_backend_reg_t reg) { + return reg->iface.get_name(reg); } -static struct ggml_backend_buffer_i cpu_backend_buffer_i = { - /* .get_name = */ ggml_backend_cpu_buffer_name, - /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, - /* .get_base = */ ggml_backend_cpu_buffer_get_base, - /* .init_tensor = */ NULL, // no initialization required - /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, - /* .clear = */ ggml_backend_cpu_buffer_clear, - /* .reset = */ NULL, -}; - -// for buffers from ptr, free is not called -static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { - /* .get_name = */ ggml_backend_cpu_buffer_name, - /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed - /* .get_base = */ ggml_backend_cpu_buffer_get_base, - /* .init_tensor = */ NULL, // no initialization required - /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, - /* .clear = */ ggml_backend_cpu_buffer_clear, - /* .reset = */ NULL, -}; - -GGML_CALL static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "CPU"; - - GGML_UNUSED(buft); +size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) { + return reg->iface.get_device_count(reg); } -GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned - void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h) - if (data == NULL) { - fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size); +ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) { + return reg->iface.get_device(reg, index); +} + +void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (!reg->iface.get_proc_address) { return NULL; } - - return ggml_backend_buffer_init(buft, cpu_backend_buffer_i, data, size); -} - -GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return TENSOR_ALIGNMENT; - - GGML_UNUSED(buft); -} - -GGML_CALL static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; - - GGML_UNUSED(buft); -} - -GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = { - /* .iface = */ { - /* .get_name = */ ggml_backend_cpu_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, - }, - /* .context = */ NULL, - }; - - return &ggml_backend_cpu_buffer_type; -} - -#ifdef GGML_USE_CPU_HBM - -// buffer type HBM - -#include - -GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "CPU_HBM"; - - GGML_UNUSED(buft); -} - -GGML_CALL static const char * ggml_backend_cpu_hbm_buffer_get_name(ggml_backend_buffer_t buf) { - return "CPU_HBM"; - - GGML_UNUSED(buf); -} - -GGML_CALL static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) { - hbw_free(buffer->context); -} - -GGML_CALL static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - //void * ptr = hbw_malloc(size); - void * ptr; - int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size); - if (result != 0) { - fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size); - return NULL; - } - - ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); - buffer->buft = buft; - buffer->iface.get_name = ggml_backend_cpu_hbm_buffer_get_name; - buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer; - - return buffer; -} - -ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = { - /* .iface = */ { - /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, - }, - /* .context = */ NULL, - }; - - return &ggml_backend_cpu_buffer_type_hbm; -} -#endif - -struct ggml_backend_cpu_context { - int n_threads; - ggml_threadpool_t threadpool; - - void * work_data; - size_t work_size; - - ggml_abort_callback abort_callback; - void * abort_callback_data; -}; - -GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) { - return "CPU"; - - GGML_UNUSED(backend); -} - -GGML_CALL static void ggml_backend_cpu_free(ggml_backend_t backend) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - free(cpu_ctx->work_data); - free(cpu_ctx); - free(backend); -} - -GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cpu_get_default_buffer_type(ggml_backend_t backend) { - return ggml_backend_cpu_buffer_type(); - - GGML_UNUSED(backend); -} - -struct ggml_backend_plan_cpu { - struct ggml_cplan cplan; - struct ggml_cgraph cgraph; -}; - -GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - - struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); - - cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); - cpu_plan->cgraph = *cgraph; // FIXME: deep copy - - if (cpu_plan->cplan.work_size > 0) { - cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); - if (cpu_plan->cplan.work_data == NULL) { - free(cpu_plan); - return NULL; - } - } - - cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; - cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; - - return cpu_plan; -} - -GGML_CALL static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; - - free(cpu_plan->cplan.work_data); - free(cpu_plan); - - GGML_UNUSED(backend); -} - -GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { - struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; - - return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); - - GGML_UNUSED(backend); -} - -GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - - struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); - - if (cpu_ctx->work_size < cplan.work_size) { - free(cpu_ctx->work_data); - cpu_ctx->work_data = malloc(cplan.work_size); - if (cpu_ctx->work_data == NULL) { - cpu_ctx->work_size = 0; - return GGML_STATUS_ALLOC_FAILED; - } - cpu_ctx->work_size = cplan.work_size; - } - cplan.work_data = cpu_ctx->work_data; - - cplan.abort_callback = cpu_ctx->abort_callback; - cplan.abort_callback_data = cpu_ctx->abort_callback_data; - - return ggml_graph_compute(cgraph, &cplan); -} - -GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - switch (op->op) { - case GGML_OP_CPY: - return - op->type != GGML_TYPE_IQ2_XXS && - op->type != GGML_TYPE_IQ2_XS && - op->type != GGML_TYPE_IQ1_S && - op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float - case GGML_OP_MUL_MAT: - return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type; - case GGML_OP_ROPE_BACK: - return op->src[2] == NULL && (op->op_params[2] & 4) == 0; - case GGML_OP_IM2COL_BACK: - return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; - default: - return true; - } - - GGML_UNUSED(backend); -} - -GGML_CALL static bool ggml_backend_cpu_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - return ggml_backend_buft_is_host(buft); - - GGML_UNUSED(backend); -} - -static struct ggml_backend_i cpu_backend_i = { - /* .get_name = */ ggml_backend_cpu_name, - /* .free = */ ggml_backend_cpu_free, - /* .get_default_buffer_type = */ ggml_backend_cpu_get_default_buffer_type, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, - /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, - /* .graph_compute = */ ggml_backend_cpu_graph_compute, - /* .supports_op = */ ggml_backend_cpu_supports_op, - /* .supports_buft = */ ggml_backend_cpu_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, -}; - -static ggml_guid_t ggml_backend_cpu_guid(void) { - static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 }; - return &guid; -} - -ggml_backend_t ggml_backend_cpu_init(void) { - struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); - if (ctx == NULL) { - return NULL; - } - - ctx->n_threads = GGML_DEFAULT_N_THREADS; - ctx->threadpool = NULL; - ctx->work_data = NULL; - ctx->work_size = 0; - ctx->abort_callback = NULL; - ctx->abort_callback_data = NULL; - - ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); - if (cpu_backend == NULL) { - free(ctx); - return NULL; - } - - *cpu_backend = (struct ggml_backend) { - /* .guid = */ ggml_backend_cpu_guid(), - /* .interface = */ cpu_backend_i, - /* .context = */ ctx - }; - return cpu_backend; -} - -GGML_CALL bool ggml_backend_is_cpu(ggml_backend_t backend) { - return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid()); -} - -void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { - GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); - - struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; - ctx->n_threads = n_threads; -} - -void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) { - GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); - - struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; - - if (ctx->threadpool && ctx->threadpool != threadpool) { - // already had a different threadpool, pause/suspend it before switching - ggml_threadpool_pause(ctx->threadpool); - } - ctx->threadpool = threadpool; -} - -void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) { - GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); - - struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; - ctx->abort_callback = abort_callback; - ctx->abort_callback_data = abort_callback_data; -} - -GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { - GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); - return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size); -} - -GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) { - return ggml_backend_cpu_init(); - - GGML_UNUSED(params); - GGML_UNUSED(user_data); + return reg->iface.get_proc_address(reg, name); } // multi-buffer buffer @@ -949,16 +534,8 @@ struct ggml_backend_multi_buffer_context { size_t n_buffers; }; -typedef struct ggml_backend_multi_buffer_context * ggml_backend_multi_buffer_context_t; - -GGML_CALL static const char * ggml_backend_multi_buffer_get_name(ggml_backend_buffer_t buffer) { - ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; - - return ctx->buffers[0]->iface.get_name(ctx->buffers[0]); -} - -GGML_CALL static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; +static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { ggml_backend_buffer_free(ctx->buffers[i]); } @@ -967,31 +544,27 @@ GGML_CALL static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_ free(ctx); } -GGML_CALL static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; +static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { ggml_backend_buffer_clear(ctx->buffers[i], value); } } -static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface(void) { - static struct ggml_backend_buffer_i multi_backend_buffer_i = { - /* .get_name = */ ggml_backend_multi_buffer_get_name, - /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer, - /* .get_base = */ NULL, - /* .init_tensor = */ NULL, - /* .set_tensor = */ NULL, - /* .get_tensor = */ NULL, - /* .cpy_tensor = */ NULL, - /* .clear = */ ggml_backend_multi_buffer_clear, - /* .reset = */ NULL, - }; +static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = { + /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer, + /* .get_base = */ NULL, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ NULL, + /* .get_tensor = */ NULL, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_multi_buffer_clear, + /* .reset = */ NULL, +}; - return multi_backend_buffer_i; -} - -GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) { - ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) malloc(sizeof(struct ggml_backend_multi_buffer_context)); +ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers) { + ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) malloc(sizeof(struct ggml_backend_multi_buffer_context)); ctx->n_buffers = n_buffers; ctx->buffers = (ggml_backend_buffer_t *) malloc(n_buffers * sizeof(ggml_backend_buffer_t)); @@ -1003,16 +576,16 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_back total_size += ggml_backend_buffer_get_size(buffers[i]); } - return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_context_interface(), ctx, total_size); + return ggml_backend_buffer_init(buffers[0]->buft, ggml_backend_multi_buffer_i, ctx, total_size); } -GGML_CALL bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) { - return buffer->iface.get_name == ggml_backend_multi_buffer_get_name; +bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer; } -GGML_CALL void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { +void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer)); - ggml_backend_multi_buffer_context_t ctx = (ggml_backend_multi_buffer_context_t) buffer->context; + ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; for (size_t i = 0; i < ctx->n_buffers; i++) { ggml_backend_buffer_set_usage(ctx->buffers[i], usage); } @@ -1099,7 +672,7 @@ struct ggml_backend_sched { char * context_buffer; size_t context_buffer_size; - bool debug; + int debug; }; #define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) @@ -1118,7 +691,7 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen } static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) { - ggml_backend_buffer_t buffer = tensor->buffer; + ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (buffer == NULL) { return -1; } @@ -1132,7 +705,7 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co } #ifndef NDEBUG - fprintf(stderr, "%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n", + GGML_LOG_DEBUG("%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n", __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name); #endif @@ -1151,8 +724,6 @@ static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML // returns the backend that should be used for the node based on the current locations static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) { - // TODO: use supports_op to check if the backend supports the op - // assign pre-allocated nodes to their backend int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor); if (cur_backend_id != -1) { @@ -1171,7 +742,8 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) { // since the tensor is pre-allocated, it cannot be moved to another backend - GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation"); + ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ABORT("pre-allocated tensor (%s) in a buffer (%s) that cannot run the operation (%s)", tensor->name, ggml_backend_buffer_name(buffer), ggml_op_name(tensor->op)); } // graph input @@ -1187,10 +759,12 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (src == NULL) { continue; } - if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + // skip ROPE since the rope freqs tensor is too small to choose a backend based on it + // not an ideal solution + if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op - if (src_backend_id == sched->n_backends - 1) { + if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { for (int b = 0; b < src_backend_id; b++) { if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); @@ -1221,32 +795,37 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str for (int i = 0; i < graph->n_nodes; i++) { if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) { ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id]; - fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), + GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs); for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) { - fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, + if (j == 0) { + GGML_LOG_DEBUG(": "); + } + GGML_LOG_DEBUG("[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j]))); } - fprintf(stderr, "\n"); + GGML_LOG_DEBUG("\n"); cur_split++; } struct ggml_tensor * node = graph->nodes[i]; if (ggml_is_view_op(node->op)) { continue; } - ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name, - fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node)); - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * src = node->src[j]; - if (src == NULL) { - continue; + if (sched->debug > 1) { + ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name, + fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src); + GGML_LOG_DEBUG(" %20.20s (%5.5s) [%5.5s %8.8s]", src->name, + fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src)); } - ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src); - fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name, - fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src)); + GGML_LOG_DEBUG("\n"); } - fprintf(stderr, "\n"); } } @@ -1538,11 +1117,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg if (src == NULL) { continue; } - // check if a weight is on a different backend + // check if a weight is on a different and incompatible backend // by starting a new split, the memory of the previously offloaded weights can be reused if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = tensor_backend_id(src); - if (src_backend_id != cur_backend_id) { + if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) { need_new_split = true; break; } @@ -1554,7 +1133,6 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg int src_backend_id = sched->hv_tensor_backend_ids[id]; bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id); if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) { - //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name); need_new_split = true; break; } @@ -1567,7 +1145,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg i_split++; if (i_split >= sched->splits_capacity) { sched->splits_capacity *= 2; - sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); + sched->splits = (ggml_backend_sched_split *) + realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); GGML_ASSERT(sched->splits != NULL); } split = &sched->splits[i_split]; @@ -1653,11 +1232,11 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg sched->prev_leaf_backend_ids = tmp; } - int graph_size = MAX(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies; + int graph_size = std::max(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies; if (sched->graph.size < graph_size) { sched->graph.size = graph_size; - sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *)); - sched->graph.leafs = realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *)); + sched->graph.nodes = (ggml_tensor **) realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *)); + sched->graph.leafs = (ggml_tensor **) realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *)); GGML_ASSERT(sched->graph.nodes != NULL); GGML_ASSERT(sched->graph.leafs != NULL); } @@ -1759,11 +1338,11 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { // the re-allocation may cause the split inputs to be moved to a different address ggml_backend_sched_synchronize(sched); #ifndef NDEBUG - fprintf(stderr, "%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); + GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); #endif ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { - fprintf(stderr, "%s: failed to allocate graph\n", __func__); + GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__); return false; } } @@ -1856,7 +1435,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s // record the event of this copy if (split->n_inputs > 0) { if (sched->events[split_backend_id][sched->cur_copy] != NULL) { - ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy]); + ggml_backend_event_record(sched->events[split_backend_id][sched->cur_copy], split_backend); } } } @@ -1874,41 +1453,43 @@ ggml_backend_sched_t ggml_backend_sched_new( bool parallel) { GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); - GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU + GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); - struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched)); + struct ggml_backend_sched * sched = (ggml_backend_sched *) calloc(1, sizeof(struct ggml_backend_sched)); - sched->debug = getenv("GGML_SCHED_DEBUG") != NULL; + const char * GGML_SCHED_DEBUG = getenv("GGML_SCHED_DEBUG"); + sched->debug = GGML_SCHED_DEBUG ? atoi(GGML_SCHED_DEBUG) : 0; sched->n_backends = n_backends; sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; // initialize hash table // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead) sched->hash_set = ggml_hash_set_new(graph_size); - sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); - sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); + sched->hv_tensor_backend_ids = (int *) malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); + sched->hv_tensor_copies = (ggml_tensor **) malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2; - sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); - sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); - sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); - sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); + sched->node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->node_backend_ids[0])); + sched->leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); + sched->prev_node_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); + sched->prev_leaf_backend_ids = (int *) calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); - sched->context_buffer = malloc(sched->context_buffer_size); + sched->context_buffer = (char *) malloc(sched->context_buffer_size); const int initial_splits_capacity = 16; - sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0])); + sched->splits = (ggml_backend_sched_split *) calloc(initial_splits_capacity, sizeof(sched->splits[0])); sched->splits_capacity = initial_splits_capacity; for (int b = 0; b < n_backends; b++) { sched->backends[b] = backends[b]; sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]); GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b])); + if (sched->n_copies > 1) { for (int c = 0; c < sched->n_copies; c++) { - sched->events[b][c] = ggml_backend_event_new(backends[b]); + sched->events[b][c] = ggml_backend_event_new(backends[b]->device); } } } @@ -1961,12 +1542,13 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * ggml_backend_sched_split_graph(sched, measure_graph); + ggml_backend_sched_synchronize(sched); + if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { return false; } ggml_backend_sched_reset(sched); - ggml_backend_sched_synchronize(sched); return true; } @@ -2144,8 +1726,8 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size); - struct ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT - bool * node_init = calloc(hash_set.size, sizeof(node_init[0])); + struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT + bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0])); struct ggml_init_params params = { /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false), @@ -2157,13 +1739,13 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s struct ggml_context * ctx_unallocated = ggml_init(params); if (ctx_allocated == NULL || ctx_unallocated == NULL) { - fprintf(stderr, "failed to allocate context for graph copy\n"); + GGML_LOG_ERROR("%s: failed to allocate context for graph copy\n", __func__); ggml_hash_set_free(&hash_set); free(node_copies); free(node_init); ggml_free(ctx_allocated); ggml_free(ctx_unallocated); - return (struct ggml_backend_graph_copy) { + return { /* .buffer = */ NULL, /* .ctx_allocated = */ NULL, /* .ctx_unallocated = */ NULL, @@ -2180,13 +1762,13 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s // allocate nodes ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend); if (buffer == NULL) { - fprintf(stderr, "failed to allocate buffer for graph copy\n"); + GGML_LOG_ERROR("%s: failed to allocate buffer for graph copy\n", __func__); ggml_hash_set_free(&hash_set); free(node_copies); free(node_init); ggml_free(ctx_allocated); ggml_free(ctx_unallocated); - return (struct ggml_backend_graph_copy) { + return { /* .buffer = */ NULL, /* .ctx_allocated = */ NULL, /* .ctx_unallocated = */ NULL, @@ -2215,7 +1797,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s free(node_copies); free(node_init); - return (struct ggml_backend_graph_copy) { + return { /* .buffer = */ buffer, /* .ctx_allocated = */ ctx_allocated, /* .ctx_unallocated = */ ctx_unallocated, @@ -2267,3 +1849,154 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t return true; } + +// CPU backend - buffer + +static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { + uintptr_t data = (uintptr_t)buffer->context; + + // align the buffer + if (data % TENSOR_ALIGNMENT != 0) { + data = GGML_PAD(data, TENSOR_ALIGNMENT); + } + + return (void *)data; +} + +static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_aligned_free(buffer->context, buffer->size); +} + +static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + memcpy((char *)tensor->data + offset, data, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + memcpy(data, (const char *)tensor->data + offset, size); + + GGML_UNUSED(buffer); +} + +static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; + } + return false; + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); +} + +static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { + /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; + +static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { + /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed + /* .get_base = */ ggml_backend_cpu_buffer_get_base, + /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cpu_buffer_clear, + /* .reset = */ NULL, +}; + +// CPU backend buffer type + +// this buffer type is defined here to make it available to all backends + +static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * data = ggml_aligned_malloc(size); + + if (data == NULL) { + GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size); + return NULL; + } + + return ggml_backend_buffer_init(buft, ggml_backend_cpu_buffer_i, data, size); +} + +static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; + + GGML_UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .device = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_buffer_type; +} + +static const char * ggml_backend_cpu_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_Mapped"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_cpu_buffer_from_ptr_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_buffer_from_ptr_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .device = */ NULL, // FIXME ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_buffer_type; +} + +ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { + GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); + return ggml_backend_buffer_init(ggml_backend_cpu_buffer_from_ptr_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size); +} diff --git a/ggml/src/ggml-blas/CMakeLists.txt b/ggml/src/ggml-blas/CMakeLists.txt new file mode 100644 index 000000000..0bf3c05d9 --- /dev/null +++ b/ggml/src/ggml-blas/CMakeLists.txt @@ -0,0 +1,87 @@ +if (GGML_STATIC) + set(BLA_STATIC ON) +endif() +#if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22) +# set(BLA_SIZEOF_INTEGER 8) +#endif() + +set(BLA_VENDOR ${GGML_BLAS_VENDOR}) +find_package(BLAS) + +if (BLAS_FOUND) + message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + + ggml_add_backend_library(ggml-blas + ggml-blas.cpp + ) + + if (${GGML_BLAS_VENDOR} MATCHES "Apple") + add_compile_definitions(ACCELERATE_NEW_LAPACK) + add_compile_definitions(ACCELERATE_LAPACK_ILP64) + add_compile_definitions(GGML_BLAS_USE_ACCELERATE) + elseif ("${BLAS_INCLUDE_DIRS}" STREQUAL "") + # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake. + # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268 + find_package(PkgConfig REQUIRED) + if (${GGML_BLAS_VENDOR} MATCHES "Generic") + pkg_check_modules(DepBLAS blas) + elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS") + # As of openblas v0.3.22, the 64-bit is named openblas64.pc + pkg_check_modules(DepBLAS openblas64) + if (NOT DepBLAS_FOUND) + pkg_check_modules(DepBLAS openblas) + endif() + elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME") + add_compile_definitions(GGML_BLAS_USE_BLIS) + pkg_check_modules(DepBLAS blis) + elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS") + pkg_check_modules(DepBLAS blas-atlas) + elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS") + pkg_check_modules(DepBLAS flexiblas_api) + elseif (${GGML_BLAS_VENDOR} MATCHES "Intel") + add_compile_definitions(GGML_BLAS_USE_MKL) + # all Intel* libraries share the same include path + pkg_check_modules(DepBLAS mkl-sdl) + elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC") + # this doesn't provide pkg-config + # suggest to assign BLAS_INCLUDE_DIRS on your own + if ("${NVHPC_VERSION}" STREQUAL "") + message(WARNING "Better to set NVHPC_VERSION") + else() + set(DepBLAS_FOUND ON) + set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include") + endif() + endif() + if (DepBLAS_FOUND) + set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS}) + else() + message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically" + " detected by pkgconfig, trying to find cblas.h from possible paths...") + find_path(BLAS_INCLUDE_DIRS + NAMES cblas.h + HINTS + /usr/include + /usr/local/include + /usr/include/openblas + /opt/homebrew/opt/openblas/include + /usr/local/opt/openblas/include + /usr/include/x86_64-linux-gnu/openblas/include + ) + endif() + endif() + + message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}") + + target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS}) + + if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) + add_compile_definitions(GGML_BLAS_USE_MKL) + endif() + + target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES}) + target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS}) +else() + message(ERROR "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct GGML_BLAS_VENDOR") +endif() diff --git a/ggml/src/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp similarity index 55% rename from ggml/src/ggml-blas.cpp rename to ggml/src/ggml-blas/ggml-blas.cpp index 713731735..ec158dfac 100644 --- a/ggml/src/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -1,10 +1,12 @@ +#include "ggml-impl.h" #include "ggml-blas.h" #include "ggml-backend-impl.h" #include #include +#include -#if defined(GGML_USE_ACCELERATE) +#if defined(GGML_BLAS_USE_ACCELERATE) # include #elif defined(GGML_BLAS_USE_MKL) # include @@ -25,30 +27,6 @@ struct ggml_backend_blas_context { #endif }; -// helper function to determine if it is better to use BLAS or not -// for large matrices, BLAS is faster -static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - const int64_t ne10 = src1->ne[0]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - - // TODO: find the optimal values for these - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - src1->type == GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { - - /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ - return true; - } - - return false; -} - static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -87,8 +65,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg // convert src0 to float if (type != GGML_TYPE_F32) { - ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type); - ggml_to_float_t const to_float = type_traits.to_float; + const auto * type_traits = ggml_get_type_traits(type); + ggml_to_float_t const to_float = type_traits->to_float; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { @@ -234,25 +212,19 @@ static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct g // backend interface -GGML_CALL static const char * ggml_backend_blas_name(ggml_backend_t backend) { +static const char * ggml_backend_blas_get_name(ggml_backend_t backend) { return "BLAS"; GGML_UNUSED(backend); } -GGML_CALL static void ggml_backend_blas_free(ggml_backend_t backend) { +static void ggml_backend_blas_free(ggml_backend_t backend) { ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; delete ctx; delete backend; } -GGML_CALL static ggml_backend_buffer_type_t ggml_backend_blas_get_default_buffer_type(ggml_backend_t backend) { - return ggml_backend_cpu_buffer_type(); - - GGML_UNUSED(backend); -} - -GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { +static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context; for (int i = 0; i < cgraph->n_nodes; i++) { @@ -284,31 +256,9 @@ GGML_CALL static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t GGML_UNUSED(backend); } -GGML_CALL static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - const struct ggml_tensor * src0 = op->src[0]; - const struct ggml_tensor * src1 = op->src[1]; - - return (op->op == GGML_OP_MUL_MAT && ggml_backend_blas_use_blas(op)) || - (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 && - op->src[1]->type == GGML_TYPE_F32 && - ggml_is_matrix(src0) && - ggml_is_matrix(src1) && - ggml_is_contiguous(src0) && - (ggml_is_contiguous(src1) || ggml_is_transposed(src1))); - - GGML_UNUSED(backend); -} - -GGML_CALL static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - return ggml_backend_buft_is_host(buft); - - GGML_UNUSED(backend); -} - static struct ggml_backend_i blas_backend_i = { - /* .get_name = */ ggml_backend_blas_name, + /* .get_name = */ ggml_backend_blas_get_name, /* .free = */ ggml_backend_blas_free, - /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, @@ -318,14 +268,8 @@ static struct ggml_backend_i blas_backend_i = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_blas_graph_compute, - /* .supports_op = */ ggml_backend_blas_supports_op, - /* .supports_buft = */ ggml_backend_blas_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, /* .event_record = */ NULL, /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, }; static ggml_guid_t ggml_backend_blas_guid(void) { @@ -339,23 +283,24 @@ ggml_backend_t ggml_backend_blas_init(void) { ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_blas_guid(), /* .interface = */ blas_backend_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0), /* .context = */ ctx, }; -#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) +#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) if (openblas_get_parallel() != OPENBLAS_OPENMP) { - fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__); + GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__); } #endif -#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP) - fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__); +#if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP) + GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__); #endif return backend; } -GGML_CALL bool ggml_backend_is_blas(ggml_backend_t backend) { +bool ggml_backend_is_blas(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid()); } @@ -365,3 +310,208 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context; ctx->n_threads = n_threads; } + +// device interface + +static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) { + return "BLAS"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) { + #if defined(GGML_BLAS_USE_ACCELERATE) + return "Accelerate"; + #elif defined(GGML_BLAS_USE_MKL) + return "MKL"; + #elif defined(GGML_BLAS_USE_BLIS) + return "BLIS"; + #elif defined(GGML_BLAS_USE_NVPL) + return "NVPL"; + #elif defined(OPENBLAS_VERSION) + return "OpenBLAS"; + #else + return "BLAS"; + #endif + + GGML_UNUSED(dev); +} + +static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + // TODO + *free = 0; + *total = 0; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_ACCEL; + + GGML_UNUSED(dev); +} + +static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_blas_device_get_name(dev); + props->description = ggml_backend_blas_device_get_description(dev); + props->type = ggml_backend_blas_device_get_type(dev); + ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) { + return ggml_backend_blas_init(); + + GGML_UNUSED(dev); + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + return ggml_backend_cpu_buffer_from_ptr(ptr, size); + + GGML_UNUSED(dev); + GGML_UNUSED(max_tensor_size); +} + +static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + + case GGML_OP_MUL_MAT: + { + // BLAS usually is only faster for large matrices + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = op->ne[0]; + const int64_t ne1 = op->ne[1]; + + // TODO: find the optimal value + const int64_t min_batch = 32; + + return ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && + src1->type == GGML_TYPE_F32 && + (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) && + (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); + } + + case GGML_OP_OUT_PROD: + return op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32 && + ggml_is_matrix(src0) && + ggml_is_matrix(src1) && + ggml_is_contiguous(src0) && + (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) && + (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL); + + default: + return false; + + } + + GGML_UNUSED(dev); +} + +static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_host(buft); + + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_blas_device_i = { + /* .get_name = */ ggml_backend_blas_device_get_name, + /* .get_description = */ ggml_backend_blas_device_get_description, + /* .get_memory = */ ggml_backend_blas_device_get_memory, + /* .get_type = */ ggml_backend_blas_device_get_type, + /* .get_props = */ ggml_backend_blas_device_get_props, + /* .init_backend = */ ggml_backend_blas_device_init_backend, + /* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr, + /* .supports_op = */ ggml_backend_blas_device_supports_op, + /* .supports_buft = */ ggml_backend_blas_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend reg interface + +static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) { + return "BLAS"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + static ggml_backend_device ggml_backend_blas_device = { + /* .iface = */ ggml_backend_blas_device_i, + /* .reg = */ reg, + /* .context = */ nullptr, + }; + + return &ggml_backend_blas_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) { + return (void *)ggml_backend_blas_set_n_threads; + } + return NULL; + + GGML_UNUSED(reg); + GGML_UNUSED(name); +} + +static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = { + /* .get_name = */ ggml_backend_blas_reg_get_name, + /* .get_device_count = */ ggml_backend_blas_reg_get_device_count, + /* .get_device = */ ggml_backend_blas_reg_get_device, + /* .get_proc_address = */ ggml_backend_blas_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_blas_reg(void) { + static struct ggml_backend_reg ggml_backend_blas_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_blas_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_blas_reg; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_blas_reg) diff --git a/ggml/src/ggml-cann/CMakeLists.txt b/ggml/src/ggml-cann/CMakeLists.txt new file mode 100644 index 000000000..05cf06bfa --- /dev/null +++ b/ggml/src/ggml-cann/CMakeLists.txt @@ -0,0 +1,76 @@ +if ("cann${CANN_INSTALL_DIR}" STREQUAL "cann" AND DEFINED ENV{ASCEND_TOOLKIT_HOME}) + set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME}) + message(STATUS "CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}") +endif() + +# Auto-detech Soc type and Soc version, if detect failed, will abort build +set(SOC_VERSION "") +function(detect_ascend_soc_type SOC_VERSION) + execute_process( + COMMAND bash -c "npu-smi info|awk -F' ' 'NF > 0 && NR==7 {print $3}'" + OUTPUT_VARIABLE npu_info + RESULT_VARIABLE npu_result + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + if("${npu_info}" STREQUAL "" OR ${npu_result}) + message(FATAL_ERROR "Auto-detech ascend soc type failed, please specify manually or check ascend device working normally.") + endif() + set(${SOC_VERSION} "Ascend${npu_info}" PARENT_SCOPE) +endfunction() + +if(NOT SOC_TYPE) + detect_ascend_soc_type(SOC_VERSION) + set(SOC_TYPE "${SOC_VERSION}") + message(STATUS "CANN: SOC_VERSION auto-detected is:${SOC_VERSION}") +endif() + +string(TOLOWER ${SOC_TYPE} SOC_VERSION) # SOC_VERSION need lower + +# Construct Soc specify compile option: ASCEND_#Soc_Major_SN. Such as ASCEND_910B, ASCEND_310P. +string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}") +set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}") +string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION) + +if (CANN_INSTALL_DIR) + # Only Support Linux. + if (NOT UNIX) + message(FATAL_ERROR "CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}") + endif() + + # Supported platforms: x86-64, arm64 + if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64") + else() + message(FATAL_ERROR "CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}") + endif() + + # Set header and libs + set(CANN_INCLUDE_DIRS + ${CANN_INSTALL_DIR}/include + ${CANN_INSTALL_DIR}/include/aclnn + ${CANN_INSTALL_DIR}/acllib/include + ) + + add_subdirectory(kernels) + list(APPEND CANN_LIBRARIES + ascendcl + nnopbase + opapi + acl_op_compiler + ascendc_kernels + ) + + file(GLOB GGML_SOURCES_CANN "*.cpp") + + ggml_add_backend_library(ggml-cann ${GGML_SOURCES_CANN}) + target_link_libraries(ggml-cann PRIVATE ${CANN_LIBRARIES}) + target_include_directories(ggml-cann PRIVATE ${CANN_INCLUDE_DIRS}) + target_link_directories(ggml-cann PRIVATE ${CANN_INSTALL_DIR}/lib64) + + target_compile_definitions(ggml-cann PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}") + + message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}") + message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}") +else() + message(FATAL_ERROR "CANN: Can't find CANN_INSTALL_DIR, did you forget to source set_var.sh?") +endif() diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index a4ec8418e..b2d857e1e 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -22,11 +22,14 @@ #include "aclnn_ops.h" +#include #include +#include #include #include #include #include +#include #include #include #include @@ -34,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -53,6 +57,7 @@ #include #include +#include "ggml-impl.h" #include "kernels/ascendc_kernels.h" #define GGML_COMMON_DECL_C @@ -241,10 +246,14 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_src1 = ggml_cann_create_tensor(src1); aclTensor* acl_dst = ggml_cann_create_tensor(dst); - int64_t concat_dim = 1; + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + int32_t acl_dim = 3 - dim; + aclTensor* tensors[] = {acl_src0, acl_src1}; aclTensorList* tensorList = aclCreateTensorList(tensors, 2); - aclnn_concat(ctx, tensorList, acl_dst, concat_dim); + aclnn_concat(ctx, tensorList, acl_dst, acl_dim); ACL_CHECK(aclDestroyTensorList(tensorList)); ACL_CHECK(aclDestroyTensor(acl_dst)); @@ -1096,9 +1105,9 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer, } /** - * @brief Creates an ACL tensor initialized with ones using a provided buffer. + * @brief Creates an ACL tensor initialized with value using a provided buffer. * - * This function initializes a tensor with ones using the specified buffer and + * This function initializes a tensor with value using the specified buffer and * tensor parameters. * * @param ctx The context for the CANN backend operations. @@ -1111,12 +1120,12 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer, * @param type_size The size of each element in the tensor data type. * @param value The value to be used for initializing the tensor (default * is 1.0). - * @return An ACL tensor initialized with ones. + * @return An ACL tensor initialized with value. */ -static aclTensor* aclnn_ones(ggml_backend_cann_context& ctx, void* buffer, - size_t n_bytes, int64_t* ne, int64_t dims, - aclDataType type, size_t type_size, - float value = 1.0f) { +static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer, + size_t n_bytes, int64_t* ne, int64_t dims, + aclDataType type, size_t type_size, + float value = 1.0f) { aclTensor* acl_tensor = aclnn_zero(ctx, buffer, n_bytes, ne, dims, type, type_size); float alpha_host = 1.0f; @@ -1158,7 +1167,7 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src); ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes); - aclTensor* acl_gamma = aclnn_ones( + aclTensor* acl_gamma = aclnn_values( ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1, ggml_cann_type_mapping(src->type), ggml_element_size(src)); @@ -1202,9 +1211,9 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes); aclTensor* mask_tensor = - aclnn_ones(ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, - GGML_MAX_DIMS, ggml_cann_type_mapping(src->type), - ggml_element_size(src), value); + aclnn_values(ctx, one_tensor_allocator.get(), one_tensor_n_bytes, + src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type), + ggml_element_size(src), value); uint64_t workspaceSize = 0; aclOpExecutor* executor; @@ -1437,10 +1446,6 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src0 = dst->src[0]; // kernel ggml_tensor* src1 = dst->src[1]; // input - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - GGML_TENSOR_BINARY_OP_LOCALS; // aclnnIm2col only works on 2D. set s1, p1, d1 to 1 to perform 2D @@ -1462,9 +1467,6 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const int64_t OH = is_2D ? ne2 : 1; const int64_t OW = ne1; - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - // memory allocated increased to 3x when is_2D == false const int64_t n_bytes_factor = is_2D ? 1 : 3; @@ -1768,6 +1770,92 @@ static void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, ACL_CHECK(aclnnSin(workspaceAddr, workspaceSize, executor, ctx.stream())); } +/** + * @brief Performs element-wise division of tensor1 by tensor2 , multiplies the + result by the scalar value and adds it to self . + * + * Performs element-wise division of tensor1 by tensor2, + * multiplies the result by the scalar value and adds it to self . + * The operation is defined as: + * \f[ + * \text{out}_i = \text{selft}_i + \text{value} \times + \frac{\text{tensor1}_i}{\text{tensor2}_i} + * \f] + + * @param ctx The context for the CANN backend operations. + * @param acl_self The source tensor on which the addcdiv function will be + applied. + * @param tensor1 Numerator tensor. + * @param tensor2 Denominator tensor. + * @param value The value to be used for coefficient. + */ +static void aclnn_inplace_addcdiv(ggml_backend_cann_context& ctx, + aclTensor* acl_self, aclTensor* tensor1, + aclTensor* tensor2, float value) { + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); + + ACL_CHECK(aclnnInplaceAddcdivGetWorkspaceSize( + acl_self, tensor1, tensor2, acl_value, &workspaceSize, &executor)); + if (workspaceSize > 0) { + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); + workspaceAddr = workspace_allocator.get(); + } + + ACL_CHECK(aclnnInplaceAddcdiv(workspaceAddr, workspaceSize, executor, + ctx.stream())); +} + +/** + * @brief Matrix division, optionally in-place. + * + * This function division each element of the source tensor `acl_src` by the + * tensor `acl_other` and stores the result in the destination tensor `acl_dst`. + * If `inplace` is true, `acl_dst` will not be used and the operation is + * performed in-place on `acl_src`. The operation is defined as: \f[ + * \text{dst}_i = \frac{\text{acl_src}_i}{\text{acl_other}_i} + * \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_src Numerator tensor.. + * @param acl_other Denominator tensor. + * @param acl_dst The destination tensor where the result will be stored if + * `inplace` is false. + * @param inplace Flag indicating whether to perform the operation in-place on + * `acl_src`. + */ +static void aclnn_div_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_src, + aclTensor* acl_other, aclTensor* acl_dst, + bool inplace) { + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + if (inplace) { + ACL_CHECK(aclnnInplaceDivGetWorkspaceSize(acl_src, acl_other, + &workspaceSize, &executor)); + if (workspaceSize > 0) { + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); + workspaceAddr = workspace_allocator.get(); + } + + ACL_CHECK(aclnnInplaceDiv(workspaceAddr, workspaceSize, executor, + ctx.stream())); + } else { + ACL_CHECK(aclnnDivGetWorkspaceSize(acl_src, acl_other, acl_dst, + &workspaceSize, &executor)); + if (workspaceSize > 0) { + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); + workspaceAddr = workspace_allocator.get(); + } + + ACL_CHECK( + aclnnDiv(workspaceAddr, workspaceSize, executor, ctx.stream())); + } +} + void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst) { const ggml_tensor* src = dst->src[0]; @@ -2311,7 +2399,16 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ctx.stream())); switch (src0->type) { - case GGML_TYPE_F32: + case GGML_TYPE_F32: { +#ifdef ASCEND_310P + // Special operation for get_row_f32 kernel of 310P: clear the + // content of dest data buffer when row is not aligned to 32 bytes + if ((src0->ne[0] % 8) != 0) { + size_t dst_len = src1->ne[0] * src1->ne[1] * src1->ne[2] * + src0->ne[0] * ggml_type_size(GGML_TYPE_F32); + ACL_CHECK(aclrtMemset((char*)dst->data, dst_len, 0, dst_len)); + } +#endif aclrtlaunch_ascendc_get_row_f32( 24, ctx.stream(), src0->data, src1->data, dst->data, ((ggml_tensor*)src0->extra)->ne, @@ -2320,7 +2417,19 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne, ((ggml_tensor*)dst->extra)->nb); break; - case GGML_TYPE_F16: + } + case GGML_TYPE_F16: { +#ifdef ASCEND_310P + // Special operation for get_row_f16 kernel of 310P: clear the + // content of dest data buffer when row is not aligned to 32 bytes + if ((src0->ne[0] % 16) != 0) { + size_t dst_len = + src1->ne[0] * src1->ne[1] * src1->ne[2] * src0->ne[0] * + ggml_type_size( + GGML_TYPE_F32); // out is also f32, even input is f16 + ACL_CHECK(aclrtMemset((char*)dst->data, dst_len, 0, dst_len)); + } +#endif aclrtlaunch_ascendc_get_row_f16( 24, ctx.stream(), src0->data, src1->data, dst->data, ((ggml_tensor*)src0->extra)->ne, @@ -2329,6 +2438,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne, ((ggml_tensor*)dst->extra)->nb); break; + } case GGML_TYPE_Q4_0: aclrtlaunch_ascendc_get_row_q4_0( 24, ctx.stream(), src0->data, src1->data, dst->data, @@ -2407,7 +2517,6 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input, aclTensor* acl_weight, aclTensor* acl_dst) { int8_t cube_math_type = 1; // ALLOW_FP32_DOWN_PRECISION, when input is // fp32, atlas a2 will transpose it to HFLOAT32. - uint64_t workspaceSize = 0; aclOpExecutor* executor; void* workspaceAddr = nullptr; @@ -2425,6 +2534,81 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input, aclnnMatmul(workspaceAddr, workspaceSize, executor, ctx.stream())); } +/** + * @brief Performs matrix multiplication of two 2D tensors. + * + * This function computes the matrix multiplication of the input tensor + * `acl_input` and the weight tensor `acl_weight`, and stores the result in the + * destination tensor `acl_dst`. + * The operation is defined as: + * \f[ + * \text {acl_dst}=\text {acl_input@acl_weight} + * \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_input The input tensor for the matrix multiplication. + * @param acl_weight The weight tensor for the matrix multiplication. + * @param acl_dst The destination tensor where the result of the matrix + * multiplication will be stored. + */ +static void aclnn_mat_mul_2d(ggml_backend_cann_context& ctx, + aclTensor* acl_input, aclTensor* acl_weight, + aclTensor* acl_dst) { + int8_t cube_math_type = 2; + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + ACL_CHECK(aclnnMmGetWorkspaceSize(acl_input, acl_weight, acl_dst, + cube_math_type, &workspaceSize, + &executor)); + + if (workspaceSize > 0) { + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); + workspaceAddr = workspace_allocator.get(); + } + + ACL_CHECK(aclnnMm(workspaceAddr, workspaceSize, executor, ctx.stream())); +} + +/** + * @brief Performs matrix multiplication of two 3D tensors. + * + * This function computes the matrix multiplication of the input tensor + * `acl_input` and the weight tensor `acl_weight`, and stores the result in the + * destination tensor `acl_dst`. + * The operation is defined as: + * \f[ + * \text {acl_dst}=\text {acl_input@acl_weight} + * \f] + * + * @param ctx The context for the CANN backend operations. + * @param acl_input The input tensor for the matrix multiplication. + * @param acl_weight The weight tensor for the matrix multiplication. + * @param acl_dst The destination tensor where the result of the matrix + * multiplication will be stored. + */ +static void aclnn_mat_mul_3d(ggml_backend_cann_context& ctx, + aclTensor* acl_input, aclTensor* acl_weight, + aclTensor* acl_dst) { + int8_t cube_math_type = 2; + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + void* workspaceAddr = nullptr; + + ACL_CHECK(aclnnBatchMatMulGetWorkspaceSize(acl_input, acl_weight, acl_dst, + cube_math_type, &workspaceSize, + &executor)); + + if (workspaceSize > 0) { + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); + workspaceAddr = workspace_allocator.get(); + } + + ACL_CHECK( + aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, ctx.stream())); +} + /** * @brief Performs matrix multiplication with floating-point precision on * tensors using the CANN backend. @@ -2446,20 +2630,39 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, // broadcast, when weight ne2 or ne3 is not 1, weight need repeat. BCAST_MUL_MAT_SHAPE(input, weight, dst); - // transpose weight: [1,2,3,4] -> [1,2,4,3] + int64_t n_dims = bcast_dims; + if (bcast_input_ne[3] == bcast_weight_ne[3] && bcast_input_ne[3] == 1) { + if (bcast_input_ne[2] == 1 && bcast_weight_ne[2] == 1) { + n_dims = 2; + } else if (bcast_input_ne[2] == 1) { + n_dims = 3; + } + } + + aclTensor* acl_input_tensor = + ggml_cann_create_tensor(input, bcast_input_ne, bcast_input_nb, n_dims); int64_t transpose_ne[] = {bcast_weight_ne[1], bcast_weight_ne[0], bcast_weight_ne[2], bcast_weight_ne[3], bcast_weight_ne[4], bcast_weight_ne[5]}; size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0], bcast_weight_nb[2], bcast_weight_nb[3], bcast_weight_nb[4], bcast_weight_nb[5]}; - aclTensor* acl_weight_tensor = - ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, bcast_dims); - aclTensor* acl_input_tensor = - ggml_cann_create_tensor(input, BCAST_MUL_MAT_PARAM(input)); - aclTensor* acl_dst = ggml_cann_create_tensor(dst, BCAST_MUL_MAT_PARAM(dst)); - aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims); + aclTensor* acl_dst = + ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims); + + switch (n_dims) { + case 2: + aclnn_mat_mul_2d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + break; + case 3: + aclnn_mat_mul_3d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + break; + default: + aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst); + break; + } ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); ACL_CHECK(aclDestroyTensor(acl_input_tensor)); @@ -2480,51 +2683,47 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, * multiplication will be stored. */ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, - ggml_tensor* dst, - const enum ggml_type type) { + ggml_tensor* dst, + const enum ggml_type type) { ggml_tensor* src0 = dst->src[0]; // weight ggml_tensor* src1 = dst->src[1]; // input - // The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC - // is regarded as batch. weight need transpose. - int64_t weight_ne[] = {src0->ne[1], src0->ne[0]}; + // The shape of the weight is NCHW. + // Matrix multiplication uses HW dims. + // HC is regarded as batch. + // weight need transpose. float weight_elem_size; if (type == GGML_TYPE_Q4_0) { weight_elem_size = float(sizeof(uint8_t)) / 2; - } - else if (type == GGML_TYPE_Q8_0) { + } else if (type == GGML_TYPE_Q8_0) { weight_elem_size = float(sizeof(uint8_t)); - } - else { + } else { GGML_ABORT("Only support Q4_0 and Q8_0 MUL_MAT"); } - float weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size}; - - // size of one matrix is element_size * height * width. - size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1]; + float weight_nb[] = {src0->ne[0] * weight_elem_size, weight_elem_size}; + size_t weight_stride = src0->ne[1] * src0->ne[0] * weight_elem_size; size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3]; // scale stored at the end of weight. Also need transpose. - GGML_ASSERT(QK4_0 == QK8_0); - int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0}; size_t scale_elem_size = sizeof(uint16_t); size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size, scale_elem_size}; - size_t scale_stride = scale_elem_size * src0->ne[0] * src0->ne[1] / QK8_0; + size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; char* scale_offset = (char*)src0->data + weight_size; // input - void* input_buffer; size_t input_elem_size = sizeof(uint16_t); int64_t input_ne[] = {src1->ne[0], src1->ne[1]}; - size_t input_nb[] = {input_elem_size, input_elem_size * src1->ne[0]}; - size_t input_stride = input_elem_size * src1->ne[0] * src1->ne[1]; - + size_t input_nb[] = {input_elem_size, input_ne[0] * input_elem_size}; + size_t input_stride = input_ne[0] * input_ne[1] * input_elem_size; ggml_cann_pool_alloc input_alloctor(ctx.pool()); + void* input_buffer = src1->data; + + // case in if (src1->type != GGML_TYPE_F16) { aclTensor* acl_src1_tensor = ggml_cann_create_tensor(src1); - input_alloctor.alloc(ggml_nelements(src1) * input_elem_size); - input_buffer = input_alloctor.get(); + input_buffer = + input_alloctor.alloc(ggml_nelements(src1) * input_elem_size); int64_t* input_cast_ne = src1->ne; size_t input_cast_nb[GGML_MAX_DIMS]; @@ -2537,85 +2736,136 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne, input_cast_nb, GGML_MAX_DIMS); aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16); + ACL_CHECK(aclDestroyTensor(acl_input_tensor)); ACL_CHECK(aclDestroyTensor(acl_src1_tensor)); - } else { - input_buffer = src1->data; } // output size_t output_elem_size = sizeof(uint16_t); - int64_t output_ne[] = {dst->ne[0], dst->ne[1]}; - size_t output_nb[] = {output_elem_size, output_elem_size * dst->ne[0]}; - ggml_cann_pool_alloc output_alloctor( - ctx.pool(), ggml_nelements(dst) * output_elem_size); - void* output_buffer = output_alloctor.get(); - size_t output_stride = output_elem_size * dst->ne[0] * dst->ne[1]; + size_t output_nb[] = {output_elem_size, dst->ne[0] * output_elem_size}; + ggml_cann_pool_alloc output_allocator(ctx.pool()); + void* output_buffer = + output_allocator.alloc(ggml_nelements(dst) * output_elem_size); + size_t output_stride = dst->ne[0] * dst->ne[1] * output_elem_size; // aclnn + int64_t max_elem_size = 65535; + int64_t split_size = (src0->ne[1] / max_elem_size) + 1; + ggml_cann_pool_alloc workspace_allocator(ctx.pool()); + aclOpExecutor* executor = nullptr; uint64_t workspaceSize = 0; - aclOpExecutor* executor; void* workspaceAddr = nullptr; - for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) { for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) { int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]); int64_t c0 = c1 / (src1->ne[2] / src0->ne[2]); - int64_t batch1 = n1 * src1->ne[2] + c1; - int64_t batch0 = n0 * src0->ne[2] + c0; + int64_t batch1 = (n1 * src1->ne[2]) + c1; + int64_t batch0 = (n0 * src0->ne[2]) + c0; aclTensor* acl_input_tensor = ggml_cann_create_tensor( (char*)input_buffer + batch1 * input_stride, ACL_FLOAT16, input_elem_size, input_ne, input_nb, 2); + + // first split + int64_t weight_ne_offset = 0; + int64_t weight_ne[2] = { + max_elem_size > src0->ne[1] ? src0->ne[1] : max_elem_size, + src0->ne[0]}; + int64_t scale_ne_offset = 0; + int64_t scale_ne[2] = {weight_ne[0], weight_ne[1] / QK8_0}; + int64_t output_ne_offset = 0; + int64_t output_ne[2] = {weight_ne[0], dst->ne[1]}; + aclTensor* acl_weight_tensor = ggml_cann_create_tensor( (char*)src0->data + batch0 * weight_stride, ggml_cann_type_mapping(type), weight_elem_size, weight_ne, - weight_nb, 2); + weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset); aclTensor* acl_scale_tensor = ggml_cann_create_tensor( scale_offset + batch0 * scale_stride, ACL_FLOAT16, - scale_elem_size, scale_ne, scale_nb, 2); + scale_elem_size, scale_ne, scale_nb, 2, ACL_FORMAT_ND, + scale_ne_offset); aclTensor* acl_output_tensor = ggml_cann_create_tensor( (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16, - output_elem_size, output_ne, output_nb, 2); + output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, + output_ne_offset); ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize( acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr, nullptr, nullptr, nullptr, QK8_0, acl_output_tensor, &workspaceSize, &executor)); - - if (workspaceSize > 0 && workspaceAddr == nullptr) { - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), - workspaceSize); - workspaceAddr = workspace_allocator.get(); + if (workspaceAddr == nullptr) { + workspaceAddr = workspace_allocator.alloc(workspaceSize); } - ACL_CHECK(aclnnWeightQuantBatchMatmulV2( workspaceAddr, workspaceSize, executor, ctx.stream())); - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); ACL_CHECK(aclDestroyTensor(acl_scale_tensor)); ACL_CHECK(aclDestroyTensor(acl_output_tensor)); + + // other splits + for (int64_t split = 1; split < split_size; split++) { + weight_ne_offset += + weight_elem_size * weight_ne[0] * weight_ne[1]; + weight_ne[0] = max_elem_size * (split + 1) > src0->ne[1] + ? src0->ne[1] - (max_elem_size * split) + : max_elem_size; + scale_ne_offset += scale_elem_size * scale_ne[0] * scale_ne[1]; + scale_ne[0] = weight_ne[0]; + output_ne_offset += + output_elem_size * output_ne[0] * output_ne[1]; + output_ne[0] = weight_ne[0]; + + acl_weight_tensor = ggml_cann_create_tensor( + (char*)src0->data + batch0 * weight_stride, + ggml_cann_type_mapping(type), weight_elem_size, weight_ne, + weight_nb, 2, ACL_FORMAT_ND, weight_ne_offset); + acl_scale_tensor = ggml_cann_create_tensor( + scale_offset + batch0 * scale_stride, ACL_FLOAT16, + scale_elem_size, scale_ne, scale_nb, 2, ACL_FORMAT_ND, + scale_ne_offset); + acl_output_tensor = ggml_cann_create_tensor( + (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16, + output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, + output_ne_offset); + + ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize( + acl_input_tensor, acl_weight_tensor, acl_scale_tensor, + nullptr, nullptr, nullptr, nullptr, QK8_0, + acl_output_tensor, &workspaceSize, &executor)); + ACL_CHECK(aclnnWeightQuantBatchMatmulV2( + workspaceAddr, workspaceSize, executor, ctx.stream())); + + ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); + ACL_CHECK(aclDestroyTensor(acl_scale_tensor)); + ACL_CHECK(aclDestroyTensor(acl_output_tensor)); + } + + ACL_CHECK(aclDestroyTensor(acl_input_tensor)); } } // cast out - int64_t* output_cast_ne = dst->ne; - size_t output_cast_nb[GGML_MAX_DIMS]; - output_cast_nb[0] = sizeof(uint16_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1]; + if (dst->type != GGML_TYPE_F16) { + int64_t* output_cast_ne = dst->ne; + size_t output_cast_nb[GGML_MAX_DIMS]; + output_cast_nb[0] = sizeof(uint16_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1]; + } + + aclTensor* acl_output_tensor = ggml_cann_create_tensor( + output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne, + output_cast_nb, GGML_MAX_DIMS); + aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, + ggml_cann_type_mapping(dst->type)); + + ACL_CHECK(aclDestroyTensor(acl_output_tensor)); + ACL_CHECK(aclDestroyTensor(acl_dst_tensor)); } - - aclTensor* acl_output_tensor = - ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, output_elem_size, - output_cast_ne, output_cast_nb, GGML_MAX_DIMS); - aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ACL_FLOAT); - - ACL_CHECK(aclDestroyTensor(acl_output_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst_tensor)); } void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -2714,12 +2964,14 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclTensor* acl_cos_repeat_tensor, aclTensor* acl_sin_repeat_tensor, - float theta_scale, bool is_neox) { + float theta_scale, float freq_scale, + float attn_factor, bool is_neox) { // int sin/cos cache, cache has different repeat method depond on // @param.is_neox ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src1 = dst->src[1]; // position + ggml_tensor* src2 = dst->src[2]; // freq_factors // arange, [0,1,...,ne0/2] int64_t arange_length = src0->ne[0] / 2; @@ -2748,11 +3000,26 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(), arange_length * sizeof(float_t)); void* theta_scale_buffer = theta_scale_allocator.get(); - aclTensor* acl_theta_scale_tensor = aclnn_ones( + aclTensor* acl_theta_scale_tensor = aclnn_values( ctx, theta_scale_buffer, arange_length * sizeof(float_t), arange_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), theta_scale); aclnn_pow_tensor_tensor(ctx, acl_theta_scale_tensor, acl_arange_tensor); + // freq_scale + if (freq_scale != 1) { + aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true); + } + + // freq_factors + if (src2) { + aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( + src2->data, ggml_cann_type_mapping(src2->type), + ggml_type_size(src2->type), arange_ne, arange_nb, GGML_MAX_DIMS); + aclnn_div_tensor(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, + nullptr, true); + ACL_CHECK(aclDestroyTensor(acl_freq_factors_tensor)); + } + // position GGML_ASSERT(src1->type == GGML_TYPE_I32); int64_t position_length = src1->ne[0]; @@ -2816,6 +3083,12 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, GGML_MAX_DIMS, ACL_FORMAT_ND); aclnn_cos(ctx, acl_permute_tensor, acl_cos_tensor); + // attn_factor + if (attn_factor != 1) { + aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true); + aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true); + } + // repeat if (is_neox) { int64_t repeatsArray[] = {1, 1, 1, 2}; @@ -2841,15 +3114,27 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ACL_CHECK(aclDestroyTensor(acl_cos_tensor)); } +#ifdef __cplusplus +extern "C" { +#endif +aclnnStatus aclnnRotaryPositionEmbeddingGetWorkspaceSize( + const aclTensor* x, const aclTensor* cos, const aclTensor* sin, + int64_t mode, const aclTensor* yOut, uint64_t* workspaceSize, + aclOpExecutor** executor); +aclnnStatus aclnnRotaryPositionEmbedding(void* workspace, + uint64_t workspaceSize, + aclOpExecutor* executor, + aclrtStream stream); +#ifdef __cplusplus +} +#endif + void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: use ascendc // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src2 = dst->src[2]; // freq_factors - // TODO: with freq_factors - GGML_ASSERT(src2 == NULL); - // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; // const int n_past = ((int32_t *) dst->op_params)[0]; @@ -2867,13 +3152,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float)); - GGML_ASSERT(n_dims <= ne0); + // TODO: n_dims <= ne0 + GGML_ASSERT(n_dims == ne0); GGML_ASSERT(n_dims % 2 == 0); - // TODO: ext_factor != 0 GGML_ASSERT(ext_factor == 0); - // TODO: freq_scale != 1 - GGML_ASSERT(freq_scale == 1); const float theta_scale = powf(freq_base, -2.0f / n_dims); @@ -2904,7 +3187,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor, - theta_scale, is_neox); + theta_scale, freq_scale, attn_factor, is_neox); + + aclTensor* acl_src = ggml_cann_create_tensor(src0); + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + +#ifdef ASCEND_310P + // Special ROPE operation for 310P // roll input void* input_roll_buffer; @@ -2947,7 +3236,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { for (int i = 1; i < GGML_MAX_DIMS; i++) { minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; } - acl_minus_one_tensor = aclnn_ones( + acl_minus_one_tensor = aclnn_values( ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); int64_t dim = 3; @@ -2974,17 +3263,15 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor)); ACL_CHECK(aclDestroyTensor(acl_input_tensor)); - // init [-1, -1, -1, 1, 1,1,...] minus_one_scale_buffer = minus_one_scale_allocator.get(); - int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; size_t minus_one_nb[GGML_MAX_DIMS]; minus_one_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; } - acl_minus_one_tensor = aclnn_ones( + acl_minus_one_tensor = aclnn_values( ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); // -1 * first half @@ -3026,14 +3313,12 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { acl_input_roll_mul_scale_tensor); // output - aclTensor* acl_src0 = ggml_cann_create_tensor(src0); - aclTensor* acl_dst = ggml_cann_create_tensor(dst); void* output_fp32_buffer; if (src0->type == GGML_TYPE_F32) { - aclnn_inplace_mul(ctx, acl_src0, acl_cos_reshape_tensor); + aclnn_inplace_mul(ctx, acl_src, acl_cos_reshape_tensor); aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor); - aclnn_add(ctx, acl_src0, acl_input_roll_mul_scale_tensor, acl_dst); + aclnn_add(ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst); // TODO: ne0 != n_dims in mode2 } else if (src0->type == GGML_TYPE_F16) { size_t input_fp32_nb[GGML_MAX_DIMS]; @@ -3060,7 +3345,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* output_fp32_tensor = ggml_cann_create_tensor( output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne, input_fp32_nb, GGML_MAX_DIMS); - aclnn_mul(ctx, acl_src0, acl_cos_reshape_tensor, input_fp32_tensor1); + aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1); aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, input_fp32_tensor2); aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2, @@ -3070,13 +3355,73 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ACL_CHECK(aclDestroyTensor(input_fp32_tensor1)); ACL_CHECK(aclDestroyTensor(input_fp32_tensor2)); ACL_CHECK(aclDestroyTensor(output_fp32_tensor)); + ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); + ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor)); + ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor)); + ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor)); + ACL_CHECK(aclDestroyTensor(acl_src)); + } + return; +#endif + + // src0 == GGML_TYPE_F16 + // TODO: optimization this `if` code + if (src0->type == GGML_TYPE_F16) { + ggml_cann_pool_alloc sin_final_allocator( + ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type)); + ggml_cann_pool_alloc cos_final_allocator( + ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type)); + void* sin_final_buffer = sin_final_allocator.get(); + void* cos_final_buffer = cos_final_allocator.get(); + + int64_t sin_final_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; + size_t sin_final_nb[GGML_MAX_DIMS]; + sin_final_nb[0] = ggml_type_size(src0->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + sin_final_nb[i] = sin_final_nb[i - 1] * sin_final_ne[i - 1]; + } + aclTensor* acl_sin_final_tensor = ggml_cann_create_tensor( + sin_final_buffer, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), sin_final_ne, sin_final_nb, + GGML_MAX_DIMS); + aclTensor* acl_cos_final_tensor = ggml_cann_create_tensor( + cos_final_buffer, ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), sin_final_ne, sin_final_nb, + GGML_MAX_DIMS); + + aclnn_cast(ctx, acl_sin_reshape_tensor, acl_sin_final_tensor, + ggml_cann_type_mapping(src0->type)); + aclnn_cast(ctx, acl_cos_reshape_tensor, acl_cos_final_tensor, + ggml_cann_type_mapping(src0->type)); + ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); + ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); + acl_sin_reshape_tensor = acl_sin_final_tensor; + acl_cos_reshape_tensor = acl_cos_final_tensor; } - ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); + uint64_t workspaceSize = 0; + aclOpExecutor* executor; + + void* workspaceAddr = nullptr; + + int acl_mode = mode; + if (mode == 0) { + acl_mode = 1; + } + + ACL_CHECK(aclnnRotaryPositionEmbeddingGetWorkspaceSize( + acl_src, acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, + acl_dst, &workspaceSize, &executor)); + if (workspaceSize > 0) { + ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); + workspaceAddr = workspace_allocator.get(); + } + + ACL_CHECK(aclnnRotaryPositionEmbedding(workspaceAddr, workspaceSize, + executor, ctx.stream())); + + ACL_CHECK(aclDestroyTensor(acl_src)); ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_src0)); + ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); ACL_CHECK(aclDestroyTensor(acl_dst)); } diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index e6a570107..5164cb74e 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -211,22 +211,26 @@ struct ggml_cann_pool_alloc { struct ggml_backend_cann_context { int32_t device; /**< Device ID. */ std::string name; /**< Name of the device. */ + std::string description; /**< Description of the device. */ aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ - aclrtStream streams[GGML_CANN_MAX_STREAMS] = { - {nullptr}}; /**< Array of streams for the device. */ + aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ /** * @brief Constructor for initializing the context with a given device. * @param device Device ID. */ explicit ggml_backend_cann_context(int device) - : device(device), name("CANN" + std::to_string(device)) {} + : device(device), name("CANN" + std::to_string(device)) { + ggml_cann_set_device(device); + description = aclrtGetSocName(); + } /** * @brief Destructor for cleaning up resources. */ ~ggml_backend_cann_context() { + ggml_cann_set_device(device); if (copy_event != nullptr) { ACL_CHECK(aclrtDestroyEvent(copy_event)); } diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp similarity index 77% rename from ggml/src/ggml-cann.cpp rename to ggml/src/ggml-cann/ggml-cann.cpp index 06930ba2e..d410c0244 100644 --- a/ggml/src/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -30,6 +30,7 @@ #include #include +#include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-cann/aclnn_ops.h" #include "ggml-cann/common.h" @@ -38,68 +39,7 @@ #include "ggml-common.h" -/** - * @brief Default logging callback for GGML. - * - * This function is the default logging callback that logs messages to stderr. - * - * @param level The log level. - * @param msg The log message. - * @param user_data User data passed to the callback. - */ -static void ggml_cann_default_log_callback(enum ggml_log_level level, - const char* msg, void* user_data) { - GGML_UNUSED(level); - GGML_UNUSED(user_data); - fprintf(stderr, "%s", msg); -} - -ggml_log_callback ggml_cann_log_callback = ggml_cann_default_log_callback; -void* ggml_cann_log_user_data = NULL; - -GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback, - void* user_data) { - ggml_cann_log_callback = log_callback; - ggml_cann_log_user_data = user_data; -} - -#define GGML_CANN_LOG_INFO(...) ggml_cann_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) -#define GGML_CANN_LOG_WARN(...) ggml_cann_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) -#define GGML_CANN_LOG_ERROR(...) \ - ggml_cann_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) - -GGML_ATTRIBUTE_FORMAT(2, 3) - -/** - * @brief Log a message using the current logging callback. - * - * This function formats a log message and passes it to the current logging - * callback. - * - * @param level The log level. - * @param format The format string for the log message. - * @param ... The arguments for the format string. - */ -static void ggml_cann_log(enum ggml_log_level level, const char* format, ...) { - if (ggml_cann_log_callback != NULL) { - va_list args; - va_start(args, format); - char buffer[128]; - int len = vsnprintf(buffer, 128, format, args); - if (len < 128) { - ggml_cann_log_callback(level, buffer, ggml_cann_log_user_data); - } else { - // vsnprintf adds a null terminator - std::vector buffer2(len + 1); - va_end(args); - va_start(args, format); - vsnprintf(&buffer2[0], buffer2.size(), format, args); - ggml_cann_log_callback(level, buffer2.data(), - ggml_cann_log_user_data); - } - va_end(args); - } -} +#define GGML_CANN_NAME "CANN" /** * @brief Handles CANN errors by printing an error message and aborting. @@ -115,10 +55,10 @@ static void ggml_cann_log(enum ggml_log_level level, const char* format, ...) { int32_t id = -1; aclrtGetDevice(&id); - GGML_CANN_LOG_ERROR("CANN error: %s\n", msg); - GGML_CANN_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, + GGML_LOG_ERROR("CANN error: %s\n", msg); + GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); - GGML_CANN_LOG_ERROR(" %s\n", stmt); + GGML_LOG_ERROR(" %s\n", stmt); // abort with GGML_ASSERT to get a stack trace GGML_ABORT("CANN error"); } @@ -164,7 +104,7 @@ static ggml_cann_device_info ggml_cann_init() { aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count); if (err != ACL_SUCCESS) { - GGML_CANN_LOG_ERROR("%s: failed to initialize CANN: %s\n", + GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n", __func__, aclGetRecentErrMsg()); return info; } @@ -182,6 +122,10 @@ static ggml_cann_device_info ggml_cann_init() { ACL_CHECK(aclrtMemGetAllocationGranularity( &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED, &info.devices[id].vmm_granularity)); + + size_t free, total; + ggml_backend_cann_get_device_memory(id, &free, &total); + info.devices[id].total_vram = free; } // TODO: add more device info later. @@ -268,6 +212,11 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { * @return A pointer to the allocated buffer. */ void* alloc(size_t size, size_t* actual_size) override { + const size_t alignment = 128; + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; + } #ifdef DEBUG_CANN_MALLOC int nnz = 0; size_t max_size = 0; @@ -306,15 +255,13 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { return ptr; } void* ptr; - size_t look_ahead_size = (size_t)(1.05 * size); - look_ahead_size = 256 * ((look_ahead_size + 255) / 256); ggml_cann_set_device(device); ACL_CHECK( - aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST)); - *actual_size = look_ahead_size; - pool_size += look_ahead_size; + aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + *actual_size = size; + pool_size += size; #ifdef DEBUG_CANN_MALLOC - GGML_CANN_LOG_INFO( + GGML_LOG_INFO( "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, " "requested %u MB\n", __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024), @@ -356,7 +303,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { /** * @brief The maximum size of the virtual memory pool (32 GB). */ - static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB + size_t max_size; /** * @brief The device ID associated with this buffer pool. @@ -401,7 +348,11 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { */ explicit ggml_cann_pool_vmm(int device) : device(device), - granularity(ggml_cann_info().devices[device].vmm_granularity) {} + granularity(ggml_cann_info().devices[device].vmm_granularity) { + auto dev = ggml_cann_info().devices[device]; + granularity = dev.vmm_granularity; + max_size = dev.total_vram; + } /** * @brief Destructor to free all buffers in the virtual memory pool. @@ -430,17 +381,19 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { // round up the allocation size to the alignment to ensure that all // allocations are aligned for all data types const size_t alignment = 128; - size = alignment * ((size + alignment - 1) / alignment); + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; + } size_t avail = pool_size - pool_used; if (size > avail) { // round up to the next multiple of the granularity size_t reserve_size = size - avail; - reserve_size = - granularity * ((reserve_size + granularity - 1) / granularity); + reserve_size = GGML_PAD(reserve_size, granularity); - GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE); + GGML_ASSERT(pool_size + reserve_size <= max_size); // allocate more physical memory aclrtPhysicalMemProp prop = {}; @@ -456,7 +409,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { // reserve virtual address space (if not already reserved) if (pool_addr == 0) { ACL_CHECK(aclrtReserveMemAddress( - &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1)); + &pool_addr, max_size, 0, NULL, 1)); } // map at the end of the pool @@ -469,10 +422,11 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { // add to the pool pool_size += reserve_size; - // GGML_CANN_LOG_INFO("cann pool[%d]: size increased to %llu MB ( - // reserved %llu MB)\n", - // device, (unsigned long long) (pool_size/1024/1024), - // (unsigned long long) (reserve_size/1024/1024)); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n", + device, (unsigned long long) (pool_size/1024/1024), + (unsigned long long) (reserve_size/1024/1024)); +#endif } GGML_ASSERT(pool_addr != 0); @@ -482,7 +436,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { pool_used += size; #ifdef DEBUG_CANN_MALLOC - GGML_CANN_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device, + GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long)size, (unsigned long long)ptr); #endif return ptr; @@ -496,7 +450,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { */ void free(void* ptr, size_t size) override { #ifdef DEBUG_CANN_MALLOC - GGML_CANN_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device, + GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long)size, (unsigned long long)ptr); #endif @@ -517,7 +471,6 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { */ std::unique_ptr ggml_backend_cann_context::new_pool_for_device( int device) { - // return std::unique_ptr(new ggml_cann_pool_leg(device)); return std::unique_ptr(new ggml_cann_pool_vmm(device)); } @@ -549,23 +502,6 @@ struct ggml_backend_cann_buffer_context { ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); } }; -/** - * @brief Retrieve the name associated with a CANN buffer. - * - * This function returns the name of a CANN buffer, which is stored in the - * context of the buffer. - * - * @param buffer The CANN buffer whose name is to be retrieved. - * @return A pointer to a C-string containing the name of the buffer. - */ - -GGML_CALL static const char* ggml_backend_cann_buffer_get_name( - ggml_backend_buffer_t buffer) { - return "CANN"; - - GGML_UNUSED(buffer); -} - /** * @brief Check if a buffer is a CANN buffer. * @@ -575,9 +511,10 @@ GGML_CALL static const char* ggml_backend_cann_buffer_get_name( * @param buffer The buffer to check. * @return true if the buffer is a CANN buffer, false otherwise. */ -GGML_CALL static bool ggml_backend_buffer_is_cann( +static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft); +static bool ggml_backend_buffer_is_cann( ggml_backend_buffer_t buffer) { - return buffer->iface.get_name == ggml_backend_cann_buffer_get_name; + return ggml_backend_buft_is_cann(buffer->buft); } /** @@ -588,7 +525,7 @@ GGML_CALL static bool ggml_backend_buffer_is_cann( * * @param buffer The CANN buffer to free. */ -GGML_CALL static void ggml_backend_cann_buffer_free_buffer( +static void ggml_backend_cann_buffer_free_buffer( ggml_backend_buffer_t buffer) { ggml_backend_cann_buffer_context* ctx = (ggml_backend_cann_buffer_context*)buffer->context; @@ -604,7 +541,7 @@ GGML_CALL static void ggml_backend_cann_buffer_free_buffer( * @param buffer The CANN buffer whose base pointer is to be retrieved. * @return A pointer to the base of the device memory allocated for the buffer. */ -GGML_CALL static void* ggml_backend_cann_buffer_get_base( +static void* ggml_backend_cann_buffer_get_base( ggml_backend_buffer_t buffer) { ggml_backend_cann_buffer_context* ctx = (ggml_backend_cann_buffer_context*)buffer->context; @@ -624,9 +561,9 @@ GGML_CALL static void* ggml_backend_cann_buffer_get_base( * @param dst Pointer to the destination buffer where transformed data will be * stored. */ -GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, - const void* src, - void* dst) { +static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, + const void* src, + void* dst) { int64_t n_elems = ggml_nelements(tensor); int64_t groups = n_elems / QK4_0; @@ -676,7 +613,7 @@ GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor, * @param dst Pointer to the destination buffer where the Q4.0 formatted data * will be stored. */ -GGML_CALL static void ggml_backend_cann_transform_back_q4_0( +static void ggml_backend_cann_transform_back_q4_0( const ggml_tensor* tensor, void* src, void* dst) { int64_t n_elems = ggml_nelements(tensor); @@ -725,9 +662,9 @@ GGML_CALL static void ggml_backend_cann_transform_back_q4_0( * @param dst Pointer to the destination buffer where transformed data will be * stored. */ -GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, - const void* src, - void* dst) { +static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, + const void* src, + void* dst) { int64_t n_elems = ggml_nelements(tensor); int64_t groups = n_elems / QK8_0; size_t quant_bytes = n_elems * sizeof(uint8_t); @@ -759,7 +696,7 @@ GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor, * @param dst Pointer to the destination buffer where the Q8.0 formatted data * will be stored. */ -GGML_CALL static void ggml_backend_cann_transform_back_q8_0( +static void ggml_backend_cann_transform_back_q8_0( const ggml_tensor* tensor, const void* src, void* dst) { int64_t n_elems = ggml_nelements(tensor); int64_t groups = n_elems / QK8_0; @@ -791,8 +728,8 @@ GGML_CALL static void ggml_backend_cann_transform_back_q8_0( * @param dst Pointer to the destination buffer where transformed data will be * stored. */ -GGML_CALL static void ggml_backend_cann_transform(ggml_tensor* tensor, - const void* src, void* dst) { +static void ggml_backend_cann_transform(ggml_tensor* tensor, + const void* src, void* dst) { switch (tensor->type) { case GGML_TYPE_Q4_0: ggml_backend_cann_transform_q4_0(tensor, src, dst); @@ -817,7 +754,7 @@ GGML_CALL static void ggml_backend_cann_transform(ggml_tensor* tensor, * @param dst Pointer to the destination buffer where transformed tensor data * will be stored. */ -GGML_CALL static void ggml_backend_cann_transform_back( +static void ggml_backend_cann_transform_back( const ggml_tensor* tensor, void* src, void* dst) { switch (tensor->type) { case GGML_TYPE_Q4_0: @@ -840,7 +777,7 @@ GGML_CALL static void ggml_backend_cann_transform_back( * @param type The tensor type to check. * @return true if transformation is needed, false otherwise. */ -GGML_CALL static bool need_transform(ggml_type type) { +static bool need_transform(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: @@ -859,7 +796,7 @@ GGML_CALL static bool need_transform(ggml_type type) { * @param buffer The CANN buffer from which to initialize the tensor. * @param tensor Pointer to the tensor to be initialized. */ -GGML_CALL static void ggml_backend_cann_buffer_init_tensor( +static void ggml_backend_cann_buffer_init_tensor( ggml_backend_buffer_t buffer, ggml_tensor* tensor) { if (tensor->view_src != NULL && tensor->view_offs == 0) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); @@ -895,7 +832,7 @@ GGML_CALL static void ggml_backend_cann_buffer_init_tensor( * @param offset Offset in the source data from where to start copying. * @param size Size of the data to be copied, in bytes. */ -GGML_CALL static void ggml_backend_cann_buffer_set_tensor( +static void ggml_backend_cann_buffer_set_tensor( ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data, size_t offset, size_t size) { ggml_backend_cann_buffer_context *ctx = @@ -913,13 +850,6 @@ GGML_CALL static void ggml_backend_cann_buffer_set_tensor( void *transform_buffer = malloc(size); ggml_backend_cann_transform(tensor, data, transform_buffer); -#ifndef NDEBUG - void *check_buffer = malloc(size); - ggml_backend_cann_transform_back(tensor, transform_buffer, - check_buffer); - GGML_ASSERT(memcmp(data, check_buffer, size) == 0); - free(check_buffer); -#endif ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); @@ -940,7 +870,7 @@ GGML_CALL static void ggml_backend_cann_buffer_set_tensor( * @param offset Offset in the destination buffer where to start copying. * @param size Size of the data to be copied, in bytes. */ -GGML_CALL static void ggml_backend_cann_buffer_get_tensor( +static void ggml_backend_cann_buffer_get_tensor( ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data, size_t offset, size_t size) { ggml_backend_cann_buffer_context* ctx = @@ -974,7 +904,7 @@ GGML_CALL static void ggml_backend_cann_buffer_get_tensor( * @param dst Pointer to the destination tensor where the data will be copied. * @return true if the copy operation succeeded, false otherwise. */ -GGML_CALL static bool ggml_backend_cann_buffer_cpy_tensor( +static bool ggml_backend_cann_buffer_cpy_tensor( ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) { if (ggml_backend_buffer_is_cann(src->buffer)) { ggml_backend_cann_buffer_context* src_ctx = @@ -1016,7 +946,7 @@ GGML_CALL static bool ggml_backend_cann_buffer_cpy_tensor( * @param buffer The CANN buffer to be cleared. * @param value The value to which each byte in the buffer will be set. */ -GGML_CALL static void ggml_backend_cann_buffer_clear( +static void ggml_backend_cann_buffer_clear( ggml_backend_buffer_t buffer, uint8_t value) { ggml_backend_cann_buffer_context* ctx = (ggml_backend_cann_buffer_context*)buffer->context; @@ -1031,11 +961,11 @@ GGML_CALL static void ggml_backend_cann_buffer_clear( * This structure defines function pointers to operations that can be performed * on a CANN buffer within the backend. */ -static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { - /* .get_name = */ ggml_backend_cann_buffer_get_name, +static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, /* .get_base = */ ggml_backend_cann_buffer_get_base, /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor, @@ -1063,11 +993,12 @@ struct ggml_backend_cann_buffer_type_context { * @param buft Pointer to the buffer type context. * @return Const pointer to the C-style string containing the name. */ -GGML_CALL static const char* ggml_backend_cann_buffer_type_name( +static const char* ggml_backend_cann_buffer_type_name( ggml_backend_buffer_type_t buft) { - return "CANN"; + ggml_backend_cann_buffer_type_context* buft_ctx = + (ggml_backend_cann_buffer_type_context*)buft->context; - GGML_UNUSED(buft); + return buft_ctx->name.c_str(); } /** @@ -1080,7 +1011,7 @@ GGML_CALL static const char* ggml_backend_cann_buffer_type_name( * @param size Size in bytes of the buffer to allocate. * @return Pointer to the allocated buffer, or nullptr if allocation fails. */ -GGML_CALL static ggml_backend_buffer_t +static ggml_backend_buffer_t ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_cann_buffer_type_context* buft_ctx = @@ -1093,7 +1024,7 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, void* dev_ptr; aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); if (err != ACL_SUCCESS) { - GGML_CANN_LOG_ERROR( + GGML_LOG_ERROR( "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, aclGetRecentErrMsg()); @@ -1119,7 +1050,7 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, * @return The alignment requirement in bytes (fixed at 128 bytes for CANN * buffers). */ -GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alignment( +static size_t ggml_backend_cann_buffer_type_get_alignment( ggml_backend_buffer_type_t buft) { return 128; @@ -1140,7 +1071,7 @@ GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alignment( * @return The total allocation size in bytes required for the tensor in the * CANN buffer. */ -GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alloc_size( +static size_t ggml_backend_cann_buffer_type_get_alloc_size( ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) { size_t size = ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; @@ -1166,19 +1097,25 @@ GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alloc_size( GGML_UNUSED(buft); } +static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + /** * @brief Interface for managing CANN buffer types in the GGML backend. * * Provides function pointers for allocating, querying properties, and managing * memory for CANN buffer types in the GGML backend. */ -static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = { +static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = { /* .get_name = */ ggml_backend_cann_buffer_type_name, /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size, - /* .is_host = */ NULL, + /* .is_host = */ ggml_backend_cann_buffer_type_is_host, }; /** @@ -1191,7 +1128,7 @@ static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = { * @return A pointer to the buffer type interface for the specified device, or * nullptr if the device index is out of range. */ -GGML_CALL ggml_backend_buffer_type_t +ggml_backend_buffer_type_t ggml_backend_cann_buffer_type(int32_t device) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -1206,9 +1143,10 @@ ggml_backend_cann_buffer_type(int32_t device) { static bool ggml_backend_cann_buffer_type_initialized = false; if (!ggml_backend_cann_buffer_type_initialized) { - for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) { + for (int32_t i = 0; i < ggml_cann_info().device_count; i++) { ggml_backend_cann_buffer_types[i] = { /* .iface = */ ggml_backend_cann_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i), /* .context = */ new ggml_backend_cann_buffer_type_context{ i, "CANN" + std::to_string(i)}, @@ -1220,6 +1158,121 @@ ggml_backend_cann_buffer_type(int32_t device) { return &ggml_backend_cann_buffer_types[device]; } +/** + * @brief Retrieves the name associated with a CANN host buffer type. + * + * This function returns the descriptive name associated with the specified + * CANN host buffer type context. + * + * @param buft Pointer to the host buffer type context. + * @return Const pointer to the C-style string containing the name. + */ +static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) { + return "CANN_Host"; + + GGML_UNUSED(buft); +} + +/** + * @brief Retrieves the name associated with a CANN host buffer. + * + * This function returns the descriptive name associated with the specified + * CANN host buffer context. + * + * @param buft Pointer to the host buffer context. + * @return Const pointer to the C-style string containing the name. + */ +static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) { + return "CANN_Host"; + + GGML_UNUSED(buffer); +} + +/** + * @brief Free resources associated with a CANN host buffer. + * + * This function frees the resources associated with a CANN host buffer, including + * its context. + * + * @param buffer The CANN host buffer to free. + */ +static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) { + ACL_CHECK(aclrtFreeHost(buffer->context)); +} + +/** + * @brief Allocates a new CANN host buffer of the specified size. + * + * This function allocates a new CANN host buffer with the given size. + * @param size Size in bytes of the host buffer to allocate. + * @return Pointer to the allocated host buffer, or nullptr if allocation fails. + */ +static void * ggml_cann_host_malloc(size_t size) { + if (getenv("GGML_CANN_NO_PINNED") != nullptr) { + return nullptr; + } + + const size_t alignment = 128; + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; + } + + void * hostPtr = nullptr; + aclError err = aclrtMallocHost((void **) &hostPtr, size); + if (err != ACL_SUCCESS) { + GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, + size / 1024.0 / 1024.0, aclGetRecentErrMsg()); + return nullptr; + } + return hostPtr; +} + +/** + * @brief Allocates a new CANN host buffer of the specified type and size. + * + * @param buft Pointer to the host buffer type context. + * @param size Size in bytes of the host buffer to allocate. + * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails. + */ +static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * hostPtr = ggml_cann_host_malloc(size); + + if (hostPtr == nullptr) { + // fallback to cpu buffer + return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + } + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free; + + return buffer; +} + +/** + * @brief Interface for managing CANN host buffer types in the GGML backend. + * + * Provides function pointers for allocating, querying properties, and managing + * memory for CANN buffer types in the GGML backend. + */ +ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cann_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0), + /* .context = */ nullptr, + }; + + return &ggml_backend_cann_buffer_type_host; +} + /** * @brief Computes the forward operation for a given tensor using CANN * operations. @@ -1383,7 +1436,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, * @param backend Pointer to the CANN backend structure. * @return A pointer to a constant string representing the backend name. */ -GGML_CALL static const char* ggml_backend_cann_name(ggml_backend_t backend) { +static const char* ggml_backend_cann_name(ggml_backend_t backend) { ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; @@ -1398,7 +1451,7 @@ GGML_CALL static const char* ggml_backend_cann_name(ggml_backend_t backend) { * * @param backend Pointer to the CANN backend structure to be freed. */ -GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) { +static void ggml_backend_cann_free(ggml_backend_t backend) { ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; ACL_CHECK(aclrtSynchronizeDevice()); @@ -1413,24 +1466,6 @@ GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) { delete backend; } -/** - * @brief Retrieves the default buffer type associated with the CANN backend. - * - * This function returns the buffer type specific to the device associated - * with the CANN backend. It is used to allocate buffers for computations - * performed by the backend. - * - * @param backend Pointer to the CANN backend structure. - * @return Pointer to the buffer type structure for the CANN backend. - */ -GGML_CALL static ggml_backend_buffer_type_t -ggml_backend_cann_get_default_buffer_type(ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - return ggml_backend_cann_buffer_type(cann_ctx->device); -} - /** * @brief Sets tensor data asynchronously in the CANN backend. * @@ -1444,11 +1479,11 @@ ggml_backend_cann_get_default_buffer_type(ggml_backend_t backend) { * @param offset Offset in bytes within the host data. * @param size Size of the data to copy in bytes. */ -GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, - ggml_tensor *tensor, - const void *data, - size_t offset, - size_t size) { +static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, + ggml_tensor *tensor, + const void *data, + size_t offset, + size_t size) { ggml_backend_cann_context *cann_ctx = (ggml_backend_cann_context *)backend->context; @@ -1460,13 +1495,6 @@ GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, void *transform_buffer = malloc(size); ggml_backend_cann_transform(tensor, data, transform_buffer); -#ifndef NDEBUG - void *check_buffer = malloc(size); - ggml_backend_cann_transform_back(tensor, transform_buffer, - check_buffer); - GGML_ASSERT(memcmp(data, check_buffer, size)); - free(check_buffer); -#endif ACL_CHECK(aclrtMemcpyAsync( (char *)tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream())); @@ -1475,7 +1503,7 @@ GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, } } -GGML_CALL static void ggml_backend_cann_get_tensor_async( +static void ggml_backend_cann_get_tensor_async( ggml_backend_t backend, const ggml_tensor *tensor, void *data, size_t offset, size_t size) { ggml_backend_cann_context *cann_ctx = @@ -1514,7 +1542,7 @@ GGML_CALL static void ggml_backend_cann_get_tensor_async( * @param dst Pointer to the destination tensor to copy data to. * @return true if the copy operation succeeds, false otherwise. */ -GGML_CALL static bool ggml_backend_cann_cpy_tensor_async( +static bool ggml_backend_cann_cpy_tensor_async( ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor* src, ggml_tensor* dst) { GGML_ASSERT(ggml_backend_is_cann(backend_src) || @@ -1582,7 +1610,7 @@ GGML_CALL static bool ggml_backend_cann_cpy_tensor_async( * * @param backend Pointer to the CANN backend structure to synchronize. */ -GGML_CALL static void ggml_backend_cann_synchronize(ggml_backend_t backend) { +static void ggml_backend_cann_synchronize(ggml_backend_t backend) { ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; @@ -1603,7 +1631,7 @@ GGML_CALL static void ggml_backend_cann_synchronize(ggml_backend_t backend) { * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation * completes successfully, otherwise an appropriate error status. */ -GGML_CALL static enum ggml_status ggml_backend_cann_graph_compute( +static enum ggml_status ggml_backend_cann_graph_compute( ggml_backend_t backend, ggml_cgraph* cgraph) { ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; @@ -1620,7 +1648,7 @@ GGML_CALL static enum ggml_status ggml_backend_cann_graph_compute( bool ok = ggml_cann_compute_forward(*cann_ctx, node); if (!ok) { - GGML_CANN_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, + GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); @@ -1641,7 +1669,7 @@ GGML_CALL static enum ggml_status ggml_backend_cann_graph_compute( * @return bool Returns true if the operation is supported by the backend, * otherwise false. */ -GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, +static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_tensor* op) { switch (op->op) { case GGML_OP_UNARY: @@ -1659,12 +1687,14 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, } case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { + case GGML_TYPE_Q8_0: + // Current groupsize should not be greater than k-1 in + // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize + if (op->src[0]->ne[0] <= QK8_0) { + return false; + } case GGML_TYPE_F16: case GGML_TYPE_F32: - case GGML_TYPE_Q8_0: - // TODO: fix me - // Current groupsize should not be greater than k-1 in - // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(). case GGML_TYPE_Q4_0: return true; default: @@ -1696,9 +1726,50 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, return false; } } + case GGML_OP_CONT: { + // TODO: support GGML_TYPE_BF16 + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + } + case GGML_OP_ROPE: { + // TODO: with ops-test v == 1 + float * ext_factor = (float*)((int32_t*)op->op_params + 7); + // TODO: n_dims <= ne0 + if (op->src[0]->ne[0] != op->op_params[1]) { + return false; + } + // TODO: ext_factor != 0 + if (*ext_factor != 0) { + return false; + } + + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } + + return true; + } + case GGML_OP_UPSCALE: { + // aclnnUpsampleNearest2dGetWorkspaceSize not support + // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal + if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) { + return false; + } + return true; + } + case GGML_OP_IM2COL: + case GGML_OP_CONCAT: case GGML_OP_DUP: case GGML_OP_REPEAT: - case GGML_OP_CONCAT: case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -1712,17 +1783,13 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_CLAMP: - case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: - case GGML_OP_ROPE: - case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: case GGML_OP_GROUP_NORM: - case GGML_OP_UPSCALE: case GGML_OP_PAD: case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: @@ -1732,7 +1799,7 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend, return false; } - GGML_UNUSED(backend); + GGML_UNUSED(dev); } /** @@ -1750,31 +1817,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { return buft->iface.get_name == ggml_backend_cann_buffer_type_name; } -/** - * @brief Checks if the CANN backend supports a specific backend buffer type. - * - * This function determines whether the CANN backend supports the given backend - * buffer type by comparing the device context of the backend and buffer type. - * It returns true if the devices are same between the backend context and - * buffer type context. - * - * @param backend Pointer to the CANN backend. - * @param buft Pointer to the backend buffer type to check. - * @return bool Returns true if the CANN backend supports the buffer type, - * otherwise false. - */ -GGML_CALL static bool ggml_backend_cann_supports_buft( - ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (ggml_backend_buft_is_cann(buft)) { - ggml_backend_cann_context * cann_ctx = - (ggml_backend_cann_context *)backend->context; - ggml_backend_cann_buffer_type_context * buft_ctx = - (ggml_backend_cann_buffer_type_context *)buft->context; - return buft_ctx->device == cann_ctx->device; - } - return false; -} - /** * @brief Determines if a tensor operation should be offloaded to the CANN * backend. @@ -1789,54 +1831,14 @@ GGML_CALL static bool ggml_backend_cann_supports_buft( * @return bool Returns true if the operation should be offloaded, otherwise * false. */ -GGML_CALL static bool ggml_backend_cann_offload_op(ggml_backend_t backend, +static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor* op) { const int min_batch_size = 32; - GGML_UNUSED(backend); + GGML_UNUSED(dev); return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; } -/** - * @brief Creates a new event for the CANN backend. - * - * This function initializes a new event for the CANN backend by setting the - * device and creating an ACL runtime event. The created event is then wrapped - * in a ggml_backend_event structure and returned. - * - * @param backend Pointer to the CANN backend. - * @return ggml_backend_event_t Returns a pointer to the new event structure. - */ -static ggml_backend_event_t ggml_backend_cann_event_new( - ggml_backend_t backend) { - ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)backend->context; - - ggml_cann_set_device(cann_ctx->device); - - aclrtEvent event; - ACL_CHECK(aclrtCreateEvent(&event)); - - return new ggml_backend_event{ - /* .backend = */ backend, - /* .context = */ event, - }; -} - -/** - * @brief Frees a CANN backend event. - * - * This function destroys the ACL runtime event associated with the given CANN - * backend event and then deletes the event structure itself. - * - * @param event Pointer to the event structure to be freed. - */ -static void ggml_backend_cann_event_free(ggml_backend_event_t event) { - ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context)); - - delete event; -} - /** * @brief Records an event on the CANN backend stream. * @@ -1845,10 +1847,9 @@ static void ggml_backend_cann_event_free(ggml_backend_event_t event) { * * @param event Pointer to the event structure to be recorded. */ -static void ggml_backend_cann_event_record(ggml_backend_event_t event) { +static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) { ggml_backend_cann_context* cann_ctx = - (ggml_backend_cann_context*)event->backend->context; - + (ggml_backend_cann_context*)backend->context; ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream())); } @@ -1866,8 +1867,7 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; - - if (ggml_backend_is_cann(event->backend)) { + if (ggml_backend_is_cann(backend)) { ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(), (aclrtEvent)event->context)); } else { @@ -1875,17 +1875,6 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend, } } -/** - * @brief Synchronizes the given event on the CANN backend. - * - * This function waits for the specified event to complete on the ACL runtime. - * - * @param event Pointer to the event structure to be synchronized. - */ -static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) { - ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context)); -} - /** * @brief Structure defining the interface for the CANN backend. * @@ -1893,10 +1882,9 @@ static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) { * supported by the CANN backend, including name retrieval, memory * management, tensor operations, synchronization, and event handling. */ -static ggml_backend_i ggml_backend_cann_interface = { +static const ggml_backend_i ggml_backend_cann_interface = { /* .get_name = */ ggml_backend_cann_name, /* .free = */ ggml_backend_cann_free, - /* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type, /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async, /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async, @@ -1906,14 +1894,8 @@ static ggml_backend_i ggml_backend_cann_interface = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_cann_graph_compute, - /* .supports_op = */ ggml_backend_cann_supports_op, - /* .supports_buft = */ ggml_backend_cann_supports_buft, - /* .offload_op = */ ggml_backend_cann_offload_op, - /* .event_new = */ ggml_backend_cann_event_new, - /* .event_free = */ ggml_backend_cann_event_free, /* .event_record = */ ggml_backend_cann_event_record, /* .event_wait = */ ggml_backend_cann_event_wait, - /* .event_synchronize = */ ggml_backend_cann_event_synchronize, }; /** @@ -1930,91 +1912,277 @@ static ggml_guid_t ggml_backend_cann_guid() { return &guid; } -GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) { +// backend device +struct ggml_backend_cann_device_context { + int device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + ggml_backend_cann_get_device_memory(ctx->device, free, total); +} + +static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_cann_device_get_name(dev); + props->description = ggml_backend_cann_device_get_description(dev); + props->type = ggml_backend_cann_device_get_type(dev); + ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total); + + bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr; + + props->caps = { + /* .async = */ false, + /* .host_buffer = */ host_buffer, + /* .buffer_from_host_ptr = */ false, + /* .events = */ true, + }; +} + +static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + return ggml_backend_cann_init(ctx->device); +} + +/** + * @brief Checks if the CANN backend supports a specific backend buffer type. + * + * This function determines whether the CANN backend supports the given backend + * buffer type by comparing the device context of the backend and buffer type. + * It returns true if the devices are same between the backend context and + * buffer type context. + * + * @param backend Pointer to the CANN backend. + * @param buft Pointer to the backend buffer type to check. + * @return bool Returns true if the CANN backend supports the buffer type, + * otherwise false. + */ +static bool ggml_backend_cann_supports_buft( + ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (ggml_backend_buft_is_cann(buft)) { + ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context; + ggml_backend_cann_buffer_type_context * buft_ctx = + (ggml_backend_cann_buffer_type_context *)buft->context; + return buft_ctx->device == dev_ctx->device; + } + return false; +} + +static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context; + return ggml_backend_cann_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return ggml_backend_cann_host_buffer_type(); +} + +/** + * @brief Creates a new event for the CANN backend device. + * + * This function initializes a new event for the CANN backend by setting the + * device and creating an ACL runtime event. The created event is then wrapped + * in a ggml_backend_event structure and returned. + * + * @param backend Pointer to the CANN backend. + * @return ggml_backend_event_t Returns a pointer to the new event structure. + */ +static ggml_backend_event_t ggml_backend_cann_device_event_new( + ggml_backend_dev_t dev) { + ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context; + + ggml_cann_set_device(dev_ctx->device); + + aclrtEvent event; + ACL_CHECK(aclrtCreateEvent(&event)); + + return new ggml_backend_event{ + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device), + /* .context = */ event, + }; +} + +/** + * @brief Frees a CANN backend event. + * + * This function destroys the ACL runtime event associated with the given CANN + * backend event and then deletes the event structure itself. + * + * @param event Pointer to the event structure to be freed. + */ +static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context)); + + delete event; + GGML_UNUSED(dev); +} + +/** + * @brief Synchronizes the given event on the CANN backend. + * + * This function waits for the specified event to complete on the ACL runtime. + * + * @param event Pointer to the event structure to be synchronized. + */ +static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context)); + + GGML_UNUSED(dev); +} + +static const ggml_backend_device_i ggml_backend_cann_device_interface = { + /* .get_name = */ ggml_backend_cann_device_get_name, + /* .get_description = */ ggml_backend_cann_device_get_description, + /* .get_memory = */ ggml_backend_cann_device_get_memory, + /* .get_type = */ ggml_backend_cann_device_get_type, + /* .get_props = */ ggml_backend_cann_device_get_props, + /* .init_backend = */ ggml_backend_cann_device_init, // called for every card + /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, // not supported for CANN + /* .supports_op = */ ggml_backend_cann_supports_op, + /* .supports_buft = */ ggml_backend_cann_supports_buft, + /* .offload_op = */ ggml_backend_cann_offload_op, + /* .event_new = */ ggml_backend_cann_device_event_new, + /* .event_free = */ ggml_backend_cann_device_event_free, + /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize, +}; + + +// backend reg +struct ggml_backend_cann_reg_context { + std::vector devices; +}; + +static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return GGML_CANN_NAME; +} + +static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) { + ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context; + return ctx->devices.size(); +} + +static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; +} + +static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { + GGML_UNUSED(reg); + GGML_UNUSED(name); + // reserved for future use + return nullptr; +} + +static const ggml_backend_reg_i ggml_backend_cann_reg_interface = { + /* .get_name = */ ggml_backend_cann_reg_get_name, + /* .get_device_count = */ ggml_backend_cann_reg_get_device_count, + /* .get_device = */ ggml_backend_cann_reg_get_device, + /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address, +}; + +// backend registry, called only once for cann backend +ggml_backend_reg_t ggml_backend_cann_reg() { + static ggml_backend_reg reg; + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + aclInit(nullptr); + ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context; + + for (int i = 0; i < ggml_cann_info().device_count; i++) { + ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context(); + dev_ctx->description = aclrtGetSocName(); + dev_ctx->device = i; + dev_ctx->name = GGML_CANN_NAME + std::to_string(i); + ggml_cann_set_device(i); + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_cann_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx + }; + ctx->devices.push_back(dev); + } + + reg = ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_cann_reg_interface, + /* .context = */ ctx + }; + } + + initialized = true; + } + + return ® +} + +ggml_backend_t ggml_backend_cann_init(int32_t device) { aclInit(nullptr); if (device < 0 || device >= ggml_backend_cann_get_device_count()) { - GGML_CANN_LOG_ERROR("%s: error: invalid device %d\n", __func__, device); + GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device); return nullptr; } ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device); if (ctx == nullptr) { - GGML_CANN_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); return nullptr; } - + ggml_cann_set_device(ctx->device); ggml_backend_t cann_backend = new ggml_backend{/* .guid = */ ggml_backend_cann_guid(), /* .interface = */ ggml_backend_cann_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device), /* .context = */ ctx}; return cann_backend; } -GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend) { +bool ggml_backend_is_cann(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cann_guid()); } -GGML_CALL int32_t ggml_backend_cann_get_device_count() { +int32_t ggml_backend_cann_get_device_count() { return ggml_cann_info().device_count; } -GGML_CALL void ggml_backend_cann_get_device_description( +void ggml_backend_cann_get_device_description( int32_t device, char* description, size_t description_size) { ggml_cann_set_device(device); const char* soc_name = aclrtGetSocName(); snprintf(description, description_size, "%s", soc_name); } -GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, size_t* free, - size_t* total) { +void ggml_backend_cann_get_device_memory(int32_t device, size_t* free, + size_t* total) { ggml_cann_set_device(device); ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total)); } -// backend registry -/** - * @brief Initializes a CANN backend based on the provided parameters. - * - * This function initializes a CANN backend using the device index and then - * initializes the backend using `ggml_backend_cann_init`. - * - * @param params Parameters for initialization (unused in this implementation). - * @param user_data User data containing the device index to initialize the - * backend. - * @return ggml_backend_t The initialized CANN backend. - */ -GGML_CALL static ggml_backend_t ggml_backend_reg_cann_init(const char* params, - void* user_data) { - ggml_backend_t cann_backend = - ggml_backend_cann_init((int)(intptr_t)user_data); - return cann_backend; - - GGML_UNUSED(params); -} - -extern "C" GGML_CALL int ggml_backend_cann_reg_devices(); - -/** - * @brief Registers CANN (Ascend) devices as backend options. - * - * This function initializes ACL, retrieves the number of available CANN - * devices, and registers each device as a backend option using - * `ggml_backend_register`. Each device is given a unique name based on - * `GGML_CANN_NAME` followed by its index. - * - * @return int The number of CANN devices registered. - */ -GGML_CALL int ggml_backend_cann_reg_devices() { - uint32_t device_count = ggml_backend_cann_get_device_count(); - // initialization - for (uint32_t i = 0; i < device_count; i++) { - char name[128]; - snprintf(name, sizeof(name), "CANN%d", i); - ggml_backend_register(name, ggml_backend_reg_cann_init, - ggml_backend_cann_buffer_type(i), - (void*)(intptr_t)i); - } - return device_count; -} +GGML_BACKEND_DL_IMPL(ggml_backend_cann_reg) diff --git a/ggml/src/ggml-cann/kernels/CMakeLists.txt b/ggml/src/ggml-cann/kernels/CMakeLists.txt index 5b4fef91b..d687220c3 100644 --- a/ggml/src/ggml-cann/kernels/CMakeLists.txt +++ b/ggml/src/ggml-cann/kernels/CMakeLists.txt @@ -1,7 +1,3 @@ -if (NOT SOC_TYPE) - set (SOC_TYPE "Ascend910B3") -endif() - file(GLOB SRC_FILES get_row_f32.cpp get_row_f16.cpp @@ -13,7 +9,6 @@ file(GLOB SRC_FILES dup.cpp ) -string(TOLOWER ${SOC_TYPE} SOC_VERSION) set(ASCEND_CANN_PACKAGE_PATH ${CANN_INSTALL_DIR}) set(RUN_MODE "npu" CACHE STRING "run mode: npu/sim") @@ -30,4 +25,6 @@ ascendc_library(ascendc_kernels STATIC ${SRC_FILES} ) +message(STATUS "CANN: compile ascend kernels witch SOC_TYPE:${SOC_TYPE}, SOC_VERSION:${SOC_VERSION}, compile macro:-D${SOC_TYPE_COMPILE_OPTION}.") +ascendc_compile_definitions(ascendc_kernels PRIVATE "-D${SOC_TYPE_COMPILE_OPTION}") # ascendc_compile_definitions(ascendc_kernels PRIVATE -DASCENDC_DUMP) diff --git a/ggml/src/ggml-cann/kernels/dup.cpp b/ggml/src/ggml-cann/kernels/dup.cpp index e2c651152..c7ba38d10 100644 --- a/ggml/src/ggml-cann/kernels/dup.cpp +++ b/ggml/src/ggml-cann/kernels/dup.cpp @@ -5,6 +5,7 @@ using namespace AscendC; #define BUFFER_NUM 2 +const int64_t SUPPORTED_MAX_DIM = 65535; // currently the limit of max block dim supportted by dup kernel is 65535template template class DupByRows { @@ -51,24 +52,36 @@ class DupByRows { __aicore__ inline void copy_in() { LocalTensor src_local = src_queue.AllocTensor(); - - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = num_elem * sizeof(SRC_T); - DataCopyPadExtParams padParams; - DataCopyPad(src_local, src_gm, dataCopyParams, padParams); - + const size_t elem_per_block = 32 / sizeof(SRC_T); + size_t tail = num_elem % elem_per_block; + size_t cpy_elements_len = tail > 0 ? num_elem + 1 : num_elem; + DataCopy(src_local, src_gm, cpy_elements_len); src_queue.EnQue(src_local); } __aicore__ inline void copy_out() { LocalTensor dst_local = dst_queue.DeQue(); - +#ifdef ASCEND_310P + const size_t elem_per_block = 32 / sizeof(DST_T); + size_t tail = num_elem % elem_per_block; + size_t len = num_elem & ~(elem_per_block - 1); + if (len > 0) { + DataCopy(dst_gm, dst_local, len); + } + if(tail != 0) { + for (size_t i = tail; i < elem_per_block; i++) { + dst_local[len + i].SetValue(0, 0); + } + SetAtomicAdd(); + DataCopy(dst_gm[len], dst_local[len], elem_per_block); + SetAtomicNone(); + } +#else DataCopyExtParams dataCopyParams; dataCopyParams.blockCount = 1; dataCopyParams.blockLen = num_elem * sizeof(DST_T); DataCopyPad(dst_gm, dst_local, dataCopyParams); - +#endif dst_queue.FreeTensor(dst_local); } diff --git a/ggml/src/ggml-cann/kernels/get_row_f16.cpp b/ggml/src/ggml-cann/kernels/get_row_f16.cpp index c704b5b2e..416b45104 100644 --- a/ggml/src/ggml-cann/kernels/get_row_f16.cpp +++ b/ggml/src/ggml-cann/kernels/get_row_f16.cpp @@ -14,7 +14,7 @@ class GET_ROW_F16 { int64_t *output_ne_ub, size_t *output_nb_ub) { // TODO, use template for F16/f32 int64_t op_block_num = GetBlockNum(); - int64_t op_block_idx = GetBlockIdx(); + op_block_idx = GetBlockIdx(); for (int i = 0; i < 4; i++) { input_ne[i] = input_ne_ub[i]; @@ -59,32 +59,42 @@ class GET_ROW_F16 { } __aicore__ inline void copy_in(uint32_t offset, size_t len) { + size_t origin_len = len; LocalTensor input_local = input_queue.AllocTensor(); - size_t tail = len % 32; - len = len & ~31; - DataCopy(input_local, input_gm[offset], len); + const size_t elem_per_block = 32 / sizeof(half); + size_t tail = len % elem_per_block; + len = len & ~(elem_per_block - 1); if(tail != 0) { - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = tail * sizeof(half); - DataCopyPadExtParams padParams; - DataCopyPad(input_local[len], input_gm[offset + len], - dataCopyParams, padParams); + len += elem_per_block; } + DataCopy(input_local, input_gm[offset], len); input_queue.EnQue(input_local); } __aicore__ inline void copy_out(uint32_t offset, size_t len) { LocalTensor output_local = output_queue.DeQue(); - size_t tail = len % 32; - len = len & ~31; - DataCopy(output_gm[offset], output_local, len); + const size_t elem_per_block = 32 / sizeof(float); + size_t tail = len % elem_per_block; + len = len & ~(elem_per_block - 1); + if (len > 0) { + DataCopy(output_gm[offset], output_local, len); + } + if(tail != 0) { +#ifdef ASCEND_310P + for (size_t i = tail; i < elem_per_block; i++) { + output_local[len + i].SetValue(0, 0); + } + SetAtomicAdd(); + DataCopy(output_gm[offset + len], output_local[len], elem_per_block); + SetAtomicNone(); +#else DataCopyExtParams dataCopyParams; dataCopyParams.blockCount = 1; dataCopyParams.blockLen = tail * sizeof(float); DataCopyPad(output_gm[offset + len], output_local[len], dataCopyParams); +#endif } output_queue.FreeTensor(output_local); } @@ -150,6 +160,7 @@ class GET_ROW_F16 { GlobalTensor output_gm; TQue input_queue; TQue output_queue; + int64_t op_block_idx; }; template diff --git a/ggml/src/ggml-cann/kernels/get_row_f32.cpp b/ggml/src/ggml-cann/kernels/get_row_f32.cpp index 9db080af3..02116905b 100644 --- a/ggml/src/ggml-cann/kernels/get_row_f32.cpp +++ b/ggml/src/ggml-cann/kernels/get_row_f32.cpp @@ -13,7 +13,7 @@ class GET_ROW_F32 { int64_t *indices_ne_ub, size_t *indices_nb_ub, int64_t *output_ne_ub, size_t *output_nb_ub) { int64_t op_block_num = GetBlockNum(); - int64_t op_block_idx = GetBlockIdx(); + op_block_idx = GetBlockIdx(); for (int i = 0; i < 4; i++) { input_ne[i] = input_ne_ub[i]; @@ -55,31 +55,40 @@ class GET_ROW_F32 { __aicore__ inline void copy_in(uint32_t offset, size_t len) { LocalTensor input_local = input_queue.AllocTensor(); - size_t tail = len % 32; - len = len & ~31; - DataCopy(input_local, input_gm[offset], len); + const size_t elem_per_block = 32 / sizeof(float); + size_t tail = len % elem_per_block; + len = len & ~(elem_per_block - 1); if(tail != 0) { - DataCopyExtParams dataCopyParams; - dataCopyParams.blockCount = 1; - dataCopyParams.blockLen = tail * sizeof(float); - DataCopyPadExtParams padParams; - DataCopyPad(input_local[len], input_gm[offset + len], - dataCopyParams, padParams); + len += elem_per_block; } + DataCopy(input_local, input_gm[offset], len); input_queue.EnQue(input_local); } __aicore__ inline void copy_out(uint32_t offset, size_t len) { LocalTensor output_local = output_queue.DeQue(); - size_t tail = len % 32; - len = len & ~31; - DataCopy(output_gm[offset], output_local, len); + const size_t elem_per_block = 32 / sizeof(float); + size_t tail = len % elem_per_block; + len = len & ~(elem_per_block - 1); + if (len > 0) { + DataCopy(output_gm[offset], output_local, len); + } + if(tail != 0) { +#ifdef ASCEND_310P + for (size_t i = tail; i < elem_per_block; i++) { + output_local[len + i].SetValue(0, 0); + } + SetAtomicAdd(); + DataCopy(output_gm[offset + len], output_local[len], elem_per_block); + SetAtomicNone(); +#else DataCopyExtParams dataCopyParams; dataCopyParams.blockCount = 1; dataCopyParams.blockLen = tail * sizeof(float); DataCopyPad(output_gm[offset + len], output_local[len], dataCopyParams); +#endif } output_queue.FreeTensor(output_local); } @@ -144,6 +153,7 @@ class GET_ROW_F32 { GlobalTensor output_gm; TQue input_queue; TQue output_queue; + int64_t op_block_idx; }; template diff --git a/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp b/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp index a80bfeec2..4fbe72208 100644 --- a/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +++ b/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp @@ -2,6 +2,15 @@ // optimize me. Use template to avoid copy code. using namespace AscendC; +#ifdef ASCEND_310P // 310P not support 4bit get row + extern "C" __global__ __aicore__ void ascendc_get_row_q4_0( + GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm, + GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm, + GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) { + // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. + printf("Ascend310P not support 4bit get row.\n"); + } +#else #define BUFFER_NUM 2 @@ -191,3 +200,5 @@ extern "C" __global__ __aicore__ void ascendc_get_row_q4_0( indices_nb_ub, output_ne_ub, output_nb_ub); op.calculate(); } + +#endif // #ifdef ASCEND_310P diff --git a/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp b/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp index 8423b3f02..504b43afa 100644 --- a/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +++ b/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp @@ -1,6 +1,14 @@ #include "kernel_operator.h" using namespace AscendC; +#ifdef ASCEND_310P + extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0( + GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, + GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { + // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. + printf("Ascend310P not support f16->8bit quantization.\n"); + } +#else #define BUFFER_NUM 2 #define QK8_0 32 @@ -206,3 +214,5 @@ extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0( op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub); op.calculate(); } + +#endif // #ifdef ASCEND_310P diff --git a/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp b/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp index b7c575093..05b0bc1df 100644 --- a/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +++ b/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp @@ -1,6 +1,14 @@ #include "kernel_operator.h" using namespace AscendC; +#ifdef ASCEND_310P // 310P not support f32->8bit quantization + extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0( + GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, + GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { + // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. + printf("Ascend310P not support f32->8bit quantization.\n"); + } +#else #define BUFFER_NUM 2 #define QK8_0 32 @@ -204,3 +212,5 @@ extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0( op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub); op.calculate(); } + +#endif // #ifdef ASCEND_310P diff --git a/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp b/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp index 9c8c86b66..1188937b7 100644 --- a/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +++ b/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp @@ -1,6 +1,21 @@ #include "kernel_operator.h" using namespace AscendC; +#ifdef ASCEND_310P // 310P not support float->4bit quantization + extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0( + GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, + GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { + // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. + printf("Ascend310P not support f32->4bit quantization.\n"); + } + + extern "C" __global__ __aicore__ void ascendc_quantize_f16_to_q4_0( + GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm, + GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) { + // let following test cases can continue run, here just print error information. Of Cource the test case that call this operator is failed. + printf("Ascend310P not support f16->4bit quantization.\n"); + } +#else #define BUFFER_NUM 2 #define Group_Size 32 @@ -276,3 +291,5 @@ extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0( op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub); op.calculate(); } + +#endif // #ifdef ASCEND_310P diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 050161393..f13fd4dea 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -6,7 +6,20 @@ typedef uint16_t ggml_half; typedef uint32_t ggml_half2; -#define GGML_COMMON_AGGR +#define GGML_COMMON_AGGR_U +#define GGML_COMMON_AGGR_S + +#define GGML_COMMON_DECL +#elif defined(GGML_COMMON_DECL_CPP) +#include + +typedef uint16_t ggml_half; +typedef uint32_t ggml_half2; + +// std-c++ allow anonymous unions but some compiler warn on it +#define GGML_COMMON_AGGR_U data +// std-c++ do not allow it. +#define GGML_COMMON_AGGR_S data #define GGML_COMMON_DECL #elif defined(GGML_COMMON_DECL_METAL) @@ -15,7 +28,8 @@ typedef uint32_t ggml_half2; typedef half ggml_half; typedef half2 ggml_half2; -#define GGML_COMMON_AGGR +#define GGML_COMMON_AGGR_U +#define GGML_COMMON_AGGR_S #define GGML_COMMON_DECL #elif defined(GGML_COMMON_DECL_CUDA) @@ -29,7 +43,8 @@ typedef half2 ggml_half2; typedef half ggml_half; typedef half2 ggml_half2; -#define GGML_COMMON_AGGR data +#define GGML_COMMON_AGGR_U +#define GGML_COMMON_AGGR_S data #define GGML_COMMON_DECL #elif defined(GGML_COMMON_DECL_HIP) @@ -39,7 +54,8 @@ typedef half2 ggml_half2; typedef half ggml_half; typedef half2 ggml_half2; -#define GGML_COMMON_AGGR data +#define GGML_COMMON_AGGR_U +#define GGML_COMMON_AGGR_S data #define GGML_COMMON_DECL #elif defined(GGML_COMMON_DECL_SYCL) @@ -49,7 +65,8 @@ typedef half2 ggml_half2; typedef sycl::half ggml_half; typedef sycl::half2 ggml_half2; -#define GGML_COMMON_AGGR data +#define GGML_COMMON_AGGR_U +#define GGML_COMMON_AGGR_S data #define GGML_COMMON_DECL #endif @@ -154,9 +171,9 @@ typedef struct { struct { ggml_half d; // delta ggml_half m; // min - } GGML_COMMON_AGGR; + } GGML_COMMON_AGGR_S; ggml_half2 dm; - }; + } GGML_COMMON_AGGR_U; uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); @@ -175,9 +192,9 @@ typedef struct { struct { ggml_half d; // delta ggml_half m; // min - } GGML_COMMON_AGGR; + } GGML_COMMON_AGGR_S; ggml_half2 dm; - }; + } GGML_COMMON_AGGR_U; uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; @@ -196,37 +213,13 @@ typedef struct { struct { ggml_half d; // delta ggml_half s; // d * sum(qs[i]) - } GGML_COMMON_AGGR; + } GGML_COMMON_AGGR_S; ggml_half2 ds; - }; + } GGML_COMMON_AGGR_U; int8_t qs[QK8_1]; // quants } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding"); -typedef struct { - ggml_half d[4]; // deltas for 4 q4_0 blocks - uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks -} block_q4_0x4; -static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding"); - -typedef struct { - ggml_half d[8]; // deltas for 8 q4_0 blocks - uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks -} block_q4_0x8; -static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding"); - -typedef struct { - ggml_half d[4]; // deltas for 4 q8_0 blocks - int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks -} block_q8_0x4; -static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding"); - -typedef struct { - ggml_half d[8]; // deltas for 8 q8_0 blocks - int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks -} block_q8_0x8; -static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); - // // Ternary quantization // @@ -261,9 +254,9 @@ typedef struct { struct { ggml_half d; // super-block scale for quantized scales ggml_half dmin; // super-block scale for quantized mins - } GGML_COMMON_AGGR; + } GGML_COMMON_AGGR_S; ggml_half2 dm; - }; + } GGML_COMMON_AGGR_U; } block_q2_K; static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); @@ -288,9 +281,9 @@ typedef struct { struct { ggml_half d; // super-block scale for quantized scales ggml_half dmin; // super-block scale for quantized mins - } GGML_COMMON_AGGR; + } GGML_COMMON_AGGR_S; ggml_half2 dm; - }; + } GGML_COMMON_AGGR_U; uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; @@ -305,9 +298,9 @@ typedef struct { struct { ggml_half d; // super-block scale for quantized scales ggml_half dmin; // super-block scale for quantized mins - } GGML_COMMON_AGGR; + } GGML_COMMON_AGGR_S; ggml_half2 dm; - }; + } GGML_COMMON_AGGR_U; uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits @@ -431,6 +424,13 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_ #define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { #define GGML_TABLE_END() }; +#define GGML_COMMON_IMPL +#elif defined(GGML_COMMON_IMPL_CPP) +#include + +#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { +#define GGML_TABLE_END() }; + #define GGML_COMMON_IMPL #elif defined(GGML_COMMON_IMPL_METAL) #include @@ -473,7 +473,7 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128) 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, GGML_TABLE_END() -//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt new file mode 100644 index 000000000..6b3641c42 --- /dev/null +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -0,0 +1,346 @@ +function(ggml_add_cpu_backend_variant_impl tag_name) + if (tag_name) + set(GGML_CPU_NAME ggml-cpu-${tag_name}) + else() + set(GGML_CPU_NAME ggml-cpu) + endif() + + ggml_add_backend_library(${GGML_CPU_NAME}) + + list (APPEND GGML_CPU_SOURCES + ggml-cpu/ggml-cpu.c + ggml-cpu/ggml-cpu.cpp + ggml-cpu/ggml-cpu-aarch64.cpp + ggml-cpu/ggml-cpu-aarch64.h + ggml-cpu/ggml-cpu-hbm.cpp + ggml-cpu/ggml-cpu-hbm.h + ggml-cpu/ggml-cpu-quants.c + ggml-cpu/ggml-cpu-quants.h + ggml-cpu/ggml-cpu-traits.cpp + ggml-cpu/ggml-cpu-traits.h + ggml-cpu/amx/amx.cpp + ggml-cpu/amx/amx.h + ggml-cpu/amx/mmq.cpp + ggml-cpu/amx/mmq.h + ggml-cpu/ggml-cpu-impl.h + ) + + target_compile_features(${GGML_CPU_NAME} PRIVATE c_std_11 cxx_std_17) + target_include_directories(${GGML_CPU_NAME} PRIVATE . ggml-cpu) + + if (APPLE AND GGML_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if (ACCELERATE_FRAMEWORK) + message(STATUS "Accelerate framework found") + + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_ACCELERATE) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE ACCELERATE_NEW_LAPACK) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE ACCELERATE_LAPACK_ILP64) + + target_link_libraries(${GGML_CPU_NAME} PRIVATE ${ACCELERATE_FRAMEWORK}) + else() + message(WARNING "Accelerate framework not found") + endif() + endif() + + if (GGML_OPENMP) + find_package(OpenMP) + if (OpenMP_FOUND) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) + + target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + else() + message(WARNING "OpenMP not found") + endif() + endif() + + if (GGML_LLAMAFILE) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_LLAMAFILE) + + list(APPEND GGML_CPU_SOURCES + ggml-cpu/llamafile/sgemm.cpp + ggml-cpu/llamafile/sgemm.h) + endif() + + if (GGML_CPU_HBM) + find_library(memkind memkind REQUIRED) + + message(STATUS "Using memkind for CPU HBM") + + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_HBM) + + target_link_libraries(${GGML_CPU_NAME} PUBLIC memkind) + endif() + + if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR + CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR + (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$")) + + message(STATUS "ARM detected") + + if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang") + message(FATAL_ERROR "MSVC is not supported for ARM, use clang") + else() + check_cxx_compiler_flag(-mfp16-format=ieee GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E) + if (NOT "${GGML_COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "") + list(APPEND ARCH_FLAGS -mfp16-format=ieee) + endif() + + if (GGML_NATIVE) + # -mcpu=native does not always enable all the features in some compilers, + # so we check for them manually and enable them if available + + execute_process( + COMMAND ${CMAKE_C_COMPILER} -mcpu=native -E -v - + INPUT_FILE "/dev/null" + OUTPUT_QUIET + ERROR_VARIABLE ARM_MCPU + RESULT_VARIABLE ARM_MCPU_RESULT + ) + if (NOT ARM_MCPU_RESULT) + string(REGEX MATCH "-mcpu=[^ ']+" ARM_MCPU_FLAG "${ARM_MCPU}") + endif() + if ("${ARM_MCPU_FLAG}" STREQUAL "") + set(ARM_MCPU_FLAG -mcpu=native) + message(STATUS "ARM -mcpu not found, -mcpu=native will be used") + endif() + + include(CheckCXXSourceRuns) + + function(check_arm_feature tag code) + set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}") + check_cxx_source_runs( + "${code}" + GGML_MACHINE_SUPPORTS_${tag} + ) + if (GGML_MACHINE_SUPPORTS_${tag}) + set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE) + else() + set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE) + endif() + set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + endfunction() + + check_arm_feature(dotprod "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }") + check_arm_feature(i8mm "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }") + check_arm_feature(sve "#include \nint main() { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }") + + list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}") + else() + if (GGML_CPU_ARM_ARCH) + list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH}) + endif() + endif() + + # show enabled features + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + set(FEAT_INPUT_FILE "NUL") + else() + set(FEAT_INPUT_FILE "/dev/null") + endif() + + execute_process( + COMMAND ${CMAKE_C_COMPILER} ${ARCH_FLAGS} -dM -E - + INPUT_FILE ${FEAT_INPUT_FILE} + OUTPUT_VARIABLE ARM_FEATURE + RESULT_VARIABLE ARM_FEATURE_RESULT + ) + if (ARM_FEATURE_RESULT) + message(WARNING "Failed to get ARM features") + else() + foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC) + string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos) + if (NOT ${feature_pos} EQUAL -1) + message(STATUS "ARM feature ${feature} enabled") + endif() + endforeach() + endif() + endif() + elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR + (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$")) + + message(STATUS "x86 detected") + + if (MSVC) + # instruction set detection for MSVC only + if (GGML_NATIVE) + include(ggml-cpu/cmake/FindSIMD.cmake) + endif () + if (GGML_AVX512) + list(APPEND ARCH_FLAGS /arch:AVX512) + # /arch:AVX512 includes: __AVX512F__, __AVX512CD__, __AVX512BW__, __AVX512DQ__, and __AVX512VL__ + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. + list(APPEND ARCH_DEFINITIONS GGML_AVX512) + if (GGML_AVX512_VBMI) + list(APPEND ARCH_DEFINITIONS __AVX512VBMI__) + if (CMAKE_C_COMPILER_ID STREQUAL "Clang") + list(APPEND ARCH_FLAGS -mavx512vbmi) + endif() + endif() + if (GGML_AVX512_VNNI) + list(APPEND ARCH_DEFINITIONS __AVX512VNNI__ GGML_AVX512_VNNI) + if (CMAKE_C_COMPILER_ID STREQUAL "Clang") + list(APPEND ARCH_FLAGS -mavx512vnni) + endif() + endif() + if (GGML_AVX512_BF16) + list(APPEND ARCH_DEFINITIONS __AVX512BF16__ GGML_AVX512_BF16) + if (CMAKE_C_COMPILER_ID STREQUAL "Clang") + list(APPEND ARCH_FLAGS -mavx512bf16) + endif() + endif() + if (GGML_AMX_TILE) + list(APPEND ARCH_DEFINITIONS __AMX_TILE__ GGML_AMX_TILE) + endif() + if (GGML_AMX_INT8) + list(APPEND ARCH_DEFINITIONS __AMX_INT8__ GGML_AMX_INT8) + endif() + if (GGML_AMX_BF16) + list(APPEND ARCH_DEFINITIONS __AMX_BF16__ GGML_AMX_BF16) + endif() + elseif (GGML_AVX2) + list(APPEND ARCH_FLAGS /arch:AVX2) + list(APPEND ARCH_DEFINITIONS GGML_AVX2 GGML_FMA GGML_F16C) + elseif (GGML_AVX) + list(APPEND ARCH_FLAGS /arch:AVX) + list(APPEND ARCH_DEFINITIONS GGML_AVX) + else () + list(APPEND ARCH_FLAGS /arch:SSE4.2) + list(APPEND ARCH_DEFINITIONS GGML_SSE42) + endif() + if (GGML_AVX_VNNI) + list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI) + endif() + else () + if (GGML_NATIVE) + list(APPEND ARCH_FLAGS -march=native) + else () + list(APPEND ARCH_FLAGS -msse4.2) + list(APPEND ARCH_DEFINITIONS GGML_SSE42) + if (GGML_F16C) + list(APPEND ARCH_FLAGS -mf16c) + list(APPEND ARCH_DEFINITIONS GGML_F16C) + endif() + if (GGML_FMA) + list(APPEND ARCH_FLAGS -mfma) + list(APPEND ARCH_DEFINITIONS GGML_FMA) + endif() + if (GGML_AVX) + list(APPEND ARCH_FLAGS -mavx) + list(APPEND ARCH_DEFINITIONS GGML_AVX) + endif() + if (GGML_AVX2) + list(APPEND ARCH_FLAGS -mavx2) + list(APPEND ARCH_DEFINITIONS GGML_AVX2) + endif() + if (GGML_AVX_VNNI) + list(APPEND ARCH_FLAGS -mavxvnni) + list(APPEND ARCH_DEFINITIONS GGML_AVX_VNNI) + endif() + if (GGML_AVX512) + list(APPEND ARCH_FLAGS -mavx512f) + list(APPEND ARCH_FLAGS -mavx512cd) + list(APPEND ARCH_FLAGS -mavx512vl) + list(APPEND ARCH_FLAGS -mavx512dq) + list(APPEND ARCH_FLAGS -mavx512bw) + list(APPEND ARCH_DEFINITIONS GGML_AVX512) + endif() + if (GGML_AVX512_VBMI) + list(APPEND ARCH_FLAGS -mavx512vbmi) + list(APPEND ARCH_DEFINITIONS GGML_AVX512_VBMI) + endif() + if (GGML_AVX512_VNNI) + list(APPEND ARCH_FLAGS -mavx512vnni) + list(APPEND ARCH_DEFINITIONS GGML_AVX512_VNNI) + endif() + if (GGML_AVX512_BF16) + list(APPEND ARCH_FLAGS -mavx512bf16) + list(APPEND ARCH_DEFINITIONS GGML_AVX512_BF16) + endif() + if (GGML_AMX_TILE) + list(APPEND ARCH_FLAGS -mamx-tile) + list(APPEND ARCH_DEFINITIONS GGML_AMX_TILE) + endif() + if (GGML_AMX_INT8) + list(APPEND ARCH_FLAGS -mamx-int8) + list(APPEND ARCH_DEFINITIONS GGML_AMX_INT8) + endif() + if (GGML_AMX_BF16) + list(APPEND ARCH_FLAGS -mamx-bf16) + list(APPEND ARCH_DEFINITIONS GGML_AMX_BF16) + endif() + endif() + endif() + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") + message(STATUS "PowerPC detected") + execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1" OUTPUT_VARIABLE POWER10_M) + string(FIND "${POWER10_M}" "POWER10" substring_index) + if (NOT DEFINED substring_index OR "${substring_index}" STREQUAL "") + set(substring_index -1) + endif() + + if (${substring_index} GREATER_EQUAL 0) + list(APPEND ARCH_FLAGS -mcpu=power10) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") + list(APPEND ARCH_FLAGS -mcpu=powerpc64le) + else() + list(APPEND ARCH_FLAGS -mcpu=native -mtune=native) + # TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) + endif() + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") + message(STATUS "loongarch64 detected") + + list(APPEND ARCH_FLAGS -march=loongarch64) + if (GGML_LASX) + list(APPEND ARCH_FLAGS -mlasx) + endif() + if (GGML_LSX) + list(APPEND ARCH_FLAGS -mlsx) + endif() + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64") + message(STATUS "RISC-V detected") + if (GGML_RVV) + list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) + endif() + else() + message(STATUS "Unknown architecture") + endif() + + if (GGML_CPU_AARCH64) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64) + endif() + + message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}") + target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES}) + target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS}) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE ${ARCH_DEFINITIONS}) + + if (GGML_BACKEND_DL) + if (GGML_NATIVE) + # the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE + message(FATAL_ERROR "GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS") + endif() + + # The feature detection code is compiled as a separate target so that + # it can be built without the architecture flags + # Since multiple variants of the CPU backend may be included in the same + # build, using set_source_files_properties() to set the arch flags is not possible + set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats) + add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/cpu-feats-x86.cpp) + target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include) + target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS}) + target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED) + set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_link_libraries(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_FEATS_NAME}) + endif() + + if (EMSCRIPTEN) + set_target_properties(${GGML_CPU_NAME} PROPERTIES COMPILE_FLAGS "-msimd128") + endif() +endfunction() diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp new file mode 100644 index 000000000..5ec5263ce --- /dev/null +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -0,0 +1,220 @@ +#include "amx.h" +#include "common.h" +#include "mmq.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-traits.h" + +#if defined(__gnu_linux__) +#include +#include +#endif + +#include +#include +#include + +#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) + +// AMX type_trais +namespace ggml::cpu::amx { +class tensor_traits : public ggml::cpu::tensor_traits { + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + size = ggml_backend_amx_desired_wsize(op); + return true; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + if (op->op == GGML_OP_MUL_MAT) { + ggml_backend_amx_mul_mat(params, op); + return true; + } + return false; + } +}; + +static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) { + static tensor_traits traits; + return &traits; +} +} // namespace ggml::cpu::amx + +// AMX buffer interface +static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) { + free(buffer->context); +} + +static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) { + return (void *) (buffer->context); +} + +static void ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, + uint8_t value, size_t offset, size_t size) { + memset((char *) tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, + const void * data, size_t offset, size_t size) { + if (qtype_has_amx_kernels(tensor->type)) { + GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, ggml_type_name(tensor->type)); + ggml_backend_amx_convert_weight(tensor, data, offset, size); + } else { + memcpy((char *) tensor->data + offset, data, size); + } + + GGML_UNUSED(buffer); +} + +/* +// need to figure what we need to do with buffer->extra. +static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(!qtype_has_amx_kernels(tensor->type)); + memcpy(data, (const char *)tensor->data + offset, size); + + GGML_UNUSED(buffer); +} + +static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_buffer_is_host(src->buffer)) { + if (qtype_has_amx_kernels(src->type)) { + ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_nbytes(dst)); + } else { + memcpy(dst->data, src->data, ggml_nbytes(src)); + } + return true; + } + return false; + + GGML_UNUSED(buffer); +} +*/ + +static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); +} + +static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { + /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer, + /* .get_base = */ ggml_backend_amx_buffer_get_base, + /* .init_tensor = */ ggml_backend_amx_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, + /* .get_tensor = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_amx_buffer_clear, + /* .reset = */ nullptr, +}; + +static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "AMX"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * data = ggml_aligned_malloc(size); + if (data == NULL) { + fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size); + return NULL; + } + + return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size); +} + +static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +namespace ggml::cpu::amx { +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + // handle only 2d gemm for now + auto is_contiguous_2d = [](const struct ggml_tensor * t) { + return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; + }; + + if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous + is_contiguous_2d(op->src[1]) && // src1 must be contiguous + op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && + op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x + (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { + // src1 must be host buffer + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + // src1 must be float32 + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer && + op->src[0]->buffer->buft == ggml_backend_amx_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + + return nullptr; + } +}; +} // namespace ggml::cpu::amx + +static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_backend_amx_get_alloc_size(tensor); + + GGML_UNUSED(buft); +} + +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 + +static bool ggml_amx_init() { +#if defined(__gnu_linux__) + if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) { + fprintf(stderr, "AMX is not ready to be used!\n"); + return false; + } + return true; +#elif defined(_WIN32) + return true; +#endif +} + +ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = { + /* .iface = */ { + /* .get_name = */ ggml_backend_amx_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size, + /* .is_host = */ nullptr, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ new ggml::cpu::amx::extra_buffer_type(), + }; + + if (!ggml_amx_init()) { + return nullptr; + } + + return &ggml_backend_buffer_type_amx; +} + +#endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__) diff --git a/ggml/src/ggml-cpu/amx/amx.h b/ggml/src/ggml-cpu/amx/amx.h new file mode 100644 index 000000000..5b65d76bd --- /dev/null +++ b/ggml/src/ggml-cpu/amx/amx.h @@ -0,0 +1,8 @@ +#include "ggml-backend.h" +#include "ggml-cpu-impl.h" + +// GGML internal header + +#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) +ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void); +#endif diff --git a/ggml/src/ggml-cpu/amx/common.h b/ggml/src/ggml-cpu/amx/common.h new file mode 100644 index 000000000..f392e8985 --- /dev/null +++ b/ggml/src/ggml-cpu/amx/common.h @@ -0,0 +1,91 @@ +#pragma once + +#include "ggml.h" +#include "ggml-cpu-impl.h" + +#include +#include +#include + +#if defined(GGML_USE_OPENMP) +#include +#endif + +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 +#define VNNI_BLK 4 + +#define AMX_BLK_SIZE 32 + +#define TMM0 0 +#define TMM1 1 +#define TMM2 2 +#define TMM3 3 +#define TMM4 4 +#define TMM5 5 +#define TMM6 6 +#define TMM7 7 + +// parallel routines +template ::value, int>::type = 0> +inline T div_up(T x, T y) { return (x + y - 1) / y; } + +template +inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { +#if 0 + // onednn partition pattern + T& n_my = n_end; + if (nth <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else { + T n1 = div_up(n, nth); + T n2 = n1 - 1; + T T1 = n - n2 * nth; + n_my = ith < T1 ? n1 : n2; + n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; + } + n_end += n_start; +#else + // pytorch aten partition pattern + T n_my = div_up(n, nth); + n_start = ith * n_my; + n_end = std::min(n_start + n_my, n); +#endif +} + +template +inline void parallel_for(int n, const func_t& f) { +#if defined(GGML_USE_OPENMP) +#pragma omp parallel +{ + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); +} +#else + f(0, n); +#endif +} + +template +inline void parallel_for_ggml(const ggml_compute_params * params, int n, const func_t & f) { + int tbegin, tend; + balance211(n, params->nth, params->ith, tbegin, tend); + f(tbegin, tend); +} + +// quantized types that have AMX support +inline bool qtype_has_amx_kernels(const enum ggml_type type) { + // TODO: fix padding for vnni format + return (type == GGML_TYPE_Q4_0) || + (type == GGML_TYPE_Q4_1) || + (type == GGML_TYPE_Q8_0) || + (type == GGML_TYPE_Q4_K) || + (type == GGML_TYPE_Q5_K) || + (type == GGML_TYPE_Q6_K) || + (type == GGML_TYPE_IQ4_XS); +} diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp new file mode 100644 index 000000000..0ea91596b --- /dev/null +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -0,0 +1,2511 @@ + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + +#include "amx.h" +#include "mmq.h" +#include "ggml-impl.h" +#include "ggml-cpu-impl.h" +#include "ggml-cpu-quants.h" +#include "ggml-quants.h" +#include +#include + +#if defined(__gnu_linux__) +#include +#include +#endif + +#if (defined(_WIN32) || defined(_WIN64)) +#define RESTRICT __restrict +#else +#define RESTRICT __restrict__ +#endif + +#if (defined(_WIN32) || defined(_WIN64)) +#define ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) + +namespace { + +// Forced unrolling +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +// type traits +template struct PackedTypes {}; +template <> struct PackedTypes { using type = int8_t; }; +template <> struct PackedTypes { using type = uint8_t; }; +template <> struct PackedTypes { using type = int8_t; }; +template using packed_B_type = typename PackedTypes::type; + +template +struct do_compensate : std::integral_constant::value> {}; + +template +struct do_unpack : std::integral_constant::value || + std::is_same::value> {}; + +template +struct is_type_qkk : std::integral_constant::value || + std::is_same::value || + std::is_same::value || + std::is_same::value> {}; + +#define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \ + [&] { \ + switch (TYPE) { \ + case GGML_TYPE_F16: { \ + using type = ggml_fp16_t; \ + constexpr int blck_size = 16; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_BF16: { \ + using type = ggml_bf16_t; \ + constexpr int blck_size = 32; \ + return __VA_ARGS__(); \ + } \ + default: \ + fprintf(stderr, "Unsupported floating data type\n"); \ + } \ + }() + +#define GGML_DISPATCH_QTYPES(QT, ...) \ + [&] { \ + switch (QT) { \ + case GGML_TYPE_Q4_0: { \ + using type = block_q4_0; \ + using vec_dot_type = block_q8_0; \ + constexpr int blck_size = QK4_0; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q4_1: { \ + using type = block_q4_1; \ + using vec_dot_type = block_q8_1; \ + constexpr int blck_size = QK4_1; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q8_0: { \ + using type = block_q8_0; \ + using vec_dot_type = block_q8_0; \ + constexpr int blck_size = QK8_0; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q4_K: { \ + using type = block_q4_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q5_K: { \ + using type = block_q5_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_Q6_K: { \ + using type = block_q6_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case GGML_TYPE_IQ4_XS: { \ + using type = block_iq4_xs; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + default: \ + fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \ + } \ + }() + +#define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ + [&] { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// define amx tile config data structure +struct tile_config_t{ + uint8_t palette_id = 0; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +}; + +// Notes: amx tile config +// +// Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values, +// and accumulate the result to a 16 x 16 matrix C containing INT32 values, +// +// As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used +// instead of the normally used 16-16-64 config. +// +// Block A: {16, 32}, dtype = int8_t +// Block B: {16, 32}, dtype = uint8_t/int8_t +// Block C: {16, 16}, dtype = int32_t +// +// Block B needs to be prepacked to vnni format before feeding into TMUL: +// packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64} +// +// Therefore, we get tileconfig: +// A B C +// rows 16 8 16 +// colsb 32 64 16 +// +// For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1, +// C used TMM4-TMM7: +// B TMM0 B TMM1 +// A TMM2 C TMM4 C TMM6 +// A TMM3 C TMM5 C TMM7 +// +// Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A +// will be needed. +// +// Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16; +// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`. +// +// ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/ +// advanced-matrix-extensions-intrinsics-functions.html +// + +#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb +void ggml_tile_config_init(void) { + static thread_local bool is_first_time = true; + + if (!is_first_time) { + return; + } + + static thread_local tile_config_t tc; + tile_config_t current_tc; + _tile_storeconfig(¤t_tc); + + // load only when config changes + if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && + memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { + tc.palette_id = 1; + tc.start_row = 0; + TC_CONFIG_TILE(TMM0, 8, 64); + TC_CONFIG_TILE(TMM1, 8, 64); + TC_CONFIG_TILE(TMM2, 16, 32); + TC_CONFIG_TILE(TMM3, 16, 32); + TC_CONFIG_TILE(TMM4, 16, 64); + TC_CONFIG_TILE(TMM5, 16, 64); + TC_CONFIG_TILE(TMM6, 16, 64); + TC_CONFIG_TILE(TMM7, 16, 64); + _tile_loadconfig(&tc); + } + + is_first_time = false; +} + +// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. +// See the notes `s8s8 igemm compensation in avx512-vnni` for detail. +template +int get_tile_size() { + int tile_size = TILE_N * sizeof(TB); + if (do_compensate::value) { + tile_size += TILE_N * sizeof(int32_t); + } + if (std::is_same::value || + std::is_same::value) { + tile_size += TILE_N * 4; + } + if (std::is_same::value) { + tile_size += TILE_N * 2; + } + return tile_size; +} + +template +int get_row_size(int K) { + int KB = K / BLOCK_K; + int row_size = KB * sizeof(TB); + if (do_compensate::value) { + row_size += KB * sizeof(int32_t); + } + if (std::is_same::value || + std::is_same::value) { + row_size += KB * 4; + } + if (std::is_same::value) { + row_size += KB * 2; + } + return row_size; +} + +// vectorized dtype conversion +inline float FP16_TO_FP32(ggml_half val) { + __m256i v = _mm256_setr_epi16( + val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512 o = _mm512_cvtph_ps(v); + return _mm512_cvtss_f32(o); +} + +inline __m512 FP16_TO_FP32_VEC(ggml_half val) { + __m256i v = _mm256_set1_epi16(val); + return _mm512_cvtph_ps(v); +} + +// horizontal reduce +inline float _mm512_reduce_max_ps(const __m512 x) { + __m512 v = x; + __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_f32x4(v, v, 0xB1); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_ps(v, v, 0x4E); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_ps(v, v, 0xB1); + v = _mm512_max_ps(v, v1); + return _mm512_cvtss_f32(v); +} + +// transpose utils +#define SHUFFLE_EPI32(a, b, mask) \ + _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) +inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) { + // unpacking and 32-bit elements + v1[0] = _mm256_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm256_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm256_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm256_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm256_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm256_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm256_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm256_unpackhi_epi32(v[6], v[7]); + + // shuffling the 32-bit elements + v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44); + v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee); + v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44); + v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee); + v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44); + v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee); + v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44); + v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee); + + // shuffling 128-bit elements + v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02); + v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02); + v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02); + v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02); + v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13); + v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13); + v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13); + v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13); +} + +inline void transpose_16x4_32bit(__m512i * r, __m512i * d) { + + static const __m512i index1 = _mm512_set_epi32( + 0x0f, 0x0b, 0x07, 0x03, + 0x0e, 0x0a, 0x06, 0x02, + 0x0d, 0x09, 0x05, 0x01, + 0x0c, 0x08, 0x04, 0x00); + + d[0] = _mm512_permutexvar_epi32(index1, r[0]); + d[1] = _mm512_permutexvar_epi32(index1, r[1]); + d[2] = _mm512_permutexvar_epi32(index1, r[2]); + d[3] = _mm512_permutexvar_epi32(index1, r[3]); + + r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); + r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); + r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); + r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); + + d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); + d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); + d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); + d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); +} + +inline void transpose_16x16_32bit(__m512i * v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + const int KB = k / QK_K; + constexpr int kVecs = QK_K / 16; + + block_q8_K * y = reinterpret_cast(vy); + + // hold 16 float vecs from x + __m512 v[kVecs]; + + // hold the quants vecs + __m512i vq[kVecs / 4]; + + // hold the packed quants vecs + __m512i vq_packed[kVecs / 4]; + + const __m512 signBit = _mm512_set1_ps(-0.f); + + for (int i = 0; i < KB; ++i) { + // Compute max(abs(e)) for the block + __m512 vamax = _mm512_set1_ps(0.f); + for (int j = 0; j < kVecs; ++j) { + v[j] = _mm512_loadu_ps(x); x += 16; + vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j])); + } + const float amax = _mm512_reduce_max_ps(vamax); + + // Quantize these floats + const float iscale = 127.f / amax; + y[i].d = GGML_FP32_TO_FP16(1 / iscale); + const float id = ( amax != 0.0f ) ? iscale : 0.f; + const __m512 vscale = _mm512_set1_ps(id); + + // Apply multiplier and round to nearest integer + for (int j = 0; j < kVecs; ++j) { + v[j] = _mm512_mul_ps(v[j], vscale); + v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + + // Pack to epi8 vecs + for (int j = 0; j < kVecs / 4; ++j) { + __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0])); + __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1])); + __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2])); + __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3])); + + __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1); + __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1); + + vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1); + _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]); + } + + // Compute the bsums with vnni + transpose_16x4_32bit(vq, vq_packed); + + const __m512i one = _mm512_set1_epi8(1); + __m512i sum = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]); + } + _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum)); + } +} + +// quantize A from float to `vec_dot_type` +template +inline void from_float(const float * x, char * vy, int64_t k); + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { + quantize_row_q8_0(x, (block_q8_0 *)vy, k); +} + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { + quantize_row_q8_1(x, (block_q8_1 *)vy, k); +} + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { +#if 1 + // TODO: this is reference impl! + quantize_row_q8_K_ref(x, (block_q8_K *)vy, k); +#else + quantize_row_q8_K_vnni(x, vy, k); +#endif +} + +// load A from memory to array when nrows can not fill in whole tile +void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) { + assert(nr != TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) { + assert(nr != TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +template +void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { + assert(nr <= TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +template <> +void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { + assert(nr <= TILE_M); + // zero padding k from 16 to 32, so that we don't have to re-config amx + const __m128i zero = _mm_setzero_si128(); + for (int m = 0; m < nr; ++m) { + const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16)); + const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r); + } +} + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8(0xF); + return _mm256_and_si256(lowMask, bytes); +} + +// used for block_q4_K +inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) { + const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi); + const __m256i lowMask = _mm256_set1_epi8(0xF); + const __m256i q4l = _mm256_and_si256(tmp, lowMask); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask); + return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1); +} + +// used for block_q5_K +inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) { + const __m256i lowMask = _mm256_set1_epi8(0xF); + __m256i hmask = _mm256_set1_epi8(1); + hmask = _mm256_slli_epi16(hmask, k); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs); + const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + + return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1); +} + +// used for block_q6_K +inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) { + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(0x3); + + const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs); + const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32)); + const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh); + + const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4); + const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4); + const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4); + const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4); + + const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0); + const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1); + const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2); + const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3); + + r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1); + r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1); +} + +inline __m512i packNibbles(__m512i r0, __m512i r1) { + return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4)); +} + +template +inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) { + int8_t tmp[8 * 64]; + __m256i v[8], v2[8]; + for (int n = 0; n < 8; ++n) { + v[n] = bytes_from_nibbles_32(B[n * KB].qs); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]); + } + for (int n = 0; n < 8; ++n) { + v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]); + } + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < 8; n += 2) { + __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64)); + __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64)); + __m512i r1r0 = packNibbles(r0, r1); + _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0); + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { + __m256i v[8], v2[8]; + for (int n = 0; n < 8; ++n) { + v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs)); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]); + } + for (int n = 0; n < 8; ++n) { + v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs)); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]); + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { + __m512i v[16]; + // QK_K 256 with 8 groups, handle 2 groups at a time + char * pb = (char *)packed_B; + for (int k = 0; k < QK_K / 64; ++k) { + // pack 2 groups { n, g, k} to {g, k/4, 4n} + // e.g. {16, 2, 32} to {2, 8, 64} + for (int n = 0; n < TILE_N; ++n) { + v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32); + } + + transpose_16x16_32bit(v); + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < TILE_N; n += 2) { + _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); + pb += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { + __m512i v[16]; + const __m512i lowMask = _mm512_set1_epi8(0xF); + // QK_K 256 with 8 groups, handle 2 groups at a time + char * pb = (char *)packed_B; + char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; + for (int k = 0; k < QK_K / 64; ++k) { + // pack 2 groups { n, g, k} to {g, k/4, 4n} + // e.g. {16, 2, 32} to {2, 8, 64} + for (int n = 0; n < TILE_N; ++n) { + v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k); + } + + transpose_16x16_32bit(v); + + // 1. pack lower 4bits with 2 groups + for (int n = 0; n < TILE_N; n += 2) { + // get lower 4 bits + const __m512i r0 = _mm512_and_si512(v[n], lowMask); + const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); + _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; + } + + // 2. pack higher 1bit with 2 groups + const __m512i hmask = _mm512_set1_epi8(0x10); + for (int g = 0; g < 2; ++g) { + __m512i hbits = _mm512_setzero_si512(); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1)); + hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) ); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1)); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3)); + _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { + __m512i v[32]; + const __m512i lowMask = _mm512_set1_epi8(0xF); + // QK_K 256 with 8 groups, handle 4 groups at a time + char * pb = (char *)packed_B; + char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; + for (int k = 0; k < QK_K / 128; ++k) { + for (int n = 0; n < TILE_N; ++n) { + bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32); + } + + // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7 + transpose_16x16_32bit(v); + transpose_16x16_32bit(v + 16); + + // 1. pack lower 4bits with 4 groups + for (int n = 0; n < 32; n += 2) { + const __m512i r0 = _mm512_and_si512(v[n], lowMask); + const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); + _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; + } + + // 2. pack higher 2bit with 4 groups + const __m512i hmask = _mm512_set1_epi8(0x30); + for (int g = 0; g < 8; ++g) { + __m512i hbits = _mm512_setzero_si512(); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) ); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2)); + _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { + __m512i v[16]; + char * pb = (char *)packed_B; + for (int k = 0; k < QK_K / 64; ++k) { + for (int n = 0; n < TILE_N; ++n) { + __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0); + __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16); + v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); + } + + transpose_16x16_32bit(v); + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < TILE_N; n += 2) { + _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); + pb += 64; + } + } +} + +// pack B to vnni formats in 4bits or 8 bits +void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2); + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + } +} + +void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2); + ggml_half * m0 = d0 + TILE_N; + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + m0[n] = B[n * KB].m; + } +} + +inline void s8s8_compensation(void * RESTRICT packed_B) { + // packed_B layout: + // quants {TILE_N, TILEK} int8_t + // d0 {TILE_N} ggml_half + // comp {TILE_N} int32_t + const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); + __m512i vcomp = _mm512_setzero_si512(); + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + for (int k = 0; k < 8; ++k) { + __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64)); + vcomp = _mm512_dpbusd_epi32(vcomp, off, vb); + } + _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp); +} + +void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K); + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + } + s8s8_compensation(packed_B); +} + +// convert 8 * {min, scale} from int6 to int8 +inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) { + const uint32_t kmask1 = 0x3f3f3f3f; + const uint32_t kmask2 = 0x0f0f0f0f; + const uint32_t kmask3 = 0x03030303; + + memcpy(utmp, scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// scales {8, TILE_N} uint8 +// mins {8, TILE_N} uint8 +// d {TILE_N} ggml_half +// dmin {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N); + uint8_t * mins = scales + 8 * TILE_N; + ggml_half * d = reinterpret_cast(mins + 8 * TILE_N); + ggml_half * dmin = d + TILE_N; + + union { + uint32_t u32[4]; + uint8_t u8[16]; + } s; + + for (int n = 0; n < TILE_N; ++n) { + unpack_mins_and_scales(B[n * KB].scales, s.u32); + for (int k = 0; k < 8; ++k) { + scales[k * TILE_N + n] = s.u8[k]; + mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; + } + d[n] = B[n * KB].d; + dmin[n] = B[n * KB].dmin; + } +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// qh {8, TILE_N, 4} uint8 +// scales {8, TILE_N} uint8 +// mins {8, TILE_N} uint8 +// d {TILE_N} ggml_half +// dmin {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); + uint8_t * mins = scales + 8 * TILE_N; + ggml_half * d = reinterpret_cast(mins + 8 * TILE_N); + ggml_half * dmin = d + TILE_N; + + union { + uint32_t u32[4]; + uint8_t u8[16]; + } s; + + for (int n = 0; n < TILE_N; ++n) { + unpack_mins_and_scales(B[n * KB].scales, s.u32); + for (int k = 0; k < 8; ++k) { + scales[k * TILE_N + n] = s.u8[k]; + mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; + } + d[n] = B[n * KB].d; + dmin[n] = B[n * KB].dmin; + } +} + +// packed_B layout: +// quants {16, TILE_N, 8} uint8 +// qh {16, TILE_N, 4} uint8 +// scales {16, TILE_N} uint8 +// d {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); + ggml_half * d = reinterpret_cast(scales + 16 * TILE_N); + for (int n = 0; n < TILE_N; ++n) { + const int8_t * ps = B[n * KB].scales; + for (int k = 0; k < 16; ++k) { + scales[k * TILE_N + n] = ps[k]; + } + d[n] = B[n * KB].d; + } +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// scales {8, TILE_N} int8 +// d {TILE_N} ggml_half +void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + int8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N); + ggml_half * d = reinterpret_cast(scales + 8 * TILE_N); + + // pack the scales + for (int n = 0; n < TILE_N; ++n) { + uint16_t sh = B[n * KB].scales_h; + for (int k = 0; k < 8; k += 2) { + const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32; + scales[(k + 0) * TILE_N + n] = ls1; + scales[(k + 1) * TILE_N + n] = ls2; + sh >>= 4; + } + d[n] = B[n * KB].d; + } +} + +template> +void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) { + GGML_UNUSED(tile); + GGML_UNUSED(packed_B); +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B) { + const __m512i off = _mm512_set1_epi8(8); + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); + const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off); + const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) { + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); + const __m512i r0 = _mm512_and_si512(bytes, lowMask); + const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +// packed_B_t for QKK is int8_t +template +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + const int packed_B_group_size = QK_K / 2 * TILE_N / 8; + const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size; + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32); + const __m512i r0 = _mm512_and_si512(bytes, lowMask); + const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + // lower 4bits, stride 256 bytes + const int packed_l4_group_size = QK_K / 2 * TILE_N / 8; + const char * pb = (const char *)packed_B + k * packed_l4_group_size; + + // higher 1bit, stride 64 bytes + const int packed_h1_group_size = QK_K / 8 * TILE_N / 8; + const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size; + const __m512i hbits = _mm512_loadu_si512(ph); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + __m512i hmask0 = _mm512_set1_epi8(0x1); + __m512i hmask1 = _mm512_set1_epi8(0x2); + + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(pb + n * 32); + __m512i r0 = _mm512_and_si512(bytes, lowMask); + __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4); + __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4); + + hmask0 = _mm512_slli_epi16(hmask0, 2); + hmask1 = _mm512_slli_epi16(hmask1, 2); + r0 = _mm512_add_epi8(r0, h0); + r1 = _mm512_add_epi8(r1, h1); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + // lower 4bits, stride 128 bytes + const int packed_l4_group_size = QK_K / 2 * TILE_N / 16; + const char * pb = (const char *)packed_B + k * packed_l4_group_size; + + // higher 2bits, stride 64 bytes + const int packed_h2_group_size = QK_K / 4 * TILE_N / 16; + const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size; + const __m512i hbits = _mm512_loadu_si512(ph); + + const __m512i off = _mm512_set1_epi8(32); + const __m512i lowMask = _mm512_set1_epi8(0xF); + __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011 + __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100 + + // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A` + __m512i bytes = _mm512_loadu_si512(pb); + __m512i r0 = _mm512_and_si512(bytes, lowMask); + __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4); + __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2); + _mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); + _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); + + hmask0 = _mm512_slli_epi16(hmask0, 4); + hmask1 = _mm512_slli_epi16(hmask1, 4); + + bytes = _mm512_loadu_si512(pb + 64); + r0 = _mm512_and_si512(bytes, lowMask); + r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + h0 = _mm512_and_si512(hbits, hmask0); + h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2); + _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); + _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + static const __m512i values128 = _mm512_set_epi8( + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 + ); + + const int packed_B_group_size = QK_K / 2 * TILE_N / 8; + const char * pb = (const char *)packed_B + k * packed_B_group_size; + const __m512i lowMask = _mm512_set1_epi8(0xF); + + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(pb + n * 32); + const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask)); + const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template +struct acc_C {}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half)))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].s)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + vsum = _mm512_fmadd_ps(vm0, vs1, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N); + const uint8_t * mins = scales + 8 * TILE_N; + const ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N); + const ggml_half * dmin = d0 + TILE_N; + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); + const uint8_t * mins = scales + 8 * TILE_N; + const ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N); + const ggml_half * dmin = d0 + TILE_N; + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); + const ggml_half * d0 = reinterpret_cast(scales + 16 * TILE_N); + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const int8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N); + const ggml_half * d0 = reinterpret_cast(scales + 8 * TILE_N); + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template constexpr int get_quants_size(); +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; } + +// used for QKK format +template ::value, int>::type = 0> +inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + get_quants_size()); + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N))); + + for (int m = 0; m < nr; ++m) { + __m512i vsumi; + if (is_acc) { + vsumi = _mm512_loadu_si512(sumi + m * TILE_N); + } else { + vsumi = _mm512_setzero_si512(); + } + __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N); + vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale)); + _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi); + } +} + +template +struct tinygemm_kernel_avx { + static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) { + GGML_UNUSED(K); + GGML_UNUSED(A); + GGML_UNUSED(B); + GGML_UNUSED(C); + GGML_UNUSED(ldc); + } +}; + +template +struct tinygemm_kernel_avx { + static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N; + assert(BLOCK_K == 16); + + __m512 va; + __m512 vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto idx) { + vc[idx] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto idx, auto k) { + constexpr int row = idx / COLS; + constexpr int col = idx % COLS; + + if constexpr (col == 0) { + va = _mm512_loadu_ps(A + row * K + k); + } + if constexpr (row == 0) { + vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k))); + } + vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]); + }; + + for (int k = 0; k < K; k += 16) { + Unroll{}(compute, k); + } + + auto storec = [&](auto idx) { + constexpr int row = idx / COLS; + constexpr int col = idx % COLS; + C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]); + }; + Unroll{}(storec); + } +}; + +#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_avx::apply( \ + K, (const float *)src1->data + mb_start * K, \ + (const type *)src0->data + nb_start * K, \ + (float *)dst->data + mb_start * ldc + nb_start, ldc); + + +// re-organize in the format {NB, KB, TILE_SIZE}: +#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size + +template +void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K) { + const int NB = N / TILE_N; + const int KB = K / BLOCK_K; + const int TILE_SIZE = get_tile_size(); + + // parallel on NB should be enough + parallel_for(NB, [&](int begin, int end) { + for (int n = begin; n < end; ++n) { + for (int k = 0; k < KB; ++k) { + int n0 = n * TILE_N; + pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB); + } + } + }); +} + +template +struct tinygemm_kernel_vnni {}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_0); + + const block_q8_0 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512 vc[COLS]; + __m512 vd1; + + // sum of offsets, shared across COLS + // + // avx512-vnni does not have `_mm512_dpbssd_epi32`, + // need to transfrom ss to us: + // a * (b - 8) is equavilent to b * a - 8 * a + // s u u u s u s + // + __m512i vcomp; + + const __m512i off = _mm512_set1_epi8(8); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + // load a and compute compensation + if constexpr (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + vcomp = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]); + } + vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d)); + } + + // load b + __m512i vsum = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; k += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]); + } + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + vsum = _mm512_sub_epi32(vsum, vcomp); + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_1); + + const block_q8_1 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512i vb[8]; + __m512 vc[COLS]; + __m512 vd1, vs1; + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + // load a + if constexpr (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + } + vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d)); + vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].s)); + } + + // load b + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; k += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); + vb[k + 0] = _mm512_and_si512(bytes, lowMask); + vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + } + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half)))); + + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]); + } + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t); + + const block_q8_0 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512i vb[8]; + __m512 vc[COLS]; + __m512 vd1; + + // Notes: s8s8 igemm compensation in avx512-vnni + // change s8s8 to u8s8 with compensate + // a * b = (a + 128) * b - 128 * b + // s s u s u s + // + // (128 * b is pre-computed when packing B to vnni formats) + // + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + // load a and add offset 128 + if constexpr (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + va[k] = _mm512_add_epi8(va[k], off); + } + vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d)); + } + + // load b + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; ++k) { + vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64)); + } + const int offset = TILE_N * TILE_K; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half); + const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2)); + + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]); + } + vsum = _mm512_sub_epi32(vsum, vcomp); + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // a.qs: 8 groups, 32 bytes each group (m256i) + __m512i va[8]; + // a.bsum: 8 groups, 2 bytes each group (m128i) + __m512i va_bsum; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_scales = (QK_K / 2) * TILE_N; + const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N; + const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + // Notes: vnni formats in QK_K + // a) quants vnni format + // int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32 + // from {16, 32} to {8, 64} + // + // b) min vnni format + // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8 + // from {16, 8} to {4, 32} + // + auto compute = [&](auto col, auto i) { + // load a + if constexpr (col == 0) { + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); + } + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + va_bsum = _mm512_castsi128_si512(q8s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // step 1: accumultate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); + + __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + + b_qs += 64; + } + // vacc += scale * (q8 @ q4) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + + // step 2: accumulate the mins + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); + vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // a.qs: 8 groups, 32 bytes each group (m256i) + __m512i va[8]; + // a.bsum: 8 groups, 2 bytes each group (m128i) + __m512i va_bsum; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_qh = (QK_K / 2) * TILE_N; + const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; + const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N; + const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + // Q5_K and Q4_K shares the same vnni formats, refer to notes above. + auto compute = [&](auto col, auto i) { + // load a + if constexpr (col == 0) { + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); + } + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + va_bsum = _mm512_castsi128_si512(q8s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // step 1: accumultate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + const char * b_qh = b_ptr + offset_qh; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + __m512i vsum = _mm512_setzero_si512(); + __m512i hmask0 = _mm512_set1_epi8(0x1); + __m512i hmask1 = _mm512_set1_epi8(0x2); + __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64)); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); + + __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + + __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4); + __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4); + + hmask0 = _mm512_slli_epi16(hmask0, 2); + hmask1 = _mm512_slli_epi16(hmask1, 2); + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + + b_qs += 64; + } + // vacc += scale * (q8 @ q5) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + + // step 2: accumulate the mins + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); + vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q6_K); + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // load the 256 bytes from A to 4 avx512 vectors + __m512i va[4]; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_qh = (QK_K / 2) * TILE_N; + const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N; + + // compensation + __m512i vcomp; + + const __m512i m32s = _mm512_set1_epi32(32); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + if constexpr (col == 0) { + // load a + va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); + va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); + va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); + va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // accmulate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + const char * b_qh = b_ptr + offset_qh; + int mask = 0; + for (int k_group = 0; k_group < QK_K / 16; ++k_group) { + int r = k_group >> 2; + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + __m512i vsum = _mm512_setzero_si512(); + __m512i hmask = _mm512_set1_epi8(0x3); + + __m512i bytes = _mm512_loadu_si512(b_qs); + __m512i hbits = _mm512_loadu_si512(b_qh); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4); + __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2); + + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + + va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + bytes = _mm512_loadu_si512(b_qs); + vb0 = _mm512_and_si512(bytes, lowMask); + vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4)); + vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2); + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + b_qh += 64; + + // B * A - 32 * A + __m512i vmask = _mm512_set1_epi32(k_group); + vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); + + // vacc += scale * (q8 @ q6) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // load the 256 bytes from A to 4 avx512 vectors + __m512i va[4]; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_scales = (QK_K / 2) * TILE_N ; + const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N; + + // compensation + __m512i vcomp; + + const __m256i m128s = _mm256_set1_epi16(128); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + const __m512i values128 = _mm512_set_epi8( + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 + ); + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + const __m512i values256 = _mm512_add_epi8(values128, off); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + if constexpr (col == 0) { + // load a + va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); + va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); + va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); + va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); + + // compensation: 128 * A + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s)); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // accmulate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + int mask = 0; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + int r = k_group >> 1; + __m512i vmask = _mm512_set1_epi32(k_group); + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + __m512i bytes = _mm512_loadu_si512(b_qs); + __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask)); + __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); + + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + } + // (B + 128) * A - 128 * A + vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); + + // vacc += scale * (q8 @ q4) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + KB, (const char *)wdata + 0 * row_size_A, \ + (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ + (float *) dst->data + 0 * N + nb_start, ldc) + +template ::value, int>::type = 0> +void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) { + using packed_B_t = packed_B_type; + const int TILE_SIZE = get_tile_size(); + const bool need_unpack = do_unpack::value; + + GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); + const TA * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + const int m0 = std::min(M, TILE_M); + const int m1 = std::max(M - TILE_M, 0); + const int lda = KB * sizeof(TA); + //const int ldb = KB * sizeof(TB); + + static thread_local packed_B_t Tile0[TILE_N * TILE_K]; + static thread_local packed_B_t Tile1[TILE_N * TILE_K]; + static thread_local int8_t Tile23[TILE_M * TILE_K]; + + static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; + static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; + + // double buffering C to interleave avx512 and amx + int32_t * C_cur = TileC0; + int32_t * C_pre = TileC1; + + auto Tile4 = [&](int32_t * base) { return base; }; + auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; }; + auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; }; + auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; }; + + if (M == 2 * TILE_M) { + // i = 0 + const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE); + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + + _tile_zero(TMM4); + _tile_loadd(TMM2, A[0].qs, lda); + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM5); + _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda); + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); + + if (need_unpack) { + unpack_B(Tile1, B_blk0); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + + _tile_zero(TMM6); + _tile_dpbssd(TMM6, TMM2, TMM1); + _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM7); + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t)); + + for (int i = 1; i < KB; ++i) { + // index of previous iter + const int ii = i - 1; + const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); + GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] { + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + _tile_zero(TMM4); + _tile_loadd(TMM2, A[i].qs, lda); + acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM5); + _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda); + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); + + if (need_unpack) { + unpack_B(Tile1, B_blk1); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + _tile_zero(TMM6); + acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM6, TMM2, TMM1); + _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM7); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); + + std::swap(C_cur, C_pre); + }); + } + // final accumulation + { + int ii = KB - 1; + acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + } + } else { + for (int i = 0; i < KB; ++i) { + _tile_zero(TMM4); + _tile_zero(TMM6); + if (m1 != 0) { + _tile_zero(TMM5); + _tile_zero(TMM7); + } + + const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + + if (need_unpack) { + unpack_B(Tile1, B_blk1); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + + if (m0 == TILE_M) { + _tile_loadd(TMM2, A[i].qs, lda); + } else { + unpack_A(Tile23, &A[i], KB, m0); + _tile_loadd(TMM2, Tile23, TILE_K); + } + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_dpbssd(TMM6, TMM2, TMM1); + + _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); + _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); + + GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); + acc_C::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); + }); + + if (m1 != 0) { + unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1); + _tile_loadd(TMM3, Tile23, TILE_K); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); + _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); + GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); + }); + } + } + } + return; +} + +template ::value, int>::type = 0> +void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + static_assert(std::is_same::value); + const int TILE_SIZE = get_tile_size(); + + GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); + const TA * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + const int m0 = std::min(M, TILE_M); + const int m1 = std::max(M - TILE_M, 0); + //const int lda = KB * sizeof(TA); + + static thread_local int8_t Tile0[TILE_N * TILE_K]; + static thread_local int8_t Tile1[TILE_N * TILE_K]; + static thread_local int8_t Tile23[TILE_M * TILE_K]; + + // mat mul result for each group + static thread_local int32_t Tile4[TILE_M * TILE_N]; + static thread_local int32_t Tile5[TILE_M * TILE_N]; + static thread_local int32_t Tile6[TILE_M * TILE_N]; + static thread_local int32_t Tile7[TILE_M * TILE_N]; + + // sum of each QK_K block, contains 8 groups, int32 + static thread_local int32_t Sumi4[TILE_M * TILE_N]; + static thread_local int32_t Sumi5[TILE_M * TILE_N]; + static thread_local int32_t Sumi6[TILE_M * TILE_N]; + static thread_local int32_t Sumi7[TILE_M * TILE_N]; + + const int k_group_size = std::is_same::value ? 16 : 32; + for (int i = 0; i < KB; ++i) { + // step 1: accumulate the quants across 8 groups, each group with 32 + for (int k = 0; k < QK_K / k_group_size; ++k) { + GGML_DISPATCH_BOOL(k > 0, is_acc, [&] { + _tile_zero(TMM4); + _tile_zero(TMM6); + + unpack_B(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + + unpack_B(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + + unpack_A(Tile23, &A[i], KB, k, m0); + _tile_loadd(TMM2, Tile23, TILE_K); + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_dpbssd(TMM6, TMM2, TMM1); + + _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + + scale_C(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0); + scale_C(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0); + + if (m1 != 0) { + _tile_zero(TMM5); + _tile_zero(TMM7); + + unpack_A(Tile23, &A[TILE_M * KB + i], KB, k, m1); + _tile_loadd(TMM3, Tile23, TILE_K); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_dpbssd(TMM7, TMM3, TMM1); + + _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + + scale_C(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1); + scale_C(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1); + } + }); + } + + // step 2: accmulate the mins + GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); + acc_C::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); + if (m1 != 0) { + acc_C::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); + } + }); + } + return; +} + +} // anonymous namespace + +// get the packed tensor size for quantized weights +size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) { + const enum ggml_type TYPE = tensor->type; + + const int K = tensor->ne[0]; // ne0: in_features + const int N = tensor->ne[1]; // ne1: out_features + + auto get_tensor_size = [&] { + size_t row_size_B{0}; + GGML_DISPATCH_QTYPES(TYPE, [&] { + row_size_B = get_row_size(K); + }); + return N * row_size_B; + }; + + if (qtype_has_amx_kernels(TYPE)) { + return get_tensor_size(); + } else { + // for f16, bf16 we don't do packing + return ggml_nbytes(tensor); + } +} + +// pack weight to vnni format +void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0 && size == ggml_nbytes(tensor)); // only full tensor conversion is supported for now + + const enum ggml_type TYPE = tensor->type; + + const int K = tensor->ne[0]; // ne0: in_features + const int N = tensor->ne[1]; // ne1: out_features + + GGML_DISPATCH_QTYPES(TYPE, [&] { + convert_B_packed_format((void *)((char *)tensor->data + offset), (const type *)data, N, K); + }); +} + +size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { + struct ggml_tensor * src0 = dst->src[0]; + + const enum ggml_type TYPE = src0->type; + + const bool is_floating_type = TYPE == GGML_TYPE_F16; + if (is_floating_type) { + return 0; + } + + const int M = dst->ne[1]; + const int K = src0->ne[0]; + + size_t desired_wsize = 0; + + GGML_DISPATCH_QTYPES(TYPE, [&] { + const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); + desired_wsize = M * row_size_A; + }); + + return desired_wsize; +} + +// NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX) +// +// src0: weight in shape of {N, K}, quantized +// src1: input in shape of {M, K}, float32 +// dst: output in shape of {M, N}, float32 +// +// the function performs: dst = src1 @ src0.T +// +void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) { + struct ggml_tensor * src0 = dst->src[0]; + struct ggml_tensor * src1 = dst->src[1]; + + const enum ggml_type TYPE = src0->type; + + // f16 only has avx512 kernels for now, + // amx kernels will be added once 6th gen xeon is released. + const bool is_floating_type = TYPE == GGML_TYPE_F16; + + const int M = dst->ne[1]; + const int N = dst->ne[0]; + const int K = src0->ne[0]; + const int ldc = dst->nb[1] / dst->nb[0]; + + if (is_floating_type) { + constexpr int BLOCK_M = 4; + constexpr int BLOCK_N = 6; + const int MB = div_up(M, BLOCK_M); + const int NB = div_up(N, BLOCK_N); + + parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { + for (int i = begin; i < end; ++i) { + int mb = i / NB; + int nb = i % NB; + + int mb_start = mb * BLOCK_M; + int mb_size = std::min(BLOCK_M, M - mb_start); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break; + case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break; + case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break; + case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break; + case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break; + default: fprintf(stderr, "Unexpected block size!\n"); + } + } + }); + }); + return; + } + + // pointer to work space, used convert A from float to quantized type + void * wdata = params->wdata; + + //TODO: performance improvement: merge quant A + if (params->ith == 0) { + GGML_DISPATCH_QTYPES(TYPE, [&] { + const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); + const size_t desired_wsize = M * row_size_A; + if (params->wsize < desired_wsize) { + GGML_ABORT("insufficient work space size"); + } + + // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size + // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size + GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); + + const float * A_data = static_cast(src1->data); + for (int m = 0; m < M; ++m) { + from_float(A_data + m * K, (char *)wdata + m * row_size_A, K); + } + }); + } + + ggml_barrier(params->threadpool); + + if (M == 1) { + // MB = 1 and handle 8 tiles in each block + constexpr int kTilesN = 4; + constexpr int BLOCK_N = TILE_N * kTilesN; + const int NB = div_up(N, BLOCK_N); + + parallel_for_ggml(params, NB, [&](int begin, int end) { + GGML_DISPATCH_QTYPES(TYPE, [&] { + const int KB = K / blck_size; + const int TILE_SIZE = get_tile_size(); + const int row_size_A = KB * sizeof(vec_dot_type); + for (int i = begin; i < end; ++i) { + int nb = i; + int nb_start = nb * BLOCK_N; + int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 + + switch (nb_size) { + //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break; + case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break; + case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break; + case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break; + case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break; + default: fprintf(stderr, "Unexpected n block size!\n"); + } + } + }); + }); + return; + } + + // handle 4 tiles at a tile + constexpr int BLOCK_M = TILE_M * 2; + constexpr int BLOCK_N = TILE_N * 2; + const int MB = div_up(M, BLOCK_M); + const int NB = div_up(N, BLOCK_N); + + parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + // init tile config for each thread + ggml_tile_config_init(); + + GGML_DISPATCH_QTYPES(TYPE, [&] { + const int KB = K / blck_size; + const int TILE_SIZE = get_tile_size(); + const int row_size_A = KB * sizeof(vec_dot_type); + + for (int i = begin; i < end; ++i) { + int mb = i / NB; + int nb = i % NB; + + int mb_start = mb * BLOCK_M; + int mb_size = std::min(BLOCK_M, M - mb_start); + int nb_start = nb * BLOCK_N; + int nb_size = BLOCK_N; + + tinygemm_kernel_amx( + mb_size, nb_size, KB, + (const char *)wdata + mb_start * row_size_A, + (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), + (float *) dst->data + mb_start * N + nb_start, ldc); + } + }); + }); +} + +#endif // if defined(__AMX_INT8__) && defined(__AVX512VNNI__) diff --git a/ggml/src/ggml-cpu/amx/mmq.h b/ggml/src/ggml-cpu/amx/mmq.h new file mode 100644 index 000000000..baf768477 --- /dev/null +++ b/ggml/src/ggml-cpu/amx/mmq.h @@ -0,0 +1,10 @@ +#pragma once +#include "common.h" + +size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst); + +size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor); + +void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + +void ggml_backend_amx_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/cmake/FindSIMD.cmake b/ggml/src/ggml-cpu/cmake/FindSIMD.cmake similarity index 100% rename from ggml/cmake/FindSIMD.cmake rename to ggml/src/ggml-cpu/cmake/FindSIMD.cmake diff --git a/ggml/src/ggml-cpu/cpu-feats-x86.cpp b/ggml/src/ggml-cpu/cpu-feats-x86.cpp new file mode 100644 index 000000000..e8133d411 --- /dev/null +++ b/ggml/src/ggml-cpu/cpu-feats-x86.cpp @@ -0,0 +1,323 @@ +#include "ggml-backend-impl.h" + +#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) + +#ifdef _MSC_VER +#include +#endif + +#include +#include +#include +#include +#include + +// ref: https://cdrdv2-public.intel.com/782156/325383-sdm-vol-2abcd.pdf +struct cpuid_x86 { + bool SSE3(void) { return f_1_ecx[0]; } + bool PCLMULQDQ(void) { return f_1_ecx[1]; } + bool MONITOR(void) { return f_1_ecx[3]; } + bool SSSE3(void) { return f_1_ecx[9]; } + bool FMA(void) { return f_1_ecx[12]; } + bool CMPXCHG16B(void) { return f_1_ecx[13]; } + bool SSE41(void) { return f_1_ecx[19]; } + bool SSE42(void) { return f_1_ecx[20]; } + bool MOVBE(void) { return f_1_ecx[22]; } + bool POPCNT(void) { return f_1_ecx[23]; } + bool AES(void) { return f_1_ecx[25]; } + bool XSAVE(void) { return f_1_ecx[26]; } + bool OSXSAVE(void) { return f_1_ecx[27]; } + bool AVX(void) { return f_1_ecx[28]; } + bool F16C(void) { return f_1_ecx[29]; } + bool RDRAND(void) { return f_1_ecx[30]; } + + bool MSR(void) { return f_1_edx[5]; } + bool CX8(void) { return f_1_edx[8]; } + bool SEP(void) { return f_1_edx[11]; } + bool CMOV(void) { return f_1_edx[15]; } + bool CLFSH(void) { return f_1_edx[19]; } + bool MMX(void) { return f_1_edx[23]; } + bool FXSR(void) { return f_1_edx[24]; } + bool SSE(void) { return f_1_edx[25]; } + bool SSE2(void) { return f_1_edx[26]; } + + bool FSGSBASE(void) { return f_7_ebx[0]; } + bool BMI1(void) { return f_7_ebx[3]; } + bool HLE(void) { return is_intel && f_7_ebx[4]; } + bool AVX2(void) { return f_7_ebx[5]; } + bool BMI2(void) { return f_7_ebx[8]; } + bool ERMS(void) { return f_7_ebx[9]; } + bool INVPCID(void) { return f_7_ebx[10]; } + bool RTM(void) { return is_intel && f_7_ebx[11]; } + bool AVX512F(void) { return f_7_ebx[16]; } + bool AVX512DQ(void) { return f_7_ebx[17]; } + bool RDSEED(void) { return f_7_ebx[18]; } + bool ADX(void) { return f_7_ebx[19]; } + bool AVX512PF(void) { return f_7_ebx[26]; } + bool AVX512ER(void) { return f_7_ebx[27]; } + bool AVX512CD(void) { return f_7_ebx[28]; } + bool AVX512BW(void) { return f_7_ebx[30]; } + bool AVX512VL(void) { return f_7_ebx[31]; } + + bool SHA(void) { return f_7_ebx[29]; } + + bool PREFETCHWT1(void) { return f_7_ecx[0]; } + + bool LAHF(void) { return f_81_ecx[0]; } + bool LZCNT(void) { return is_intel && f_81_ecx[5]; } + bool ABM(void) { return is_amd && f_81_ecx[5]; } + bool SSE4a(void) { return is_amd && f_81_ecx[6]; } + bool XOP(void) { return is_amd && f_81_ecx[11]; } + bool TBM(void) { return is_amd && f_81_ecx[21]; } + + bool SYSCALL(void) { return is_intel && f_81_edx[11]; } + bool MMXEXT(void) { return is_amd && f_81_edx[22]; } + bool RDTSCP(void) { return is_intel && f_81_edx[27]; } + bool _3DNOWEXT(void) { return is_amd && f_81_edx[30]; } + bool _3DNOW(void) { return is_amd && f_81_edx[31]; } + + bool AVX512_VBMI(void) { return f_7_ecx[1]; } + bool AVX512_VNNI(void) { return f_7_ecx[11]; } + bool AVX512_FP16(void) { return f_7_edx[23]; } + bool AVX512_BF16(void) { return f_7_1_eax[5]; } + bool AVX_VNNI(void) { return f_7_1_eax[4]; } + + bool AMX_TILE(void) { return f_7_edx[24]; } + bool AMX_INT8(void) { return f_7_edx[25]; } + bool AMX_FP16(void) { return f_7_1_eax[21]; } + bool AMX_BF16(void) { return f_7_edx[22]; } + +#ifdef _MSC_VER + static void cpuid(int cpu_info[4], int eax) { + __cpuid(cpu_info, eax); + } + static void cpuidex(int cpu_info[4], int eax, int ecx) { + __cpuidex(cpu_info, eax, ecx); + } +#else + static void cpuid(int cpu_info[4], int eax) { + __asm__ __volatile__( + "cpuid" + : "=a"(cpu_info[0]), "=b"(cpu_info[1]), "=c"(cpu_info[2]), "=d"(cpu_info[3]) + : "a"(eax), "c"(0)); + } + static void cpuidex(int cpu_info[4], int eax, int ecx) { + __asm__ __volatile__( + "cpuid" + : "=a"(cpu_info[0]), "=b"(cpu_info[1]), "=c"(cpu_info[2]), "=d"(cpu_info[3]) + : "a"(eax), "c"(ecx)); + } +#endif + + cpuid_x86() { + std::array cpui; + std::vector> data; + + // calling __cpuid with 0x0 as the function_id argument + // gets the number of the highest valid function ID. + cpuid(cpui.data(), 0); + int n_ids = cpui[0]; + + for (int i = 0; i <= n_ids; ++i) { + cpuidex(cpui.data(), i, 0); + data.push_back(cpui); + } + + // capture vendor string + char vendor[0x20] = {}; + *reinterpret_cast(vendor) = data[0][1]; + *reinterpret_cast(vendor + 4) = data[0][3]; + *reinterpret_cast(vendor + 8) = data[0][2]; + this->vendor = vendor; + if (this->vendor == "GenuineIntel") { + is_intel = true; + } else if (this->vendor == "AuthenticAMD") { + is_amd = true; + } + + // load bitset with flags for function 0x00000001 + if (n_ids >= 1) { + f_1_ecx = data[1][2]; + f_1_edx = data[1][3]; + } + + // load bitset with flags for function 0x00000007 + if (n_ids >= 7) { + f_7_ebx = data[7][1]; + f_7_ecx = data[7][2]; + f_7_edx = data[7][3]; + cpuidex(cpui.data(), 7, 1); + f_7_1_eax = cpui[0]; + } + + // calling __cpuid with 0x80000000 as the function_id argument + // gets the number of the highest valid extended ID. + cpuid(cpui.data(), 0x80000000); + unsigned int n_ex_ids = cpui[0]; + + std::vector> ext_data; + for (unsigned int i = 0x80000000; i <= n_ex_ids; ++i) { + cpuidex(cpui.data(), i, 0); + ext_data.push_back(cpui); + } + + // load bitset with flags for function 0x80000001 + if (n_ex_ids >= 0x80000001) { + f_81_ecx = ext_data[1][2]; + f_81_edx = ext_data[1][3]; + } + + // interpret CPU brand string if reported + char brand[0x40] = {}; + if (n_ex_ids >= 0x80000004) { + std::memcpy(brand, ext_data[2].data(), sizeof(cpui)); + std::memcpy(brand + 16, ext_data[3].data(), sizeof(cpui)); + std::memcpy(brand + 32, ext_data[4].data(), sizeof(cpui)); + this->brand = brand; + } + } + + bool is_intel = false; + bool is_amd = false; + std::string vendor; + std::string brand; + std::bitset<32> f_1_ecx; + std::bitset<32> f_1_edx; + std::bitset<32> f_7_ebx; + std::bitset<32> f_7_ecx; + std::bitset<32> f_7_edx; + std::bitset<32> f_7_1_eax; + std::bitset<32> f_81_ecx; + std::bitset<32> f_81_edx; +}; + +#if 0 +void test_x86_is() { + cpuid_x86 is; + printf("CPU Vendor: %s\n", is.vendor.c_str()); + printf("Brand: %s\n", is.brand.c_str()); + printf("is_intel: %d\n", is.is_intel); + printf("is_amd: %d\n", is.is_amd); + printf("sse3: %d\n", is.SSE3()); + printf("pclmulqdq: %d\n", is.PCLMULQDQ()); + printf("ssse3: %d\n", is.SSSE3()); + printf("fma: %d\n", is.FMA()); + printf("cmpxchg16b: %d\n", is.CMPXCHG16B()); + printf("sse41: %d\n", is.SSE41()); + printf("sse42: %d\n", is.SSE42()); + printf("movbe: %d\n", is.MOVBE()); + printf("popcnt: %d\n", is.POPCNT()); + printf("aes: %d\n", is.AES()); + printf("xsave: %d\n", is.XSAVE()); + printf("osxsave: %d\n", is.OSXSAVE()); + printf("avx: %d\n", is.AVX()); + printf("f16c: %d\n", is.F16C()); + printf("rdrand: %d\n", is.RDRAND()); + printf("msr: %d\n", is.MSR()); + printf("cx8: %d\n", is.CX8()); + printf("sep: %d\n", is.SEP()); + printf("cmov: %d\n", is.CMOV()); + printf("clflush: %d\n", is.CLFSH()); + printf("mmx: %d\n", is.MMX()); + printf("fxsr: %d\n", is.FXSR()); + printf("sse: %d\n", is.SSE()); + printf("sse2: %d\n", is.SSE2()); + printf("fsgsbase: %d\n", is.FSGSBASE()); + printf("bmi1: %d\n", is.BMI1()); + printf("hle: %d\n", is.HLE()); + printf("avx2: %d\n", is.AVX2()); + printf("bmi2: %d\n", is.BMI2()); + printf("erms: %d\n", is.ERMS()); + printf("invpcid: %d\n", is.INVPCID()); + printf("rtm: %d\n", is.RTM()); + printf("avx512f: %d\n", is.AVX512F()); + printf("rdseed: %d\n", is.RDSEED()); + printf("adx: %d\n", is.ADX()); + printf("avx512pf: %d\n", is.AVX512PF()); + printf("avx512er: %d\n", is.AVX512ER()); + printf("avx512cd: %d\n", is.AVX512CD()); + printf("sha: %d\n", is.SHA()); + printf("prefetchwt1: %d\n", is.PREFETCHWT1()); + printf("lahf: %d\n", is.LAHF()); + printf("lzcnt: %d\n", is.LZCNT()); + printf("abm: %d\n", is.ABM()); + printf("sse4a: %d\n", is.SSE4a()); + printf("xop: %d\n", is.XOP()); + printf("tbm: %d\n", is.TBM()); + printf("syscall: %d\n", is.SYSCALL()); + printf("mmxext: %d\n", is.MMXEXT()); + printf("rdtscp: %d\n", is.RDTSCP()); + printf("3dnowext: %d\n", is._3DNOWEXT()); + printf("3dnow: %d\n", is._3DNOW()); + printf("avx512_vbmi: %d\n", is.AVX512_VBMI()); + printf("avx512_vnni: %d\n", is.AVX512_VNNI()); + printf("avx512_fp16: %d\n", is.AVX512_FP16()); + printf("avx512_bf16: %d\n", is.AVX512_BF16()); + printf("amx_tile: %d\n", is.AMX_TILE()); + printf("amx_int8: %d\n", is.AMX_INT8()); + printf("amx_fp16: %d\n", is.AMX_FP16()); + printf("amx_bf16: %d\n", is.AMX_BF16()); +} +#endif + +static int ggml_backend_cpu_x86_score() { + // FIXME: this does not check for OS support + + int score = 0; + cpuid_x86 is; + +#ifdef GGML_FMA + if (!is.FMA()) { return 0; } + score += 1; +#endif +#ifdef GGML_F16C + if (!is.F16C()) { return 0; } + score += 1<<1; +#endif +#ifdef GGML_SSE42 + if (!is.SSE42()) { return 0; } + score += 1<<2; +#endif +#ifdef GGML_AVX + if (!is.AVX()) { return 0; } + score += 1<<4; +#endif +#ifdef GGML_AVX2 + if (!is.AVX2()) { return 0; } + score += 1<<5; +#endif +#ifdef GGML_AVX_VNNI + if (!is.AVX_VNNI()) { return 0; } + score += 1<<6; +#endif +#ifdef GGML_AVX512 + if (!is.AVX512F()) { return 0; } + if (!is.AVX512CD()) { return 0; } + if (!is.AVX512VL()) { return 0; } + if (!is.AVX512DQ()) { return 0; } + if (!is.AVX512BW()) { return 0; } + score += 1<<7; +#endif +#ifdef GGML_AVX512_VBMI + if (!is.AVX512_VBMI()) { return 0; } + score += 1<<8; +#endif +#ifdef GGML_AVX512_BF16 + if (!is.AVX512_BF16()) { return 0; } + score += 1<<9; +#endif +#ifdef GGML_AVX512_VNNI + if (!is.AVX512_VNNI()) { return 0; } + score += 1<<10; +#endif +#ifdef GGML_AMX_INT8 + if (!is.AMX_INT8()) { return 0; } + score += 1<<11; +#endif + + return score; +} + +GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_x86_score) + +#endif // defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp new file mode 100644 index 000000000..b311a5b1c --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp @@ -0,0 +1,4247 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" +#include "ggml-backend-impl.h" + +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-impl.h" +#include "ggml-cpu-traits.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#include "ggml-cpu-aarch64.h" + +// TODO: move to include file? +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + int8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +// control size +static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); +static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); +static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); +static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); + +using block_q4_0x4 = block<4, 4>; +using block_q4_0x8 = block<4, 8>; +using block_q8_0x4 = block<8, 4>; +using block_q8_0x8 = block<8, 8>; + +struct block_iq4_nlx4 { + ggml_half d[4]; // deltas for 4 iq4_nl blocks + uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks +}; + +static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding"); + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#elif defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#define UNUSED GGML_UNUSED + +// Functions to create the interleaved data layout formats + +// interleave 4 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x4 +// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks +// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave +// +// - in : an array of block_q4_0 pointers +// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of +// blck_size_interleave bytes +// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes +// from bias offset form to pure sign form (this saves subtract +// operations durin unpacking) +// +#if defined(__AVX__) +#if defined(__F16C__) +#if defined(__AVX512F__) +#define GGML_F32Cx8x2_LOAD(x, y) _mm512_cvtph_ps(_mm256_set_m128i(_mm_loadu_si128((const __m128i *)(y)), _mm_loadu_si128((const __m128i *)(x)))) +#define GGML_F32Cx16_REPEAT_LOAD(x) _mm512_cvtph_ps(_mm256_set_m128i(x, x)) +#endif +// the _mm256_cvt intrinsics require F16C +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) _mm256_cvtph_ps(_mm_shuffle_epi32(_mm_maskload_epi32((int const*)(x), loadMask), 68)) +#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)) +#else +#if defined(__AVX512F__) +static inline __m512 __avx512_f32cx8x2_load(ggml_fp16_t *x, ggml_fp16_t *y) { + float tmp[16]; + + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + } + + for (int i = 0; i < 8; i++) { + tmp[i + 8] = GGML_FP16_TO_FP32(y[i]); + } + + return _mm512_loadu_ps(tmp); +} +static inline __m512 __avx512_repeat_f32cx16_load(__m128i x) { + float tmp[16]; + uint16_t tmphalf[8]; + _mm_storeu_si128((__m128i*)tmphalf, x); + + for (int i = 0; i < 4; i++) { + tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]); + tmp[i + 4] = GGML_FP16_TO_FP32(tmphalf[i]); + tmp[i + 8] = GGML_FP16_TO_FP32(tmphalf[i]); + tmp[i + 12] = GGML_FP16_TO_FP32(tmphalf[i]); + } + + return _mm512_loadu_ps(tmp); +} +#endif +static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline __m256 __avx_repeat_f32cx8_load(ggml_fp16_t *x) { + float tmp[8]; + + for (int i = 0; i < 4; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + tmp[i + 4] = GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrangeMask) { + uint16_t tmphalf[8]; + float tmp[8]; + + _mm_storeu_si128((__m128i*)tmphalf, _mm_shuffle_epi8(_mm_loadu_si128((const __m128i *) x), arrangeMask)); + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(tmphalf[i]); + } + + return _mm256_loadu_ps(tmp); +} + +#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define GGML_F32Cx8_REPEAT_LOAD(x, loadMask) __avx_repeat_f32cx8_load(x) +#define GGML_F32Cx8_REARRANGE_LOAD(x, arrangeMask) __avx_rearranged_f32cx8_load(x, arrangeMask) +#if defined(__AVX512F__) +#define GGML_F32Cx8x2_LOAD(x, y) __avx512_f32cx8x2_load(x, y) +#define GGML_F32Cx16_REPEAT_LOAD(x) __avx512_repeat_f32cx16_load(x) +#endif +#endif +#endif + + +#if defined(__AVX2__) || defined(__AVX512F__) +#if defined(__AVX512F__) +// add int16_t pairwise and return as 512 bit int vector +static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) { + const __m512i ones = _mm512_set1_epi16(1); + return _mm512_madd_epi16(ones, x); +} + +static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) { +#if defined(__AVX512VNNI__) + const __m512i zero = _mm512_setzero_si512(); + return _mm512_dpbusd_epi32(zero, ax, sy); +#else + // Perform multiplication and create 16-bit values + const __m512i dot = _mm512_maddubs_epi16(ax, sy); + return sum_i16_pairs_int_32x16(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as 512 bit int vector +static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) { + const __m512i zero = _mm512_setzero_si512(); + // Get absolute values of x vectors + const __m512i ax = _mm512_abs_epi8(x); + // Sign the values of the y vectors + __mmask64 blt0 = _mm512_movepi8_mask(x); + const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y); + return mul_sum_us8_pairs_int32x16(ax, sy); +} +#endif + +// add int16_t pairwise and return as 256 bit int vector +static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + return _mm256_madd_epi16(ones, x); +} + +static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbusd_epi32(zero, ax, sy); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbusd_avx_epi32(zero, ax, sy); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_int32x8(dot); +#endif +} + +// Integer variant of the function defined in ggml-quants.c +// multiply int8_t, add results pairwise twice and return as 256 bit int vector +static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + return _mm256_dpbssd_epi32(zero, x, y); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_int32x8(ax, sy); +#endif +} +#endif + +static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +static void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +static void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 4; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][2 * j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][2 * j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][2 * j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + float id[4]; + __m256 srcv[4][4]; + __m256 idvec[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 32 ); + __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 8 ); + __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 16 ); + __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 32 + 24 ); + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Divided by 127.f to mirror results in quantize_row_q8_0 + const float d = maxScalar / 127.f; + id[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; //d ? 1.0f / d : 0.0f; + + // Store the scale for the individual block + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + + // Store the values in blocks of eight values - Aim is to use these later for block interleaving + srcv[row_iter][0] = v0; + srcv[row_iter][1] = v1; + srcv[row_iter][2] = v2; + srcv[row_iter][3] = v3; + idvec[row_iter] = _mm256_set1_ps(id[row_iter]); + } + + // The loop iterates four times - The aim is to get 4 corresponding chunks of eight bytes from the original weight blocks that are interleaved + for (int j = 0; j < 4; j++) { + // Apply the multiplier + __m256 v0 = _mm256_mul_ps(srcv[0][j], idvec[0]); + __m256 v1 = _mm256_mul_ps(srcv[1][j], idvec[1]); + __m256 v2 = _mm256_mul_ps(srcv[2][j], idvec[2]); + __m256 v3 = _mm256_mul_ps(srcv[3][j], idvec[3]); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); + + // Permute and store the quantized weights in the required order after the pack instruction + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 32 * j + 16), ni4); +#endif + } + } +#else + // scalar + const int blck_size_interleave = 8; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +static void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { + assert(nrow == 4); + UNUSED(nrow); + if (blck_size_interleave == 4) { + quantize_q8_0_4x4(x, vy, n_per_row); + } else if (blck_size_interleave == 8) { + quantize_q8_0_4x8(x, vy, n_per_row); + } else { + assert(false); + } +} + +static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + for (int b = 0; b < nb; b++) { + int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); + int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); + int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); + int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x16_t a0 = vld1q_s8(a_ptr->qs); + int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret = vdupq_n_s32(0); + + ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0); + ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1); + ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2); + ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3); + + ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0); + ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1); + ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2); + ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3); + + acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4), + vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + for (int b = 0; b < nb; b++) { + int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); + int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); + int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); + int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs); + int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1); + int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2); + int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret0 = vdupq_n_s32(0); + int32x4_t ret1 = vdupq_n_s32(0); + + ret0 = vdotq_s32(ret0, b0 << 4, a0); + ret1 = vdotq_s32(ret1, b1 << 4, a0); + ret0 = vdotq_s32(ret0, b2 << 4, a1); + ret1 = vdotq_s32(ret1, b3 << 4, a1); + + ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2); + ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2); + ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3); + ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3); + + int32x4_t ret = vpaddq_s32(ret0, ret1); + + acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4), + vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) + if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "ptrue p0.b\n" + "add %x[b_ptr], %x[b_ptr], #0x10\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "mov z31.b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" + "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" + "mov z28.s, #0x0\n" + "mov z27.s, #0x0\n" + "ld1rd { z26.d }, p0/Z, [x22]\n" + "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" + "sub x20, x22, #0x2\n" + "sub x21, x21, #0x1\n" + "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" + "ld1rd { z23.d }, p0/Z, [x22, #8]\n" + "lsl z22.b, z30.b, #0x4\n" + "lsl z16.b, z29.b, #0x4\n" + "and z30.b, z30.b, #0xf0\n" + "and z29.b, z29.b, #0xf0\n" + "ld1rd { z21.d }, p0/Z, [x22, #16]\n" + "ld1rd { z20.d }, p0/Z, [x22, #24]\n" + "lsl z19.b, z25.b, #0x4\n" + "and z25.b, z25.b, #0xf0\n" + "ld1rh { z17.h }, p0/Z, [x20]\n" + "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" + "sdot z28.s, z22.b, z26.b\n" + "sdot z27.s, z16.b, z26.b\n" + "lsl z16.b, z24.b, #0x4\n" + "add x22, x22, #0x22\n" + "and z24.b, z24.b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x90\n" + "fcvt z17.s, p0/m, z17.h\n" + "fcvt z18.s, p0/m, z18.h\n" + "sdot z28.s, z19.b, z23.b\n" + "sdot z27.s, z16.b, z23.b\n" + "fmul z18.s, z18.s, z17.s\n" + "sdot z28.s, z30.b, z21.b\n" + "sdot z27.s, z29.b, z21.b\n" + "sdot z28.s, z25.b, z20.b\n" + "sdot z27.s, z24.b, z20.b\n" + "uzp1 z17.s, z28.s, z27.s\n" + "uzp2 z16.s, z28.s, z27.s\n" + "add z17.s, z17.s, z16.s\n" + "asr z17.s, z17.s, #0x4\n" + "scvtf z17.s, p0/m, z17.s\n" + "fmla z31.s, p0/M, z17.s, z18.s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x8\n" + "st1w { z31.s }, p0, [%x[res_ptr]]\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } +#endif // #if defined(__ARM_FEATURE_SVE) +#elif defined(__AVX2__) + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); + __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + + // Permute mask used for easier vector processing at later stages + const __m256i m4b = _mm256_set1_epi8(0x0F); + + int64_t b_nb = n / QK4_0; + + const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; + const block_q8_0 * a_ptr_start = (const block_q8_0 *)vy; + + // Process Q8_0 blocks one by one + for (int64_t y = 0; y < nr; y++) { + + // Pointers to LHS blocks of block_q8_0 format + const block_q8_0 * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < nc / 8; x++) { + + // Pointers to RHS blocks + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulator + __m256 acc_row = _mm256_setzero_ps(); + + for (int64_t b = 0; b < nb; b++) { + // Load 8 blocks of Q4_0 interleaved as 8 bytes (B0 - B7) + const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 1); + const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 2); + const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs) + 3); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_vec_0123_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_0, m4b)); // B0(0-7) B1(0-7) B2(0-7) B3(0-7) + const __m256i rhs_vec_4567_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_0, m4b)); // B4(0-7) B5(0-7) B6(0-7) B7(0-7) + const __m256i rhs_vec_0123_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_0123_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) + const __m256i rhs_vec_4567_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_vec_4567_1, m4b)); // B0(8-15) B1(8-15) B2(8-15) B3(8-15) + + const __m256i rhs_vec_0123_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b)); // B0(16-23) B1(16-23) B2(16-23) B3(16-23) + const __m256i rhs_vec_4567_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b)); // B4(16-23) B5(16-23) B6(16-23) B7(16-23) + const __m256i rhs_vec_0123_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b)); // B0(24-31) B1(24-31) B2(24-31) B3(24-31) + const __m256i rhs_vec_4567_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b)); // B4(24-31) B5(24-31) B6(24-31) B7(24-31) + + // Load the scale values for the 8 blocks interleaved in block_q4_0x8 + const __m256 col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); + + // Load and convert to FP32 scale from block_q8_0 + const __m256 row_scale_f32 = _mm256_set1_ps(GGML_FP16_TO_FP32(a_ptr[b].d)); + + // Load the block values in block_q8_0 in batches of 16 bytes and replicate the same across 256 bit vector + __m256i lhs_vec_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)a_ptr[b].qs)); + __m256i lhs_vec_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16))); + + lhs_vec_0 = _mm256_permute2f128_si256(lhs_vec_0, lhs_vec_0, 0); // A0 (0-15) A0(0-15) + lhs_vec_1 = _mm256_permute2f128_si256(lhs_vec_1, lhs_vec_1, 0); // A0 (16-31) A0(16-31)) + + __m256i iacc = _mm256_setzero_si256(); + + // Dot product done within 32 bit lanes and accumulated in the same vector + // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3) + // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7) + // ........................................................................... + // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); + + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); + iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); + + // Accumulated values multipled with appropriate scales + acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); + } + + // Accumulated output values permuted so as to be stored in appropriate order post accumulation + acc_row = _mm256_permutevar8x32_ps(acc_row, finalpermutemask); + _mm256_storeu_ps(s + (y * nr + x * 8), acc_row); + } + } + return; +#elif defined(__riscv_v_intrinsic) + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4)); + + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + // vector version needs Zvfhmin extension + const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d); + const float b_scales[8] = { + GGML_FP16_TO_FP32(b_ptr[l].d[0]), + GGML_FP16_TO_FP32(b_ptr[l].d[1]), + GGML_FP16_TO_FP32(b_ptr[l].d[2]), + GGML_FP16_TO_FP32(b_ptr[l].d[3]), + GGML_FP16_TO_FP32(b_ptr[l].d[4]), + GGML_FP16_TO_FP32(b_ptr[l].d[5]), + GGML_FP16_TO_FP32(b_ptr[l].d[6]), + GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4); + sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4); + } + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4); + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) + { + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float * res_ptr = s; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + float32x4_t sumf = vdupq_n_f32(0); + for (int l = 0; l < nb; l++) { + uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); + uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); + uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); + uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); + + int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4); + int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F); + int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4); + int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F); + int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4); + int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F); + int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4); + int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F); + + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); + + int32x4_t sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); + sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); + sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); + sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1); + sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2); + sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); + sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); + sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); + + float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + float32x4_t d = a_d * b_d; + + sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi)); + } + + vst1q_f32(res_ptr + x * 4, sumf); + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +static void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v23.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" + "3:" // Block loop + "ldr q3, [x28, #0x0]\n" + "ldr q31, [x25, #0x0]\n" + "movi v28.16b, #0x4\n" + "movi v10.4s, #0x0\n" + "ldr q22, [x28, #0x10]\n" + "ldr q6, [x25, #0x10]\n" + "movi v29.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q27, [x28, #0x20]\n" + "ldr q30, [x28, #0x30]\n" + "movi v20.4s, #0x0\n" + "movi v24.16b, #0xf0\n" + "ldr d2, [x25, #-0x8]\n" + "ldr d26, [x23, #-0x8]\n" + "sshl v12.16b, v3.16b, v28.16b\n" + "sub x20, x28, #0x8\n" + "ldr d17, [x20, #0x0]\n" + "and v3.16b, v3.16b, v24.16b\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" + ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" + ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" + ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" + "sshl v31.16b, v22.16b, v28.16b\n" + "and v22.16b, v22.16b, v24.16b\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v2.4s, v2.4h\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" + ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" + ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" + ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" + "sshl v6.16b, v27.16b, v28.16b\n" + "sshl v28.16b, v30.16b, v28.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + "and v30.16b, v30.16b, v24.16b\n" + "ldr q24, [x25, #0x20]\n" + ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x30]\n" + ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" + ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" + ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" + ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x40]\n" + ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" + ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" + ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" + ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x60]\n" + ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" + ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" + ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" + ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" + "fmul v24.4s, v17.4s, v2.s[0]\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v15.4s, v10.4s, v24.4s\n" + "ldr q24, [x23, #0x0]\n" + "fmul v10.4s, v17.4s, v2.s[1]\n" + "fmla v19.4s, v29.4s, v10.4s\n" + "ldr q10, [x23, #0x10]\n" + "fmul v29.4s, v17.4s, v2.s[2]\n" + "fmul v2.4s, v17.4s, v2.s[3]\n" + "fmla v18.4s, v9.4s, v29.4s\n" + "movi v9.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" + "fmla v14.4s, v20.4s, v2.4s\n" + "movi v20.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x20]\n" + ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" + ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" + ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" + ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x30]\n" + ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x40]\n" + ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" + ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" + ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" + ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x50]\n" + ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x60]\n" + ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" + ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" + ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" + ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x0]\n" + ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" + ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" + ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" + ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" + "fmul v10.4s, v17.4s, v26.s[0]\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v11.4s, v9.4s, v10.4s\n" + "ldr q9, [x22, #0x10]\n" + "fmul v10.4s, v17.4s, v26.s[1]\n" + "fmla v13.4s, v29.4s, v10.4s\n" + "ldr d29, [x22, #-0x8]\n" + "fmul v10.4s, v17.4s, v26.s[2]\n" + "fmul v26.4s, v17.4s, v26.s[3]\n" + "fcvtl v29.4s, v29.4h\n" + "fmla v23.4s, v20.4s, v10.4s\n" + "movi v20.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v16.4s, v2.4s, v26.4s\n" + "movi v26.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x20]\n" + ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x30]\n" + ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x40]\n" + ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x50]\n" + ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x60]\n" + ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x21, #0x0]\n" + ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" + ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" + ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" + "fmul v9.4s, v17.4s, v29.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v25.4s, v20.4s, v9.4s\n" + "ldr q9, [x21, #0x10]\n" + "fmul v20.4s, v17.4s, v29.s[1]\n" + "fmla v7.4s, v10.4s, v20.4s\n" + "ldr d20, [x21, #-0x8]\n" + "fmul v10.4s, v17.4s, v29.s[2]\n" + "fmul v29.4s, v17.4s, v29.s[3]\n" + "fcvtl v20.4s, v20.4h\n" + "fmla v0.4s, v26.4s, v10.4s\n" + "movi v26.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v4.4s, v2.4s, v29.4s\n" + "movi v2.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" + "ldr q12, [x21, #0x20]\n" + "fmul v24.4s, v17.4s, v20.s[0]\n" + ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x30]\n" + "fmul v31.4s, v17.4s, v20.s[1]\n" + ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" + ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" + ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" + ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x40]\n" + "fmul v6.4s, v17.4s, v20.s[2]\n" + "fmul v20.4s, v17.4s, v20.s[3]\n" + ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x50]\n" + ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" + ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" + ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" + ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x60]\n" + ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" + "ldr q17, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" + ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" + ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" + ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" + ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" + ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" + ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" + ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v5.4s, v26.4s, v24.4s\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v21.4s, v10.4s, v31.4s\n" + "fmla v8.4s, v2.4s, v6.4s\n" + "fmla v1.4s, v29.4s, v20.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q16, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q1, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q7, [x24, #0x0]\n" + "ldr q5, [x25, #0x0]\n" + "movi v9.16b, #0x4\n" + "movi v4.4s, #0x0\n" + "ldr q3, [x24, #0x10]\n" + "ldr q2, [x25, #0x10]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q13, [x24, #0x20]\n" + "ldr q31, [x25, #0x20]\n" + "movi v30.4s, #0x0\n" + "movi v29.16b, #0xf0\n" + "ldr q28, [x24, #0x30]\n" + "ldr q27, [x25, #0x30]\n" + "sshl v20.16b, v7.16b, v9.16b\n" + "sub x20, x24, #0x8\n" + "ldr q26, [x25, #0x40]\n" + "ldr q25, [x25, #0x50]\n" + "sshl v17.16b, v3.16b, v9.16b\n" + "and v7.16b, v7.16b, v29.16b\n" + "ldr q24, [x25, #0x60]\n" + "ldr q16, [x25, #0x70]\n" + "sshl v22.16b, v13.16b, v9.16b\n" + "and v3.16b, v3.16b, v29.16b\n" + "ldr d21, [x20, #0x0]\n" + "ldr d12, [x25, #-0x8]\n" + ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" + ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" + ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" + ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" + "sshl v9.16b, v28.16b, v9.16b\n" + "subs x21, x21, #0x1\n" + "and v13.16b, v13.16b, v29.16b\n" + "and v28.16b, v28.16b, v29.16b\n" + "add x25, x25, #0x88\n" + "add x24, x24, #0x48\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v12.4s, v12.4h\n" + ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" + ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" + ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" + ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" + "fmul v11.4s, v21.4s, v12.s[0]\n" + "fmul v23.4s, v21.4s, v12.s[1]\n" + "fmul v17.4s, v21.4s, v12.s[2]\n" + ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" + "fmul v6.4s, v21.4s, v12.s[3]\n" + ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" + ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" + ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" + ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" + ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" + ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" + ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" + ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" + ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" + ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" + ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" + ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" + ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" + ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" + ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" + ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" + ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" + ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" + ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" + ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" + ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" + ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" + ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" + "scvtf v4.4s, v4.4s, #0x4\n" + "scvtf v1.4s, v1.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v15.4s, v4.4s, v11.4s\n" + "scvtf v30.4s, v30.4s, #0x4\n" + "fmla v19.4s, v1.4s, v23.4s\n" + "fmla v18.4s, v0.4s, v17.4s\n" + "fmla v14.4s, v30.4s, v6.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q14, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +static void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "3:" // Block loop + "ldr q21, [x28, #0x0]\n" + "ldr q16, [x28, #0x10]\n" + "movi v1.16b, #0x4\n" + "movi v19.4s, #0x0\n" + "ldr q27, [x25, #0x0]\n" + "ldr q15, [x25, #0x10]\n" + "movi v26.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "ldr q29, [x28, #0x20]\n" + "ldr q3, [x28, #0x30]\n" + "movi v17.4s, #0x0\n" + "movi v0.16b, #0xf0\n" + "ldr d20, [x25, #-0x8]\n" + "ldr d9, [x23, #-0x8]\n" + "sshl v8.16b, v21.16b, v1.16b\n" + "sshl v31.16b, v16.16b, v1.16b\n" + "and v21.16b, v21.16b, v0.16b\n" + "and v16.16b, v16.16b, v0.16b\n" + "sub x20, x28, #0x8\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" + ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" + "ldr q27, [x25, #0x20]\n" + ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" + ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" + "sshl v15.16b, v29.16b, v1.16b\n" + "sshl v1.16b, v3.16b, v1.16b\n" + "and v29.16b, v29.16b, v0.16b\n" + "and v3.16b, v3.16b, v0.16b\n" + "ldr q0, [x25, #0x30]\n" + "fcvtl v20.4s, v20.4h\n" + ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" + "fcvtl v9.4s, v9.4h\n" + ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" + "ldr q27, [x25, #0x40]\n" + ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" + ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" + "ldr q0, [x25, #0x50]\n" + ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" + ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" + "ldr q27, [x25, #0x60]\n" + ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" + ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" + "ldr q0, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" + ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" + "ldr d27, [x20, #0x0]\n" + ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" + ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" + "fcvtl v27.4s, v27.4h\n" + "uzp1 v0.2d, v19.2d, v26.2d\n" + "uzp2 v26.2d, v19.2d, v26.2d\n" + "fmul v19.4s, v27.4s, v20.s[0]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v2.4s, v0.4s, v19.4s\n" + "ldr q19, [x23, #0x0]\n" + "uzp1 v0.2d, v18.2d, v17.2d\n" + "uzp2 v18.2d, v18.2d, v17.2d\n" + "fmul v17.4s, v27.4s, v20.s[1]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v10.4s, v26.4s, v17.4s\n" + "ldr q17, [x23, #0x10]\n" + "fmul v26.4s, v27.4s, v20.s[2]\n" + "fmul v20.4s, v27.4s, v20.s[3]\n" + "fmla v12.4s, v0.4s, v26.4s\n" + "ldr d0, [x22, #-0x8]\n" + "ldr d26, [x21, #-0x8]\n" + "fcvtl v0.4s, v0.4h\n" + "fmla v28.4s, v18.4s, v20.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x23, #0x20]\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x23, #0x40]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q19, [x23, #0x60]\n" + ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" + ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" + "uzp1 v19.2d, v20.2d, v18.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp2 v20.2d, v20.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v9.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v11.4s, v19.4s, v18.4s\n" + "ldr q18, [x22, #0x0]\n" + "fmul v19.4s, v27.4s, v9.s[1]\n" + "fmla v13.4s, v20.4s, v19.4s\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" + ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" + "ldr q17, [x23, #0x30]\n" + ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" + "ldr q17, [x23, #0x50]\n" + ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" + "ldr q17, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v9.s[2]\n" + "fmul v9.4s, v27.4s, v9.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v22.4s, v17.4s, v19.4s\n" + "ldr q17, [x22, #0x10]\n" + "movi v19.4s, #0x0\n" + ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" + "fmla v23.4s, v20.4s, v9.4s\n" + "movi v20.4s, #0x0\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" + "ldr q18, [x22, #0x20]\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" + ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" + "ldr q18, [x22, #0x40]\n" + ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" + ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" + "ldr q18, [x22, #0x60]\n" + ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" + ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" + "ldr q17, [x22, #0x30]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" + "ldr q17, [x22, #0x50]\n" + ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" + "ldr q17, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v0.s[0]\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v25.4s, v17.4s, v19.4s\n" + "ldr q19, [x21, #0x0]\n" + "fmul v17.4s, v27.4s, v0.s[1]\n" + "fmla v5.4s, v20.4s, v17.4s\n" + "ldr q17, [x21, #0x10]\n" + "uzp1 v20.2d, v9.2d, v18.2d\n" + "uzp2 v9.2d, v9.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v0.s[2]\n" + "fmul v0.4s, v27.4s, v0.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v7.4s, v20.4s, v18.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x21, #0x20]\n" + "fmla v4.4s, v9.4s, v0.4s\n" + "movi v9.4s, #0x0\n" + "movi v0.4s, #0x0\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + "fmul v8.4s, v27.4s, v26.s[0]\n" + ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" + "ldr q17, [x21, #0x30]\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + "fmul v31.4s, v27.4s, v26.s[1]\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + "fmul v15.4s, v27.4s, v26.s[2]\n" + "fmul v27.4s, v27.4s, v26.s[3]\n" + ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" + "ldr q1, [x21, #0x50]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q26, [x21, #0x60]\n" + ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" + ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" + "ldr q21, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" + ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" + ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" + ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" + "uzp1 v29.2d, v20.2d, v18.2d\n" + "uzp2 v21.2d, v20.2d, v18.2d\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "uzp1 v18.2d, v9.2d, v0.2d\n" + "uzp2 v16.2d, v9.2d, v0.2d\n" + "scvtf v21.4s, v21.4s, #0x4\n" + "fmla v6.4s, v29.4s, v8.4s\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v21.4s, v31.4s\n" + "fmla v24.4s, v18.4s, v15.4s\n" + "fmla v14.4s, v16.4s, v27.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q6, [x24, #0x0]\n" + "ldr q5, [x24, #0x10]\n" + "movi v17.16b, #0x4\n" + "movi v8.4s, #0x0\n" + "ldr q4, [x25, #0x0]\n" + "ldr q13, [x25, #0x10]\n" + "movi v27.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q31, [x24, #0x20]\n" + "ldr q14, [x24, #0x30]\n" + "movi v29.4s, #0x0\n" + "movi v22.16b, #0xf0\n" + "ldr q11, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "sshl v21.16b, v6.16b, v17.16b\n" + "sshl v16.16b, v5.16b, v17.16b\n" + "ldr q20, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "and v6.16b, v6.16b, v22.16b\n" + "and v5.16b, v5.16b, v22.16b\n" + "ldr q25, [x25, #0x60]\n" + "ldr q3, [x25, #0x70]\n" + "sshl v19.16b, v31.16b, v17.16b\n" + "sshl v18.16b, v14.16b, v17.16b\n" + "ldr d17, [x25, #-0x8]\n" + ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" + ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" + "and v31.16b, v31.16b, v22.16b\n" + ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" + ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" + "and v14.16b, v14.16b, v22.16b\n" + "sub x20, x24, #0x8\n" + "ldr d16, [x20, #0x0]\n" + "subs x21, x21, #0x1\n" + "add x25, x25, #0x88\n" + "fcvtl v17.4s, v17.4h\n" + "add x24, x24, #0x48\n" + ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" + ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" + ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" + ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" + "fcvtl v16.4s, v16.4h\n" + ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" + ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" + "fmul v23.4s, v16.4s, v17.s[0]\n" + "fmul v21.4s, v16.4s, v17.s[1]\n" + "fmul v1.4s, v16.4s, v17.s[2]\n" + "fmul v20.4s, v16.4s, v17.s[3]\n" + ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" + ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" + ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" + ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" + ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" + ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" + "uzp1 v19.2d, v8.2d, v27.2d\n" + "uzp2 v18.2d, v8.2d, v27.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v0.2d, v29.2d\n" + "uzp2 v16.2d, v0.2d, v29.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v2.4s, v19.4s, v23.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v10.4s, v18.4s, v21.4s\n" + "fmla v12.4s, v17.4s, v1.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q28, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x20, #0x4\n" + "mov x13, %x[nr]\n" + "mov z28.s, #-0x4\n" + "mov x12, #0x88\n" + "ptrue p1.b\n" + "whilelt p0.s, XZR, x20\n" + "cmp x13, #0x10\n" + "mul x12, %x[nb], x12\n" + "blt 4f\n" + "1:" // Row loop + "add x11, %x[b_ptr], #0x10\n" + "mov x10, %x[nc]\n" + "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x28, %x[a_ptr], #0x8\n" + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov x27, %x[nb]\n" + "add x26, x28, x12\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "add x25, x26, x12\n" + "mov z13.b, #0x0\n" + "mov z1.b, #0x0\n" + "add x24, x25, x12\n" + "mov z20.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z8.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z10.b, #0x0\n" + "3:" // Block loop + "ld1b { z30.b }, p1/Z, [x11]\n" + "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" + "mov z18.s, #0x0\n" + "mov z7.s, #0x0\n" + "ld1rqb { z3.b }, p1/Z, [x28]\n" + "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" + "mov z9.s, #0x0\n" + "mov z22.s, #0x0\n" + "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" + "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" + "sub x20, x11, #0x10\n" + "sub x23, x28, #0x8\n" + "lsl z31.b, z30.b, #0x4\n" + "lsl z6.b, z21.b, #0x4\n" + "ld1h { z23.s }, p1/Z, [x20]\n" + "sub x22, x26, #0x8\n" + "and z30.b, z30.b, #0xf0\n" + "and z21.b, z21.b, #0xf0\n" + "sub x21, x25, #0x8\n" + "sub x20, x24, #0x8\n" + "lsl z14.b, z4.b, #0x4\n" + "lsl z2.b, z17.b, #0x4\n" + "subs x27, x27, #0x1\n" + "add x11, x11, #0x90\n" + ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" + ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" + "and z4.b, z4.b, #0xf0\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" + "and z17.b, z17.b, #0xf0\n" + "fcvt z23.s, p1/m, z23.h\n" + ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" + ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" + "fscale z23.s, p1/m, z23.s, z28.s\n" + ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" + ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" + "add x28, x28, #0x88\n" + ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" + ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" + "ld1h { z3.s }, p0/Z, [x23]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "fcvt z3.s, p1/m, z3.h\n" + "uzp1 z5.d, z18.d, z7.d\n" + "uzp2 z18.d, z18.d, z7.d\n" + "mov z3.q, z3.q[0]\n" + "uzp1 z7.d, z9.d, z22.d\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z3.s[0]\n" + "scvtf z5.s, p1/m, z5.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "scvtf z7.s, p1/m, z7.s\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z24.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z5.b }, p1/Z, [x26]\n" + "fmul z9.s, z23.s, z3.s[1]\n" + "fmla z15.s, p1/M, z18.s, z9.s\n" + "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" + "fmul z9.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "fmla z12.s, p1/M, z7.s, z9.s\n" + "mov z9.s, #0x0\n" + "ld1h { z7.s }, p0/Z, [x22]\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + "fmla z0.s, p1/M, z22.s, z3.s\n" + "mov z22.s, #0x0\n" + "ld1h { z3.s }, p0/Z, [x21]\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" + "fcvt z7.s, p1/m, z7.h\n" + "fcvt z3.s, p1/m, z3.h\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" + "mov z7.q, z7.q[0]\n" + "mov z3.q, z3.q[0]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "uzp1 z5.d, z9.d, z22.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z7.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z13.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z9.b }, p1/Z, [x25]\n" + "fmul z5.s, z23.s, z7.s[1]\n" + "fmla z1.s, p1/M, z22.s, z5.s\n" + "mov z5.s, #0x0\n" + "mov z22.s, #0x0\n" + ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" + ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" + ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" + ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" + ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" + ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" + "add x26, x26, #0x88\n" + ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" + ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" + "uzp1 z18.d, z5.d, z22.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z22.d, z5.d, z22.d\n" + "fmul z5.s, z23.s, z7.s[2]\n" + "fmul z7.s, z23.s, z7.s[3]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z20.s, p1/M, z18.s, z5.s\n" + "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" + "ld1h { z5.s }, p0/Z, [x20]\n" + "fcvt z5.s, p1/m, z5.h\n" + "fmla z25.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" + "mov z5.q, z5.q[0]\n" + ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" + ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" + ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" + ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" + ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" + "uzp1 z9.d, z22.d, z7.d\n" + "scvtf z9.s, p1/m, z9.s\n" + "uzp2 z22.d, z22.d, z7.d\n" + "fmul z7.s, z23.s, z3.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z11.s, p1/M, z9.s, z7.s\n" + "ld1rqb { z9.b }, p1/Z, [x24]\n" + "fmul z7.s, z23.s, z3.s[1]\n" + "fmla z16.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" + ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" + ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" + ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" + ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" + "add x25, x25, #0x88\n" + ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" + ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" + "uzp1 z18.d, z22.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z7.d, z22.d, z7.d\n" + "fmul z22.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "scvtf z7.s, p1/m, z7.s\n" + "fmla z19.s, p1/M, z18.s, z22.s\n" + "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" + "fmul z22.s, z23.s, z5.s[0]\n" + "fmla z26.s, p1/M, z7.s, z3.s\n" + "mov z3.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" + ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "mov z9.s, #0x0\n" + ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" + "mov z31.s, #0x0\n" + ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" + "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" + ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" + "fmul z14.s, z23.s, z5.s[1]\n" + ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" + "fmul z2.s, z23.s, z5.s[2]\n" + "fmul z23.s, z23.s, z5.s[3]\n" + ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" + ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" + ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" + "add x24, x24, #0x88\n" + ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" + ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" + ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" + ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" + "uzp1 z18.d, z3.d, z7.d\n" + "uzp2 z5.d, z3.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp1 z6.d, z9.d, z31.d\n" + "uzp2 z9.d, z9.d, z31.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "fmla z8.s, p1/M, z18.s, z22.s\n" + "scvtf z6.s, p1/m, z6.s\n" + "scvtf z9.s, p1/m, z9.s\n" + "fmla z29.s, p1/M, z5.s, z14.s\n" + "fmla z27.s, p1/M, z6.s, z2.s\n" + "fmla z10.s, p1/M, z9.s, z23.s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x10, x10, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z0.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z13.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z1.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z20.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z25.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z11.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z16.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z19.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z26.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z8.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z29.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z27.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z10.s }, p1, [x20]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[res_ptr], x9\n" + "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x13, 9f\n" + "5:" // Row tail: Row loop + "add x25, %x[b_ptr], #0x10\n" + "mov x24, %x[nc]\n" + "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "add x28, %x[a_ptr], #0x8\n" + "mov x22, %x[nb]\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "7:" // Row tail: Block loop + "ld1b { z3.b }, p1/Z, [x25]\n" + "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" + "mov z2.s, #0x0\n" + "mov z25.s, #0x0\n" + "ld1rqb { z26.b }, p1/Z, [x28]\n" + "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" + "mov z27.s, #0x0\n" + "mov z19.s, #0x0\n" + "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" + "sub x21, x25, #0x10\n" + "sub x20, x28, #0x8\n" + "lsl z20.b, z3.b, #0x4\n" + "lsl z4.b, z6.b, #0x4\n" + "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" + "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" + "and z3.b, z3.b, #0xf0\n" + "and z6.b, z6.b, #0xf0\n" + "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" + "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" + "lsl z8.b, z29.b, #0x4\n" + "lsl z14.b, z16.b, #0x4\n" + "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" + "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" + ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" + ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" + "and z29.b, z29.b, #0xf0\n" + "ld1h { z17.s }, p1/Z, [x21]\n" + ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" + ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" + "and z16.b, z16.b, #0xf0\n" + "ld1h { z4.s }, p0/Z, [x20]\n" + "subs x22, x22, #0x1\n" + "add x28, x28, #0x88\n" + "fcvt z17.s, p1/m, z17.h\n" + "add x25, x25, #0x90\n" + ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" + ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" + "fcvt z4.s, p1/m, z4.h\n" + ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" + ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" + "fscale z17.s, p1/m, z17.s, z28.s\n" + "mov z4.q, z4.q[0]\n" + ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" + ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" + "fmul z23.s, z17.s, z4.s[0]\n" + "fmul z9.s, z17.s, z4.s[1]\n" + "fmul z21.s, z17.s, z4.s[2]\n" + "fmul z4.s, z17.s, z4.s[3]\n" + ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" + ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" + ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" + ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" + ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" + ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" + "uzp1 z31.d, z2.d, z25.d\n" + "uzp2 z13.d, z2.d, z25.d\n" + "scvtf z31.s, p1/m, z31.s\n" + "uzp1 z17.d, z27.d, z19.d\n" + "uzp2 z18.d, z27.d, z19.d\n" + "scvtf z13.s, p1/m, z13.s\n" + "fmla z24.s, p1/M, z31.s, z23.s\n" + "scvtf z17.s, p1/m, z17.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "fmla z15.s, p1/M, z13.s, z9.s\n" + "fmla z12.s, p1/M, z17.s, z21.s\n" + "fmla z0.s, p1/M, z18.s, z4.s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x13, #0x1\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x2\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x3\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "st1w { z0.s }, p1, [x20]\n" + "8:" // Row tail: Accumulator store skip + "subs x24, x24, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "bne 6b\n" + "subs x13, x13, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x12\n" + "mov %x[res_ptr], x23\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } +#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) +#elif defined(__AVX2__) || defined(__AVX512F__) + { + const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; + const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy; + int64_t b_nb = n / QK4_0; + int64_t y = 0; + // Mask to mask out nibbles from packed bytes + const __m256i m4b = _mm256_set1_epi8(0x0F); + const __m128i loadMask = _mm_blend_epi32(_mm_setzero_si128(), _mm_set1_epi32(0xFFFFFFFF), 3); + // Lookup table to convert signed nibbles to signed bytes + __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + // Permute mask used for easier vector processing at later stages + __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); + int64_t xstart = 0; + int anr = nr - nr%16; // Used to align nr with boundary of 16 + #ifdef __AVX512F__ + int anc = nc - nc%16; // Used to align nc with boundary of 16 + // Mask to mask out nibbles from packed bytes expanded to 512 bit length + const __m512i m4bexpanded = _mm512_set1_epi8(0x0F); + // Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length + __m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1); + + // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < anr / 4; y += 4) { + + const block_q8_0x4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5B8B9BCBD, B2B3B6B7BABBBEBF for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + // 4-bit -> 8-bit - Sign is maintained + const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) + const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) + + const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) + const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) + + const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) + const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) + + const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) + const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + + // Shuffle pattern two - right side input + + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + + // Scale values - Load the weight scale values of two block_q4_0x8 + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // Process LHS in pairs of rows + for (int rp = 0; rp < 4; rp++) { + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); + __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); + __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); + __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); + __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); + __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); + __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); + __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); + + __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); + __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); + __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); + __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); + __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); + __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); + __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); + __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); + + // Shuffle pattern one - left side input + + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m512i iacc_mat_00_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_01_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_10_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_11_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_00_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_01_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); + __m512i iacc_mat_10_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_11_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); + const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); + + // Multiply with appropiate scales and accumulate + acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + } + } + + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < nr / 4; y ++) { + + const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); + + // Take group of two block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_0x8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_0x8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the sixteen block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....BE,BF + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_0[b].qs + 96)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 32)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 64)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i *)(b_ptr_1[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + // 4-bit -> 8-bit - Sign is maintained + const __m512i rhs_mat_014589CD_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) B8(0-7) B9(0-7) BC(0-7) BD(0-7) + const __m512i rhs_mat_2367ABEF_0 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) BA(0-7) BB(0-7) BE(0-7) BF(0-7) + + const __m512i rhs_mat_014589CD_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) B8(8-15) B9(8-15) BC(8-15) BD(8-15) + const __m512i rhs_mat_2367ABEF_1 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) BA(8-15) BB(8-15) BE(8-15) BF(8-15) + + const __m512i rhs_mat_014589CD_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) B8(16-23) B9(16-23) BC(16-23) BD(16-23) + const __m512i rhs_mat_2367ABEF_2 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) BA(16-23) BB(16-23) BE(16-23) BF(16-23) + + const __m512i rhs_mat_014589CD_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) B8(24-31) B9(24-31) BC(24-31) BD(24-31) + const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + + // Shuffle pattern two - right side input + + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + + + // Scale values - Load the weight scale values of two block_q4_0x8 + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_ymm_01_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 0); + __m256i lhs_mat_ymm_23_0 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_0, lhs_mat_ymm_0123_0, 17); + __m256i lhs_mat_ymm_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_ymm_01_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 0); + __m256i lhs_mat_ymm_23_1 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_1, lhs_mat_ymm_0123_1, 17); + __m256i lhs_mat_ymm_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_ymm_01_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 0); + __m256i lhs_mat_ymm_23_2 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_2, lhs_mat_ymm_0123_2, 17); + __m256i lhs_mat_ymm_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_ymm_01_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 0); + __m256i lhs_mat_ymm_23_3 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_3, lhs_mat_ymm_0123_3, 17); + + __m512i lhs_mat_01_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_0), lhs_mat_ymm_01_0, 1); + __m512i lhs_mat_23_0 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_0), lhs_mat_ymm_23_0, 1); + __m512i lhs_mat_01_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_1), lhs_mat_ymm_01_1, 1); + __m512i lhs_mat_23_1 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_1), lhs_mat_ymm_23_1, 1); + __m512i lhs_mat_01_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_2), lhs_mat_ymm_01_2, 1); + __m512i lhs_mat_23_2 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_2), lhs_mat_ymm_23_2, 1); + __m512i lhs_mat_01_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_3), lhs_mat_ymm_01_3, 1); + __m512i lhs_mat_23_3 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_3), lhs_mat_ymm_23_3, 1); + + // Shuffle pattern one - left side input + + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m512i iacc_mat_00_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_01_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_10_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); + __m512i iacc_mat_11_sp1 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); + __m512i iacc_mat_00_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_01_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); + __m512i iacc_mat_10_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); + __m512i iacc_mat_11_sp2 = + _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m512i iacc_mat_01 = _mm512_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m512i iacc_mat_10 = _mm512_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m512i iacc_mat_11 = _mm512_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); + const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); + + // Multiply with appropiate scales and accumulate + acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + } + + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm512_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + if (anc != nc) { + xstart = anc/8; + y = 0; + } + #endif // __AVX512F__ + + // Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation + + for (; y < anr / 4; y += 4) { + const block_q8_0x4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_q4_0x8 structures at each pass of the loop and perform dot product operation + for (int64_t x = xstart; x < nc / 8; x++) { + + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of values + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_q4_0x8 + const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + + // Process LHS in groups of four + for (int rp = 0; rp < 4; rp++) { + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m256i iacc_mat_00_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_01_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_10_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_11_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_00_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_01_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); + __m256i iacc_mat_10_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_11_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + } + } + + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + + // Take a block_q8_0x4 structures at each pass of the loop and perform dot product operation + for (; y < nr / 4; y ++) { + + const block_q8_0x4 * a_ptr = a_ptr_start + (y * nb); + + // Load the eight block_q4_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + for (int64_t x = xstart; x < nc / 8; x++) { + + const block_q4_0x8 * b_ptr = b_ptr_start + (x * b_nb); + + // Master FP accumulators + __m256 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm256_setzero_ps(); + } + + for (int64_t b = 0; b < nb; b++) { + // Load the eight block_q8_0 quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 32)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 64)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i *)(b_ptr[b].qs + 96)); + + // Save the values in the following vectors in the formats B0B1B4B5, B2B3B6B7 for further processing and storing of valuess + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + + // 4-bit -> 8-bit - Sign is maintained + const __m256i rhs_mat_0145_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_0, m4b)); //B0(0-7) B1(0-7) B4(0-7) B5(0-7) + const __m256i rhs_mat_2367_0 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_0, m4b)); //B2(0-7) B3(0-7) B6(0-7) B7(0-7) + + const __m256i rhs_mat_0145_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_0145_1, m4b)); //B0(8-15) B1(8-15) B4(8-15) B5(8-15) + const __m256i rhs_mat_2367_1 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(rhs_raw_mat_2367_1, m4b)); //B2(8-15) B3(8-15) B6(8-15) B7(8-15) + + const __m256i rhs_mat_0145_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_0, 4), m4b)); //B0(16-23) B1(16-23) B4(16-23) B5(16-23) + const __m256i rhs_mat_2367_2 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_0, 4), m4b)); //B2(16-23) B3(16-23) B6(16-23) B7(16-23) + + const __m256i rhs_mat_0145_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_0145_1, 4), m4b)); //B0(24-31) B1(24-31) B4(24-31) B5(24-31) + const __m256i rhs_mat_2367_3 = _mm256_shuffle_epi8(signextendlut, _mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_1, 4), m4b)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) + + // Shuffle pattern one - right side input + const __m256i rhs_mat_0145_0_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) + const __m256i rhs_mat_2367_0_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) + + const __m256i rhs_mat_0145_1_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) + const __m256i rhs_mat_2367_1_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) + + const __m256i rhs_mat_0145_2_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) + const __m256i rhs_mat_2367_2_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) + + const __m256i rhs_mat_0145_3_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) + const __m256i rhs_mat_2367_3_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) + + // Shuffle pattern two - right side input + + const __m256i rhs_mat_0145_0_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) + const __m256i rhs_mat_2367_0_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) + + const __m256i rhs_mat_0145_1_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) + const __m256i rhs_mat_2367_1_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) + + const __m256i rhs_mat_0145_2_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) + const __m256i rhs_mat_2367_2_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) + + const __m256i rhs_mat_0145_3_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) + const __m256i rhs_mat_2367_3_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) + + // Scale values - Load the wight scale values of block_q4_0x8 + const __m256 col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + + // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); + + // Shuffle pattern one - left side input + + const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m256i lhs_mat_23_0_sp1 = _mm256_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + + const __m256i lhs_mat_01_1_sp1 = _mm256_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m256i lhs_mat_23_1_sp1 = _mm256_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + + const __m256i lhs_mat_01_2_sp1 = _mm256_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m256i lhs_mat_23_2_sp1 = _mm256_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + + const __m256i lhs_mat_01_3_sp1 = _mm256_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m256i lhs_mat_23_3_sp1 = _mm256_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + + // Shuffle pattern two - left side input + + const __m256i lhs_mat_01_0_sp2 = _mm256_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m256i lhs_mat_23_0_sp2 = _mm256_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + + const __m256i lhs_mat_01_1_sp2 = _mm256_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m256i lhs_mat_23_1_sp2 = _mm256_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + + const __m256i lhs_mat_01_2_sp2 = _mm256_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m256i lhs_mat_23_2_sp2 = _mm256_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + + const __m256i lhs_mat_01_3_sp2 = _mm256_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m256i lhs_mat_23_3_sp2 = _mm256_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + // Resembles MMLAs into 2x2 matrices in ARM Version + __m256i iacc_mat_00_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_01_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_10_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); + __m256i iacc_mat_11_sp1 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); + __m256i iacc_mat_00_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_01_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); + __m256i iacc_mat_10_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); + __m256i iacc_mat_11_sp2 = + _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); + __m256i iacc_mat_01 = _mm256_add_epi32(iacc_mat_01_sp1, iacc_mat_01_sp2); + __m256i iacc_mat_10 = _mm256_add_epi32(iacc_mat_10_sp1, iacc_mat_10_sp2); + __m256i iacc_mat_11 = _mm256_add_epi32(iacc_mat_11_sp1, iacc_mat_11_sp2); + + + // Straighten out to make 4 row vectors + __m256i iacc_row_0 = _mm256_blend_epi32(iacc_mat_00, _mm256_shuffle_epi32(iacc_mat_01, 78), 204); + __m256i iacc_row_1 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01, 204); + __m256i iacc_row_2 = _mm256_blend_epi32(iacc_mat_10, _mm256_shuffle_epi32(iacc_mat_11, 78), 204); + __m256i iacc_row_3 = _mm256_blend_epi32(_mm256_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11, 204); + + // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes + const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask); + + // Multiply with appropiate scales and accumulate + acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + } + + // Store the accumulated values + for (int i = 0; i < 4; i++) { + _mm256_storeu_ps((float *)(s + ((y * 4 + i) * bs + x * 8)), acc_rows[i]); + } + } + } + return; + } +#elif defined(__riscv_v_intrinsic) + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + + // vector version needs Zvfhmin extension + const float a_scales[4] = { + GGML_FP16_TO_FP32(a_ptr[l].d[0]), + GGML_FP16_TO_FP32(a_ptr[l].d[1]), + GGML_FP16_TO_FP32(a_ptr[l].d[2]), + GGML_FP16_TO_FP32(a_ptr[l].d[3]) + }; + const float b_scales[8] = { + GGML_FP16_TO_FP32(b_ptr[l].d[0]), + GGML_FP16_TO_FP32(b_ptr[l].d[1]), + GGML_FP16_TO_FP32(b_ptr[l].d[2]), + GGML_FP16_TO_FP32(b_ptr[l].d[3]), + GGML_FP16_TO_FP32(b_ptr[l].d[4]), + GGML_FP16_TO_FP32(b_ptr[l].d[5]), + GGML_FP16_TO_FP32(b_ptr[l].d[6]), + GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + + const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l0; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l0 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); + sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l1; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l1 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); + sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l2; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l2 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); + sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l3; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l3 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); + sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); + } + } + __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); + } + } + + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + float32x4_t sumf[4]; + for (int m = 0; m < 4; m++) { + sumf[m] = vdupq_n_f32(0); + } + + for (int l = 0; l < nb; l++) { + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + + int32x4_t sumi_0 = vdupq_n_s32(0); + int32x4_t sumi_1 = vdupq_n_s32(0); + int32x4_t sumi_2 = vdupq_n_s32(0); + int32x4_t sumi_3 = vdupq_n_s32(0); + + for (int k = 0; k < 4; k++) { + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); + + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); + int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); + + sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + } + + sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); + sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); + sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); + sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); + } + + for (int m = 0; m < 4; m++) { + vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + } + } + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 2 / blck_size_interleave; + + if (blck_size_interleave == 8) { + const uint64_t xor_mask = 0x8888888888888888ULL; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + // Using memcpy to avoid unaligned memory accesses + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + } else if (blck_size_interleave == 4) { + const uint32_t xor_mask = 0x88888888; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint32_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +// interleave 8 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x8 +// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks +// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 4 / blck_size_interleave; + const uint64_t xor_mask = 0x8888888888888888ULL; + + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + return out; +} + +static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + constexpr int nrows_interleaved = 4; + + block_q4_0x4 * dst = (block_q4_0x4 *)t->data; + const block_q4_0 * src = (const block_q4_0 *)data; + block_q4_0 dst_tmp[4]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q4_0x8 * dst = (block_q4_0x8*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; + block_q4_0 dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * 2 / blck_size_interleave; + + // TODO: this branch seems wrong + //if (blck_size_interleave == 8) { + // for (int i = 0; i < end; ++i) { + // int src_id = i % 4; + // int src_offset = (i / 4) * blck_size_interleave; + // int dst_offset = i * blck_size_interleave; + + // // Using memcpy to avoid unaligned memory accesses + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + // } + //} else + if (blck_size_interleave == 4) { + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); + //GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + GGML_ASSERT(interleave_block == 4); + + block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data; + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nl dst_tmp[4]; + int nrow = ggml_nrows(t); + int nrows_interleaved = 4; + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +namespace ggml::cpu::aarch64 { +// repack +template +int repack(struct ggml_tensor *, const void *, size_t); + +// TODO: generalise. +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); +} + +// TODO: needs to be revisited +//template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { +// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); +//} + +// gemv +template +void gemv(int, float *, size_t, const void *, const void *, int, int); + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> +void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +// gemm +template +void gemm(int, float *, size_t, const void *, const void *, int, int); + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> +void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; +}; + +template class tensor_traits : public tensor_traits_base { + + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + // not realy a GGML_TYPE_Q8_0 but same size. + switch (op->op) { + case GGML_OP_MUL_MAT: + size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])); + return true; + case GGML_OP_MUL_MAT_ID: + size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])); + size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2]; + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + forward_mul_mat(params, op); + return true; + case GGML_OP_MUL_MAT_ID: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_n_dims(op->src[0]) == 2); + // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2); + + char * wdata = static_cast(params->wdata); + const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10); + + assert(params->wsize >= nbw1 * ne11); + + const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float; + + int64_t i11_processed = 0; + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10, + INTER_SIZE); + } + i11_processed = ne11 - ne11 % 4; + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10); + } + + ggml_barrier(params->threadpool); + + const void * src1_wdata = params->wdata; + const size_t src1_col_stride = ggml_row_size(GGML_TYPE_Q8_0, ne10); + int64_t src0_start = (ith * ne01) / nth; + int64_t src0_end = ((ith + 1) * ne01) / nth; + src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start; + src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end; + if (src0_start >= src0_end) { + return; + } + + // If there are more than three rows in src1, use gemm; otherwise, use gemv. + if (ne11 > 3) { + gemm(ne00, (float *) ((char *) dst->data) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + } + for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { + gemv(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); + } + } + + void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * ids = op->src[2]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const ggml_from_float_t from_float = ggml_get_type_traits_cpu(GGML_TYPE_Q8_0)->from_float; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne3 == 1); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + const size_t nbw1 = ggml_row_size(GGML_TYPE_Q8_0, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) + + n_as * ne12 * sizeof(mmid_row_mapping))); + + auto wdata = (char *) params->wdata; + auto wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t)); + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] + + // src1: float32 => block_q8_0 + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11), + (void *) (wdata + i12 * nbw2 + i11 * nbw1), + ne10); + } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)] + + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + + // group rows by src0 matrix + for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int32_t id = 0; id < n_ids; ++id) { + const int32_t i02 = + *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; + matrix_row_counts[i02] += 1; + } + } + } + + ggml_barrier(params->threadpool); + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + auto src0_cur = (const char *) src0->data + cur_a*nb02; + + //const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows + + int64_t src0_cur_start = (ith * ne01) / nth; + int64_t src0_cur_end = ((ith + 1) * ne01) / nth; + src0_cur_start = + (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + if (src0_cur_start >= src0_cur_end) return; + + for (int ir1 = 0; ir1 < nr1; ir1++) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2); + + gemv( + ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, + ne01, src0_cur + src0_cur_start * nb01, + src1_col, 1, src0_cur_end - src0_cur_start); + } + } +#undef MMID_MATRIX_ROW + } + + int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), + (int) NB_COLS, (int) INTER_SIZE); + return ggml::cpu::aarch64::repack(t, data, data_size); + } +}; + +// instance for Q4 +static const tensor_traits q4_0_4x4_q8_0; +static const tensor_traits q4_0_4x8_q8_0; +static const tensor_traits q4_0_8x8_q8_0; + +// instance for IQ4 +static const tensor_traits iq4_nl_4x4_q8_0; + +} // namespace ggml::cpu::aarch64 + +static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) { + if (cur->type == GGML_TYPE_Q4_0) { + if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { + if (cur->ne[1] % 8 == 0) { + return &ggml::cpu::aarch64::q4_0_8x8_q8_0; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 4 == 0) { + return &ggml::cpu::aarch64::q4_0_4x8_q8_0; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &ggml::cpu::aarch64::q4_0_4x4_q8_0; + } + } + } else if (cur->type == GGML_TYPE_IQ4_NL) { + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0; + } + } + } + + return nullptr; +} + +static void ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + tensor->extra = (void *) const_cast(ggml_aarch64_get_optimal_repack_type(tensor)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, + const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra; + auto OK = tensor_traits->repack(tensor, data, size); + + GGML_ASSERT(OK == 0); + GGML_UNUSED(buffer); +} + +static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_AARCH64"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; +} + +static size_t ggml_backend_cpu_aarch64_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +namespace ggml::cpu::aarch64 { +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + if ( op->op == GGML_OP_MUL_MAT && + op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() && + ggml_aarch64_get_optimal_repack_type(op->src[0]) + ) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + //if (op->src[1]->type == GGML_TYPE_Q8_0) { + // return true; + //} + // may be possible if Q8_0 packed... + } else if (op->op == GGML_OP_MUL_MAT_ID + && op->src[0]->buffer + && (ggml_n_dims(op->src[0]) == 3) + && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() + && ggml_aarch64_get_optimal_repack_type(op->src[0]) + ) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + //if (op->src[1]->type == GGML_TYPE_Q8_0) { + // return true; + //} + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) { + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + } + return nullptr; + } +}; +} // namespace ggml::cpu::aarch64 + +ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_aarch64_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes + /* .is_host = */ nullptr, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ new ggml::cpu::aarch64::extra_buffer_type(), + }; + + return &ggml_backend_cpu_buffer_type_aarch64; +} diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.h b/ggml/src/ggml-cpu/ggml-cpu-aarch64.h new file mode 100644 index 000000000..6e84c826b --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.h @@ -0,0 +1,8 @@ +#pragma once + +#include "ggml-cpu-traits.h" +#include "ggml.h" + +// GGML internal header + +ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void); diff --git a/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp b/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp new file mode 100644 index 000000000..fa8dea2af --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp @@ -0,0 +1,55 @@ +#ifdef GGML_USE_CPU_HBM + +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" + +#include "ggml-cpu-hbm.h" + +// buffer type HBM + +#include + +static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_HBM"; + + GGML_UNUSED(buft); +} + +static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) { + hbw_free(buffer->context); +} + +static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + void * ptr; + int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size); + if (result != 0) { + GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size); + return NULL; + } + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer; + + return buffer; +} + +ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, + }, + /* .context = */ nullptr, + }; + + return &ggml_backend_cpu_buffer_type_hbm; +} +#endif diff --git a/ggml/src/ggml-cpu/ggml-cpu-hbm.h b/ggml/src/ggml-cpu/ggml-cpu-hbm.h new file mode 100644 index 000000000..09a1f09d7 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu-hbm.h @@ -0,0 +1,8 @@ +#pragma once + +#include "ggml-backend.h" +#include "ggml.h" + +// GGML CPU internal header + +ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void); diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h new file mode 100644 index 000000000..d71076ad1 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -0,0 +1,386 @@ +#pragma once + +// GGML CPU internal header + +#include "ggml.h" +#include "ggml-impl.h" +#include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ +//#include +#include +#include // memcpy +#include // fabsf + + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_compute_params { + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + + struct ggml_threadpool * threadpool; +}; + + +#if defined(_MSC_VER) + +#define m512bh(p) p +#define m512i(p) p + +#else + +#define m512bh(p) (__m512bh)(p) +#define m512i(p) (__m512i)(p) + +#endif + +// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 +#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __FMA__ +#define __FMA__ +#endif +#ifndef __F16C__ +#define __F16C__ +#endif +#endif + +// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available +#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __SSE3__ +#define __SSE3__ +#endif +#ifndef __SSSE3__ +#define __SSSE3__ +#endif +#endif + +#if defined(__ARM_FEATURE_SVE) +#include +#include +#endif + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#if defined(__ARM_NEON) + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#ifdef _MSC_VER + +typedef uint16_t ggml_fp16_internal_t; + +#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } + +#else + +typedef __fp16 ggml_fp16_internal_t; + +#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } + +#endif // _MSC_VER + +#if !defined(__aarch64__) + +// 32-bit ARM compatibility + +// vaddlvq_s16 +// vpaddq_s16 +// vpaddq_s32 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 +// vzip1_u8 +// vzip2_u8 + +inline static int32_t vaddlvq_s16(int16x8_t v) { + int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v))); + return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { + int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); + int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); + return vcombine_s32(a0, b0); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline static float vmaxvq_f32(float32x4_t v) { + return + MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + +inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[0]; res[1] = b[0]; + res[2] = a[1]; res[3] = b[1]; + res[4] = a[2]; res[5] = b[2]; + res[6] = a[3]; res[7] = b[3]; + + return res; +} + +inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { + uint8x8_t res; + + res[0] = a[4]; res[1] = b[4]; + res[2] = a[5]; res[3] = b[5]; + res[4] = a[6]; res[5] = b[6]; + res[6] = a[7]; res[7] = b[7]; + + return res; +} + +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct ggml_int16x8x2_t { + int16x8_t val[2]; +} ggml_int16x8x2_t; + +inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { + ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct ggml_uint8x16x2_t { + uint8x16_t val[2]; +} ggml_uint8x16x2_t; + +inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { + ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct ggml_uint8x16x4_t { + uint8x16_t val[4]; +} ggml_uint8x16x4_t; + +inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { + ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct ggml_int8x16x2_t { + int8x16_t val[2]; +} ggml_int8x16x2_t; + +inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { + ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct ggml_int8x16x4_t { + int8x16_t val[4]; +} ggml_int8x16x4_t; + +inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { + ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +// NOTE: not tested +inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { + int8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +// NOTE: not tested +inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { + uint8x16_t res; + + res[ 0] = a[b[ 0]]; + res[ 1] = a[b[ 1]]; + res[ 2] = a[b[ 2]]; + res[ 3] = a[b[ 3]]; + res[ 4] = a[b[ 4]]; + res[ 5] = a[b[ 5]]; + res[ 6] = a[b[ 6]]; + res[ 7] = a[b[ 7]]; + res[ 8] = a[b[ 8]]; + res[ 9] = a[b[ 9]]; + res[10] = a[b[10]]; + res[11] = a[b[11]]; + res[12] = a[b[12]]; + res[13] = a[b[13]]; + res[14] = a[b[14]]; + res[15] = a[b[15]]; + + return res; +} + +#else + +#define ggml_int16x8x2_t int16x8x2_t +#define ggml_uint8x16x2_t uint8x16x2_t +#define ggml_uint8x16x4_t uint8x16x4_t +#define ggml_int8x16x2_t int8x16x2_t +#define ggml_int8x16x4_t int8x16x4_t + +#define ggml_vld1q_s16_x2 vld1q_s16_x2 +#define ggml_vld1q_u8_x2 vld1q_u8_x2 +#define ggml_vld1q_u8_x4 vld1q_u8_x4 +#define ggml_vld1q_s8_x2 vld1q_s8_x2 +#define ggml_vld1q_s8_x4 vld1q_s8_x4 +#define ggml_vqtbl1q_s8 vqtbl1q_s8 +#define ggml_vqtbl1q_u8 vqtbl1q_u8 + +#endif // !defined(__aarch64__) + +#if !defined(__ARM_FEATURE_DOTPROD) + +inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { + const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); + const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); + + return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); +} + +#else + +#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) + +#endif // !defined(__ARM_FEATURE_DOTPROD) + +#endif // defined(__ARM_NEON) + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif + +#ifdef __riscv_v_intrinsic +#include +#endif + +#if defined(__loongarch64) +#if defined(__loongarch_asx) +#include +#endif +#if defined(__loongarch_sx) +#include +#endif +#endif + +#if defined(__loongarch_asx) + +typedef union { + int32_t i; + float f; +} ft_union; + +/* float type data load instructions */ +static __m128 __lsx_vreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); +} + +static __m256 __lasx_xvreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); +} +#endif + +// TODO: move to ggml-threading +void ggml_barrier(struct ggml_threadpool * tp); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c new file mode 100644 index 000000000..88303ff0e --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.c @@ -0,0 +1,10920 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include "ggml-quants.h" +#include "ggml-cpu-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu-impl.h" +#include "ggml-cpu.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#if defined(_MSC_VER) +// disable "possible loss of data" to avoid warnings for hundreds of casts +// we should just be careful :) +#pragma warning(disable: 4244 4267) +#endif + +#define UNUSED GGML_UNUSED + +// some compilers don't provide _mm256_set_m128i, e.g. gcc 7 +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(x, x); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + const __m128i ones = _mm_set1_epi16(1); + return _mm_madd_epi16(ones, dot); +} + +#if __AVX__ || __AVX2__ || __AVX512F__ +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#if defined(__AVX2__) || defined(__AVX512F__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + return _mm256_and_si256(lowMask, bytes); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} +#elif defined(__AVX__) +static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m128i lowByte = _mm_set1_epi16( 0xFF ); + __m128i high = _mm_andnot_si128( lowByte, bytes1 ); + __m128i low = _mm_and_si128( lowByte, bytes1 ); + high = _mm_srli_epi16( high, 4 ); + bytes1 = _mm_or_si128( low, high ); + high = _mm_andnot_si128( lowByte, bytes2 ); + low = _mm_and_si128( lowByte, bytes2 ); + high = _mm_srli_epi16( high, 4 ); + bytes2 = _mm_or_si128( low, high ); + + return _mm_packus_epi16( bytes1, bytes2); +} + +static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { + const __m128i ax = _mm_sign_epi8(x, x); + const __m128i sy = _mm_sign_epi8(y, x); + return _mm_maddubs_epi16(ax, sy); +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); + __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); + __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); + const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytesl = _mm_or_si128(bytesl, bit_mask); + bytesh = _mm_or_si128(bytesh, bit_mask); + bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); + bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); + return MM256_SET_M128I(bytesh, bytesl); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + // Load 16 bytes from memory + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); + __m128i tmph = _mm_srli_epi16(tmpl, 4); + const __m128i lowMask = _mm_set1_epi8(0xF); + tmpl = _mm_and_si128(lowMask, tmpl); + tmph = _mm_and_si128(lowMask, tmph); + return MM256_SET_M128I(tmph, tmpl); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { + const __m128i ones = _mm_set1_epi16(1); + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m128i xl = _mm256_castsi256_si128(x); + const __m128i xh = _mm256_extractf128_si256(x, 1); + const __m128i yl = _mm256_castsi256_si128(y); + const __m128i yh = _mm256_extractf128_si256(y, 1); + // Get absolute values of x vectors + const __m128i axl = _mm_sign_epi8(xl, xl); + const __m128i axh = _mm_sign_epi8(xh, xh); + // Sign the values of the y vectors + const __m128i syl = _mm_sign_epi8(yl, xl); + const __m128i syh = _mm_sign_epi8(yh, xh); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors +static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1, + const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) { + const __m128i mone = _mm_set1_epi16(1); + + const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1); + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); + const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1); + const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1); + return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1)); +} + +// quad fp16 delta calculation +static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) { + // GGML_FP16_TO_FP32 is faster than Intel F16C + return _mm256_set_m128(_mm_set1_ps(GGML_FP16_TO_FP32(x1) * GGML_FP16_TO_FP32(y1)), + _mm_set1_ps(GGML_FP16_TO_FP32(x0) * GGML_FP16_TO_FP32(y0))); +} +#endif +#elif defined(__SSSE3__) +// horizontally add 4x4 floats +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =_mm_hadd_ps(a, b); + __m128 res_1 =_mm_hadd_ps(c, d); + __m128 res =_mm_hadd_ps(res_0, res_1); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif // __AVX__ || __AVX2__ || __AVX512F__ +#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) + +#if defined(__ARM_NEON) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +#if defined(__loongarch_asx) + +#ifdef __clang__ +#define VREGS_PREFIX "$vr" +#define XREGS_PREFIX "$xr" +#else // GCC +#define VREGS_PREFIX "$f" +#define XREGS_PREFIX "$f" +#endif +#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" +// Convert __m128i to __m256i +static inline __m256i ____m256i(__m128i in) { + __m256i out = __lasx_xvldi(0); + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX"\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "+f" (out) : [in] "f" (in) + ); + return out; +} +// Convert two __m128i to __m256i +static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { + __m256i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".ifnc %[out], %[hi] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" + " xvori.b $xr\\i, $xr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out), [hi] "+f" (inhi) + : [lo] "f" (inlo) + ); + return out; +} +// Convert __m256i low part to __m128i +static inline __m128i lasx_extracti128_lo(__m256i in) { + __m128i out; + __asm__ volatile ( + ".ifnc %[out], %[in] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " vori.b $vr\\i, $vr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} +// Convert __m256i high part to __m128i +static inline __m128i lasx_extracti128_hi(__m256i in) { + __m128i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} + +static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) { + v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7}; + return (__m256i)__ret; +} + +static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { + v4i32 __ret = {d, c, b, a}; + return (__m128i)__ret; +} + +static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { + v4i64 __ret = {d, c, b, a}; + return (__m256i)__ret; +} + +static __m256i lasx_insertf128( __m128i x, __m128i y) { + return lasx_set_q(x, y); +} + +static __m128i lsx_shuffle_b(__m128i a, __m128i b) { + __m128i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lsx_vreplgr2vr_b(f); + zero = __lsx_vldi(0); + tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones + return __lsx_vshuf_b(a, zero, tmp2); +} + +static __m256i lasx_shuffle_b(__m256i a, __m256i b) { + __m256i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lasx_xvreplgr2vr_b(f); + zero = __lasx_xvldi(0); + tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones + return __lasx_xvshuf_b(a, zero, tmp2); +} + +static __m256i lasx_extu8_16(__m128i a) { + __m128i zero = __lsx_vldi(0); + __m128i vlo = __lsx_vilvl_b(zero, a); + __m128i vhi = __lsx_vilvh_b(zero, a); + return lasx_set_q(vhi, vlo); +} + +static __m256i lasx_ext8_16(__m128i a) { + __m128i sign = __lsx_vslti_b(a, 0); + __m128i vlo = __lsx_vilvl_b(sign, a); + __m128i vhi = __lsx_vilvh_b(sign, a); + return lasx_set_q(vhi, vlo); +} + +static __m256i lasx_ext16_32(__m128i a) { + __m256i tmp1; + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7); + return tmp1; +} + +static __m128i lasx_extracti128( __m256i a, int pos) { + __m128i ret; + if( pos == 0) + { + ret = lasx_extracti128_lo(a); + } else { + ret = lasx_extracti128_hi(a); + } + return ret; +} + +static __m128 lasx_extractf128( __m256 a, int pos) { + __m128 ret; + if( pos == 0) + { + ret = (__m128)lasx_extracti128_lo((__m256i)a); + } else { + ret = (__m128)lasx_extracti128_hi((__m256i)a); + } + return ret; +} + +static __m128i lsx_hadd_h(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_h(b, a); + __m128i tmp2 = __lsx_vpickod_h(b, a); + return __lsx_vadd_h(tmp1, tmp2); +} + +static __m128i lsx_hadd_w(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_w(b, a); + __m128i tmp2 = __lsx_vpickod_w(b, a); + return __lsx_vadd_w(tmp1, tmp2); +} + +static __m128 lsx_hadd_s(__m128 a, __m128 b) { + __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); + __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); + + return __lsx_vfadd_s(tmp1, tmp2); +} + +static __m256i lasx_maddubs_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_h_b(a, b); + tmp2 = __lasx_xvmulwod_h_b(a, b); + return __lasx_xvsadd_h(tmp1, tmp2); +} + +static __m256i lasx_madd_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_w_h(a, b); + tmp2 = __lasx_xvmulwod_w_h(a, b); + return __lasx_xvadd_w(tmp1, tmp2); +} + +static __m256i lasx_packs_w(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_w(a, 15); + tmp1 = __lasx_xvsat_w(b, 15); + return __lasx_xvpickev_h(tmp1, tmp); +} + +static __m256i lasx_packs_h(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_h(a, 7); + tmp1 = __lasx_xvsat_h(b, 7); + return __lasx_xvpickev_b(tmp1, tmp); +} + +static __m128i lsx_packs_w(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_w(a, 15); + tmp1 = __lsx_vsat_w(b, 15); + return __lsx_vpickev_h(tmp1, tmp); +} + +static __m128i lsx_packs_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_h(a, 7); + tmp1 = __lsx_vsat_h(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + +static __m128i lsx_packus_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_hu(a, 7); + tmp1 = __lsx_vsat_hu(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + + +static __m128i lsx_maddubs_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_h_b(a, b); + tmp2 = __lsx_vmulwod_h_b(a, b); + return __lsx_vsadd_h(tmp1, tmp2); +} + +static __m128i lsx_madd_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_w_h(a, b); + tmp2 = __lsx_vmulwod_w_h(a, b); + return __lsx_vadd_w(tmp1, tmp2); +} + +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = __lsx_vsigncov_b(x, x); + // Sign the values of the y vectors + const __m128i sy = __lsx_vsigncov_b(x, y); + // Perform multiplication and create 16-bit values + const __m128i dot = lsx_maddubs_h(ax, sy); + const __m128i ones = __lsx_vreplgr2vr_h(1); + return lsx_madd_h(ones, dot); +} + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = lasx_extractf128(x, 1); + ft_union tmp; + res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); + tmp.i = __lsx_vpickve2gr_w(res, 0); + return tmp.f; +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + + __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); + __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); + + __m128i tmp1_128 = lasx_extracti128_lo(tmp1); + __m128i tmp2_128 = lasx_extracti128_lo(tmp2); + + __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); + + __m128i ev = __lsx_vpickev_w(sum128, sum128); + __m128i od = __lsx_vpickod_w(sum128, sum128); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + __m128i ev = __lsx_vpickev_w(a, a); + __m128i od = __lsx_vpickod_w(a, a); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = lasx_set_d( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + + __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask); + const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe); + bytes = __lasx_xvor_v(bytes, bit_mask); + return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); + __m128i hi = __lsx_vsrli_h(lo, 4); + return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + __m256i v = __lasx_xvpackod_h(x, x); + __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); + return __lasx_xvffint_s_w(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + // Perform multiplication and create 16-bit values + const __m256i dot = lasx_maddubs_h(ax, sy); + return sum_i16_pairs_float(dot); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + + // Get absolute values of x vectors + const __m256i ax = __lasx_xvsigncov_b(x, x); + // Sign the values of the y vectors + const __m256i sy = __lasx_xvsigncov_b(x, y); + + return mul_sum_us8_pairs_float(ax, sy); +} + +static inline __m128i packNibbles( __m256i bytes ) { + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF); + __m256i high = __lasx_xvandn_v(lowByte, bytes); + __m256i low = __lasx_xvand_v(lowByte, bytes); + high = __lasx_xvsrli_h(high, 4); + bytes = __lasx_xvor_v(low, high); + // Compress uint16_t lanes into bytes + __m128i *r0 = (__m128i *)&bytes; + __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11); + __m128i *r1 = (__m128i *)&tmp_h128; + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(zero, *r0); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(zero, *r1); + tmp3 = __lsx_vsat_hu(tmp, 7); + return __lsx_vpickev_b(tmp3, tmp2); +} +#endif //__loongarch_asx + +void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q4_0_ref(x, y, k); +} + +void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q4_1_ref(x, y, k); +} + +void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q5_0_ref(x, y, k); +} + +void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q5_1_ref(x, y, k); +} + +void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_0); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + } + +#elif defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + } + +#elif defined(__loongarch_asx) + for (int i = 0; i < nb; i++) { + ft_union fi; + __m256 v0 = (__m256)__lasx_xvld( x , 0); + __m256 v1 = (__m256)__lasx_xvld( x , 32); + __m256 v2 = (__m256)__lasx_xvld( x , 64); + __m256 v3 = (__m256)__lasx_xvld( x , 96); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); + __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); + + __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) ); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); + __m128 tmp = max4; + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 )); + fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); + const float max_scalar = fi.f; + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id ); + + // Apply the multiplier + v0 = __lasx_xvfmul_s( v0, mul ); + v1 = __lasx_xvfmul_s( v1, mul ); + v2 = __lasx_xvfmul_s( v2, mul ); + v3 = __lasx_xvfmul_s( v3, mul ); + + // Round to nearest integer + __m256i i0 = __lasx_xvftintrne_w_s( v0 ); + __m256i i1 = __lasx_xvftintrne_w_s( v1 ); + __m256i i2 = __lasx_xvftintrne_w_s( v2 ); + __m256i i3 = __lasx_xvftintrne_w_s( v3 ); + + __m128i ni0 = lasx_extracti128( i0, 0 ); + __m128i ni1 = lasx_extracti128( i0, 1); + __m128i ni2 = lasx_extracti128( i1, 0); + __m128i ni3 = lasx_extracti128( i1, 1); + __m128i ni4 = lasx_extracti128( i2, 0); + __m128i ni5 = lasx_extracti128( i2, 1); + __m128i ni6 = lasx_extracti128( i3, 0); + __m128i ni7 = lasx_extracti128( i3, 1); + + // Convert int32 to int16 + ni0 = lsx_packs_w( ni0, ni1 ); + ni2 = lsx_packs_w( ni2, ni3 ); + ni4 = lsx_packs_w( ni4, ni5 ); + ni6 = lsx_packs_w( ni6, ni7 ); + // Convert int16 to int8 + ni0 = lsx_packs_h( ni0, ni2 ); + ni4 = lsx_packs_h( ni4, ni6 ); + + __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); + __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); + + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); + } + + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + v128_t accv = wasm_i32x4_splat(0); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + + accv = wasm_i32x4_add(accv, vi); + } + + y[i].s = GGML_FP32_TO_FP16( + d * (wasm_i32x4_extract_lane(accv, 0) + + wasm_i32x4_extract_lane(accv, 1) + + wasm_i32x4_extract_lane(accv, 2) + + wasm_i32x4_extract_lane(accv, 3))); + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float max_scalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Compute the sum of the quants and set y[i].s + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1))); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_1); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + + // compute sum for y[i].s + vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); + + // set y[i].s + int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); + y[i].s = GGML_FP32_TO_FP16(sum*d); + } + +#elif defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = GGML_FP32_TO_FP16(d); + + vector int accv = vec_splats(0); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + + accv = vec_add(accv, vi[j]); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + + accv = vec_add(accv, vec_sld(accv, accv, 4)); + accv = vec_add(accv, vec_sld(accv, accv, 8)); + y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); + } + +#elif defined(__loongarch_asx) + for (int i = 0; i < nb; i++) { + ft_union ft; + __m256 v0 = (__m256)__lasx_xvld( x , 0 ); + __m256 v1 = (__m256)__lasx_xvld( x , 32 ); + __m256 v2 = (__m256)__lasx_xvld( x , 64 ); + __m256 v3 = (__m256)__lasx_xvld( x , 96 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); + __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); + + __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); + __m128 tmp = max4; + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); + ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); + const float max_scalar = ft.f; + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = __lasx_xvreplfr2vr_s( id ); + + // Apply the multiplier + v0 = __lasx_xvfmul_s( v0, mul ); + v1 = __lasx_xvfmul_s( v1, mul ); + v2 = __lasx_xvfmul_s( v2, mul ); + v3 = __lasx_xvfmul_s( v3, mul ); + + // Round to nearest integer + __m256i i0 = __lasx_xvftintrne_w_s( v0 ); + __m256i i1 = __lasx_xvftintrne_w_s( v1 ); + __m256i i2 = __lasx_xvftintrne_w_s( v2 ); + __m256i i3 = __lasx_xvftintrne_w_s( v3 ); + + __m128i ni0 = lasx_extracti128(i0, 0); + __m128i ni1 = lasx_extracti128( i0, 1); + __m128i ni2 = lasx_extracti128( i1, 0); + __m128i ni3 = lasx_extracti128( i1, 1); + __m128i ni4 = lasx_extracti128( i2, 0 ); + __m128i ni5 = lasx_extracti128( i2, 1); + __m128i ni6 = lasx_extracti128( i3, 0); + __m128i ni7 = lasx_extracti128( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3)); + const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7)); + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1))); + + // Convert int32 to int16 + ni0 = lsx_packs_w( ni0, ni1 ); + ni2 = lsx_packs_w( ni2, ni3 ); + ni4 = lsx_packs_w( ni4, ni5 ); + ni6 = lsx_packs_w( ni6, ni7 ); + // Convert int16 to int8 + ni0 = lsx_packs_h( ni0, ni2 ); + ni4 = lsx_packs_h( ni4, ni6 ); + + __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); + __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + +// +// 2-6 bit quantization in super-blocks +// + +// +// ===================== Helper functions +// +static inline int nearest_int(float fval) { + assert(fabsf(fval) <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type, + const float * restrict qw) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (amax < GROUP_MAX_EPS) { // all zero + for (int i = 0; i < n; ++i) { + L[i] = 0; + } + return 0.f; + } + float iscale = -nmax / max; + if (rmse_type == 0) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + return 1/iscale; + } + bool return_early = false; + if (rmse_type < 0) { + rmse_type = -rmse_type; + return_early = true; + } + float sumlx = 0; + float suml2 = 0; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 0; i < n; ++i) { +#else + for (int i = 0; i < n; ++i) { +#endif + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + float scale = suml2 ? sumlx/suml2 : 0.0f; + if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; + float best = scale * sumlx; + for (int is = -9; is <= 9; ++is) { + if (is == 0) { + continue; + } + iscale = -(nmax + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + } + } + return scale; +} + +static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (amax < GROUP_MAX_EPS) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = -nmax / max; + if (do_rmse) { + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l; + float w = x[i]*x[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = x[i]*x[i]; + float slx = sumlx - w*x[i]*L[i]; + if (slx > 0) { + float sl2 = suml2 - w*L[i]*L[i]; + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MAX(-nmax, MIN(nmax-1, new_l)); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + for (int i = 0; i < n; ++i) { + L[i] += nmax; + } + return sumlx / suml2; + } + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + } + return 1/iscale; +} + +static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, + int ntry, float alpha) { + float min = x[0]; + float max = x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + } + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = 0; + return 0.f; + } + if (min > 0) min = 0; + float iscale = nmax/(max - min); + float scale = 1/iscale; + for (int itry = 0; itry < ntry; ++itry) { + float sumlx = 0; int suml2 = 0; + bool did_change = false; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + if (l != L[i]) { + L[i] = l; + did_change = true; + } + sumlx += (x[i] - min)*l; + suml2 += l*l; + } + scale = sumlx/suml2; + float sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] - scale*L[i]; + } + min = alpha*min + (1 - alpha)*sum/n; + if (min > 0) min = 0; + iscale = 1/scale; + if (!did_change) break; + } + *the_min = -min; + return scale; +} + +static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights, + uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights[0]; + float sum_x = sum_w * x[0]; +#ifdef HAVE_BUGGY_APPLE_LINKER + // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 + for (volatile int i = 1; i < n; ++i) { +#else + for (int i = 1; i < n; ++i) { +#endif + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) min = 0; + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + +static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +//========================- 2-bit (de)-quantization + +void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { + quantize_row_q2_K_ref(x, vy, k); +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { + quantize_row_q3_K_ref(x, vy, k); +} + +// ====================== 4-bit (de)-quantization + +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_q4_K * restrict y = vy; + quantize_row_q4_K_ref(x, y, k); +} + +// ====================== 5-bit (de)-quantization + +void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_q5_K * restrict y = vy; + quantize_row_q5_K_ref(x, y, k); +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_q6_K * restrict y = vy; + quantize_row_q6_K_ref(x, y, k); +} + +// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) + +void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_tq1_0 * restrict y = vy; + quantize_row_tq1_0_ref(x, y, k); +} + +void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_tq2_0 * restrict y = vy; + quantize_row_tq2_0_ref(x, y, k); +} + +static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +//===================================== Q8_K ============================================== + +void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q8_K_ref(x, y, k); +} + +//===================================== Dot products ================================= + +// +// Helper functions +// +#if __AVX__ || __AVX2__ || __AVX512F__ + +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} +#elif defined(__loongarch_asx) +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return __lasx_xvld((const __m256i*)k_shuffle + i, 0); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return __lasx_xvld((const __m256i*)k_shuffle + i, 0); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return __lsx_vld((const __m128i*)k_shuffle + i, 0); +} +#endif + +void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_0 * restrict vx0 = vx; + const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx); + const block_q8_0 * restrict vy0 = vy; + const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_0 * restrict b_x0 = &vx0[i]; + const block_q4_0 * restrict b_x1 = &vx1[i]; + const block_q8_0 * restrict b_y0 = &vy0[i]; + const block_q8_0 * restrict b_y1 = &vy1[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); + const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); + const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); + const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + float32_t _scale[4] = { + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) + }; + float32x4_t scale = vld1q_f32(_scale); + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + vst1_f32(s, vget_low_f32 (sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + + return; + } +#endif + + int ib = 0; + float sumf = 0; + +#if defined(__ARM_FEATURE_SVE) + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + + const int vector_length = ggml_cpu_get_sve_cnt()*8; + + // VLA Implementation using switch case + switch (vector_length) { + case 128: + { + // predicate for activating higher lanes for 4 float32 elements + const svbool_t ph4 = svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F)); + const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04)); + const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F)); + const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); + + // sub 8 + const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); + const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); + const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); + const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); + + // load y + const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16); + const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); + const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16); + + // dot product + sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx0ls, qy0l), + svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx1ls, qy1l), + svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 256: + { + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 512: + { + // predicate for activating higher lanes for 32 int8 elements + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes + const svbool_t pl16 = svnot_b_z(ph32, ph16); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs); + const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(ph32, y0->qs); + const svint8_t qy1 = svld1_s8(ph32, y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1)); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + +#elif defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // dot product into int32x4_t + const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + qx = _mm256_sub_epi8( qx, off ); + + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#elif defined(__AVX__) + __m256 accum = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8)); + const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8)); + const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8)); + const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8)); + + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1); + const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1); + const __m256 p = sum_i16_pairs_float(p_2, p_1); + + const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + for (; ib + 1 < nb; ib += 2) { + _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } + + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#elif defined(__riscv_v_intrinsic) + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (; ib < nb; ++ib) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + // subtract offset + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_sub(q4x0, v8); + q4x1 = vec_sub(q4x1, v8); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi0 = vec_sum4s(qv1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = __lasx_xvreplgr2vr_b( 8 ); + qx = __lasx_xvsub_b( qx, off ); + + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = __lasx_xvfmadd_s( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#elif defined(__loongarch_sx) + // set constants + const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); + const __m128i off = __lsx_vreplgr2vr_b(8); + + // Initialize accumulator with zeros + __m128 acc_0 = __lsx_vldi(0); + __m128 acc_1 = __lsx_vldi(0); + __m128 acc_2 = __lsx_vldi(0); + __m128 acc_3 = __lsx_vldi(0); + + for (; ib + 1 < nb; ib += 2) { + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); + + const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0); + + __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); + __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); + bx_0 = __lsx_vsub_b(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); + __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0); + bx_1 = __lsx_vsub_b(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); + + const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0); + + __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); + __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0); + bx_2 = __lsx_vsub_b(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); + __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0); + bx_3 = __lsx_vsub_b(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = __lsx_vffint_s_w(i32_0); + __m128 p1 = __lsx_vffint_s_w(i32_1); + __m128 p2 = __lsx_vffint_s_w(i32_2); + __m128 p3 = __lsx_vffint_s_w(i32_3); + + // Apply the scale + __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 ); + __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 ); + __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 ); + __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 ); + + // Acummulate + acc_0 = __lsx_vfadd_s(p0_d, acc_0); + acc_1 = __lsx_vfadd_s(p1_d, acc_1); + acc_2 = __lsx_vfadd_s(p2_d, acc_2); + acc_3 = __lsx_vfadd_s(p3_d, acc_3); + } + + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_1 * restrict vx0 = vx; + const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx); + const block_q8_1 * restrict vy0 = vy; + const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by); + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t summs0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_1 * restrict b_x0 = &vx0[i]; + const block_q4_1 * restrict b_x1 = &vx1[i]; + const block_q8_1 * restrict b_y0 = &vy0[i]; + const block_q8_1 * restrict b_y1 = &vy1[i]; + + float32_t summs_t[4] = { + GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s) + }; + summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + // mmla into int32x4_t + float32_t _scale[4] = { + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) + }; + float32x4_t scale = vld1q_f32(_scale); + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + sumv2 = vaddq_f32(sumv2, summs0); + + vst1_f32(s, vget_low_f32 (sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + + return; + } +#endif + + int ib = 0; + float sumf = 0; + + // TODO: add WASM SIMD +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs = 0; + + for (; ib + 1 < nb; ib += 2) { + const block_q4_1 * restrict x0 = &x[ib + 0]; + const block_q4_1 * restrict x1 = &x[ib + 1]; + const block_q8_1 * restrict y0 = &y[ib + 0]; + const block_q8_1 * restrict y1 = &y[ib + 1]; + + summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // dot product into int32x4_t + const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + // Main loop + for (; ib < nb; ++ib) { + const float d0 = GGML_FP16_TO_FP32(x[ib].d); + const float d1 = GGML_FP16_TO_FP32(y[ib].d); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs ); + + const __m256 xy = mul_sum_us8_pairs_float(qx, qy); + + // Accumulate d0*d1*x*y +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d0d1, xy, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif + } + + sumf = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (; ib < nb; ++ib) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask); + vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_msum(q8y0, q4x0, vsumi0); + vsumi0 = vec_msum(q8y1, q4x1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0; + + // Main loop + for (; ib < nb; ++ib) { + const float d0 = GGML_FP16_TO_FP32(x[ib].d); + const float d1 = GGML_FP16_TO_FP32(y[ib].d); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); + const __m256 d1v = __lasx_xvreplfr2vr_s( d1 ); + + // Compute combined scales + const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0); + + const __m256 xy = mul_sum_us8_pairs_float(qx, qy); + + // Accumulate d0*d1*x*y + acc = __lasx_xvfmadd_s( d0d1, xy, acc ); + } + + sumf = hsum_float_8(acc) + summs; +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + for (; ib + 1 < nb; ib += 2) { + const block_q5_0 * restrict x0 = &x[ib]; + const block_q5_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + // extract the 5th bit via lookup table ((!b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (; ib < nb; ++ib) { + const block_q5_0 * restrict x0 = &x[ib]; + const block_q8_0 * restrict y0 = &y[ib]; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); + const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( + wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + qx = _mm256_or_si256(qx, bxhi); + + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + sumf = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8((char)0xF0); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_andnot_si128(bxhil, mask); + bxhih = _mm_andnot_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx_0 = MM256_SET_M128I(bxh, bxl); + + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); + + /* Multiply q with scale and accumulate */ + acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); + } + + sumf = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + uint32_t qh; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + // These temporary registers are for masking and shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); + + vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); + vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + + for (; ib < nb; ++ib) { + memcpy(&qh, x[ib].qh, sizeof(uint32_t)); + + // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + + // ((qh & (1u << (j + 16))) >> (j + 12)); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); + vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); + + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); + + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])}; + vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + + vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); + vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); + + vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); + + qv0 = vec_add(qv0, qv1); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); //FIXME + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0)); + qx = __lasx_xvor_v(qx, bxhi); + + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = __lasx_xvfmadd_s(d, q, acc); + } + + sumf = hsum_float_8(acc); +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + + *s = sumf; +} + +void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + for (; ib + 1 < nb; ib += 2) { + const block_q5_1 * restrict x0 = &x[ib]; + const block_q5_1 * restrict x1 = &x[ib + 1]; + const block_q8_1 * restrict y0 = &y[ib]; + const block_q8_1 * restrict y1 = &y[ib + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); + summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); + + // extract the 5th bit via lookup table ((b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit + const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + float summs = 0.0f; + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (; ib < nb; ++ib) { + const block_q5_1 * restrict x0 = &x[ib]; + const block_q8_1 * restrict y0 = &y[ib]; + + summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit + const v128_t v0lf = wasm_v128_or(v0l, qhl); + const v128_t v0hf = wasm_v128_or(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.0f; + + // Main loop + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + qx = _mm256_or_si256(qx, bxhi); + + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_us8_pairs_float(qx, qy); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + sumf = hsum_float_8(acc) + summs; +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8(0x10); + + float summs = 0.0f; + + // Main loop + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_and_si128(bxhil, mask); + bxhih = _mm_and_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx_0 = MM256_SET_M128I(bxh, bxl); + + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); + + acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + } + + sumf = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + uint32_t qh; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + // temporary registers for shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + + for (; ib < nb; ++ib) { + memcpy(&qh, x[ib].qh, sizeof(uint32_t)); + + // load qh + vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); + + // ((qh >> (j + 0)) << 4) & 0x10; + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); + + // ((qh >> (j + 12)) ) & 0x10; + vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); + + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); + + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); + + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + + vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0); + vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_msum(q8y0, q5x0, vsumi0); + vsumi0 = vec_msum(q8y1, q5x1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0.0f; + + // Main loop + for (; ib < nb; ++ib) { + const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d)); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); + qx = __lasx_xvor_v(qx, bxhi); + + const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_us8_pairs_float(qx, qy); + + acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); + } + + sumf = hsum_float_8(acc) + summs; +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q8_0 * restrict vx0 = vx; + const block_q8_0 * restrict vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx); + const block_q8_0 * restrict vy0 = vy; + const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q8_0 * restrict b_x0 = &vx0[i]; + const block_q8_0 * restrict b_y0 = &vy0[i]; + + const block_q8_0 * restrict b_x1 = &vx1[i]; + const block_q8_0 * restrict b_y1 = &vy1[i]; + + const int8x16_t x0_l = vld1q_s8(b_x0->qs); + const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16); + const int8x16_t x1_l = vld1q_s8(b_x1->qs); + const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + float32_t _scale[4] = { + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) + }; + float32x4_t scale = vld1q_f32(_scale); + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + vst1_f32(s, vget_low_f32 (sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + + return; + } +#endif + + int ib = 0; + float sumf = 0; + +#if defined(__ARM_FEATURE_SVE) + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + + const int vector_length = ggml_cpu_get_sve_cnt()*8; + + //VLA Implemenation for SVE + switch (vector_length) { + case 128: + { + // predicate for activating lanes for 16 Int8 elements + const svbool_t ph16 = svptrue_pat_b8 (SV_VL16); + const svbool_t pl16 = svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0_0 = svld1_s8(ph16, x0->qs); + const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16); + const svint8_t qx1_0 = svld1_s8(ph16, x1->qs); + const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16); + + // load y + const svint8_t qy0_0 = svld1_s8(ph16, y0->qs); + const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16); + const svint8_t qy1_0 = svld1_s8(ph16, y1->qs); + const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16); + + sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx0_0, qy0_0), + svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx1_0, qy1_0), + svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1)); + } break; + case 256: + { + //printf("sve256"); + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 512: + { + // predicate for activating high 256 bit + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + // predicate for activating low 256 bit + const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32); + + // predicate for activating high lanes for 8 float32 elements + const svbool_t ph8 = svptrue_pat_b32(SV_VL8); + // predicate for activating low lanes for 8 float32 elements + const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8); + + svfloat32_t sumv00 = svdup_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits + // and add them to make one 64 element vector + // load x + const svint8_t qx_32 = svld1_s8(ph32, x0->qs); + svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2); + + qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); + + // load y + const svint8_t qy_32 = svld1_s8(ph32, y0->qs); + svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2); + + qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); + + // scale creation + const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d); + const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d); + + // duplicate deq1 in first half of vector and deq2 in second half of vector + const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2); + + const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); + + sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp); + } + + sumf = svaddv_f32(svptrue_b32(), sumv00); + break; + } + default: + assert(false && "Unsupported vector length"); + break; + } +#elif defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; + + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (; ib < nb; ++ib) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = _mm256_fmadd_ps( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#elif defined(__AVX__) + __m256 accum = _mm256_setzero_ps(); + + for (; ib + 1 < nb; ib += 2) { + const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs); + const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1); + const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1); + const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); + const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1); + const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1); + const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); +#elif defined(__riscv_v_intrinsic) + size_t vl = __riscv_vsetvl_e8m1(qk); + + for (; ib < nb; ++ib) { + // load elements + vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl); + vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + + vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } +#elif defined(__POWER9_VECTOR__) + const vector signed int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char q8x0 = vec_xl( 0, x[ib].qs); + vector signed char q8x1 = vec_xl(16, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed short qv0 = vec_mule(q8x0, q8y0); + vector signed short qv1 = vec_mulo(q8x0, q8y0); + vector signed short qv2 = vec_mule(q8x1, q8y1); + vector signed short qv3 = vec_mulo(q8x1, q8y1); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + vsumi0 = vec_sum4s(qv2, vsumi0); + vsumi1 = vec_sum4s(qv3, vsumi1); + + vsumi0 = vec_add(vsumi0, vsumi1); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (; ib < nb; ++ib) { + // Compute combined scale for the block + const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = __lasx_xvfmadd_s( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#endif + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + const uint8x16_t shift = vld1q_u8(k_shift); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + // first 32 bytes of 5 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); + uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); + uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); + uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); + int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); + int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); + const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); + const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); + sumi0 = vdotq_s32(sumi0, sqx8, qy8); + sumi1 = vdotq_s32(sumi1, sqx9, qy9); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); +#endif + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); + uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); + qx5 = vmulq_u8(qx5, shift); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#elif defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + + // first 32 bytes of 5 elements + { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs)); + // 8-bit multiplies with shifts, masks and adds + __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3 + __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9 + __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9 + __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9 + + // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits? + + // Cancel the +1 from avg so that it behaves like a halving add + qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1)); + qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1)); + qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1)); + qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1)); + qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256())); + qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256())); + qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256())); + qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256())); + qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256())); + qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); + qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); + const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + qx4 = _mm256_maddubs_epi16(qx4, qy4); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + sumi2 = _mm256_add_epi16(sumi2, qx4); + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh)); + __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3 + __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9 + __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9 + __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9 + __m256i qx01 = MM256_SET_M128I(qx1, qx0); + __m256i qx23 = MM256_SET_M128I(qx3, qx2); + + // avx2 does not have 8-bit multiplies, so 16-bit it is. + qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1)); + qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF)); + __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1)); + + __m256i qx45 = MM256_SET_M128I(qx5, qx4); + + // Cancel the +1 from avg so that it behaves like a halving add + qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1)); + qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1)); + qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256())); + qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256())); + qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256())); + qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3)); + qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); + qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); + + const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); + const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); + const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); + + qx01 = _mm256_maddubs_epi16(qx01, qy01); + qx23 = _mm256_maddubs_epi16(qx23, qy23); + qx45 = _mm256_maddubs_epi16(qx45, qy45); + + sumi0 = _mm256_add_epi16(sumi0, qx01); + sumi1 = _mm256_add_epi16(sumi1, qx23); + sumi2 = _mm256_add_epi16(sumi2, qx45); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int sum = 0; + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; + } + } + } + + for (size_t l = 0; l < 4; ++l) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; + } + } + + sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d); + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + const uint8x16_t m3 = vdupq_n_u8(3); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + uint8x16_t qx0 = vld1q_u8(x[i].qs + j); + uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); + uint8x16_t qx2 = vshrq_n_u8(qx0, 2); + uint8x16_t qx3 = vshrq_n_u8(qx1, 2); + uint8x16_t qx4 = vshrq_n_u8(qx0, 4); + uint8x16_t qx5 = vshrq_n_u8(qx1, 4); + uint8x16_t qx6 = vshrq_n_u8(qx0, 6); + uint8x16_t qx7 = vshrq_n_u8(qx1, 6); + + int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#elif defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums, because 256*127 still fits + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j)); + __m256i qx1 = _mm256_srli_epi16(qx0, 2); + __m256i qx2 = _mm256_srli_epi16(qx0, 4); + __m256i qx3 = _mm256_srli_epi16(qx0, 6); + + // 0, 1, 2 (should not be 3) + qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_add_epi16(sumi0, sumi1); + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t k = 0; k < 32; ++k) { + sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); + } + } + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + sumf += (float) sumi * d; + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + const uint8x16_t m3 = vdupq_n_u8(0x3); + const uint8x16_t m4 = vdupq_n_u8(0xF); + + const int32x4_t vzero = vdupq_n_s32(0); + + ggml_int8x16x2_t q2bytes; + uint8_t aux[16]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); + + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + + int isum = 0; + int is = 0; + +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; + +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); + + for (int j = 0; j < QK_K/128; ++j) { + const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; + + ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + + sum += d * isum; + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m256i mins = _mm256_cvtepi8_epi16(mins8); + const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i q2_0 = _mm256_and_si256(q2bits, m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + + __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); + __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); + + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(0x3); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // load mins and scales from block_q2_K.scales[QK_K/16] + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); + const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); + + // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 + const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); + const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); + + // sumf += -dmin * summs in 32bits*8 + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); + + const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); + const __m128i scales[2] = { scales_0, scales_1 }; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + + // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // load 2bits*16*8 from block_q2_K.qs[QK_K/4] + __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_1 = _mm_and_si128(q2bits, m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 + __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); + __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); + __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); + __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); + __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); + __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); + __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); + __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + + // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 + __m128i shuffle = _mm_set1_epi16(0x0100); + p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); + shuffle = _mm_add_epi16(shuffle, m2); + p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); + shuffle = _mm_add_epi16(shuffle, m2); + p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); + shuffle = _mm_add_epi16(shuffle, m2); + p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); + shuffle = _mm_add_epi16(shuffle, m2); + p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); + shuffle = _mm_add_epi16(shuffle, m2); + p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); + shuffle = _mm_add_epi16(shuffle, m2); + p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); + shuffle = _mm_add_epi16(shuffle, m2); + p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + + p0 = _mm_add_epi32(p0, p1); + p2 = _mm_add_epi32(p2, p3); + p4 = _mm_add_epi32(p4, p5); + p6 = _mm_add_epi32(p6, p7); + + // isum in 32bits*4*2 + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); + } + + // sumf += dall * isum - dmin * summs in 32bits + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + size_t vl = 16; + + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + + vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + + uint8_t is=0; + int isum=0; + + for (int j = 0; j < QK_K/128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); + + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2+=32; q8+=128; is=8; + + } + + sumf += dall * isum; + + } + + *s = sumf; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowScaleMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales); + vector signed char vscales = vec_and(q2xmins, lowScaleMask); + + q2xmins = vec_sr(q2xmins, v4); + vector signed short q2xmins0 = vec_unpackh(q2xmins); + vector signed short q2xmins1 = vec_unpackl(q2xmins); + + vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); + vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); + vector signed int prod2 = vec_mule(q2xmins1, q8ysums1); + vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q2); + vector signed char qxs1 = (vector signed char)vec_xl(16, q2); + q2 += 32; + + vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask); + vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask); + vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask); + vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask); + vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask); + vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed int qv0 = vec_msum(q8y00, q2x00, v0); + vector signed int qv1 = vec_msum(q8y01, q2x01, v0); + vector signed int qv2 = vec_msum(q8y02, q2x02, v0); + vector signed int qv3 = vec_msum(q8y03, q2x03, v0); + vector signed int qv4 = vec_msum(q8y10, q2x10, v0); + vector signed int qv5 = vec_msum(q8y11, q2x11, v0); + vector signed int qv6 = vec_msum(q8y12, q2x12, v0); + vector signed int qv7 = vec_msum(q8y13, q2x13, v0); + + vector signed short vscales_07 = vec_unpackh(vscales); + vector signed int vscales_03 = vec_unpackh(vscales_07); + vector signed int vscales_47 = vec_unpackl(vscales_07); + vector signed int vs0 = vec_splat(vscales_03, 0); + vector signed int vs1 = vec_splat(vscales_03, 1); + vector signed int vs2 = vec_splat(vscales_03, 2); + vector signed int vs3 = vec_splat(vscales_03, 3); + vector signed int vs4 = vec_splat(vscales_47, 0); + vector signed int vs5 = vec_splat(vscales_47, 1); + vector signed int vs6 = vec_splat(vscales_47, 2); + vector signed int vs7 = vec_splat(vscales_47, 3); + vscales = vec_sld(vscales, vscales, 8); + + vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1); + vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2); + vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3); + vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4); + vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5); + vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6); + vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + + const __m256i m3 = __lasx_xvreplgr2vr_b(3); + const __m128i m4 = __lsx_vreplgr2vr_b(0xF); + + __m256 acc = (__m256)__lasx_xvldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0); + const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4); + const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4); + const __m256i mins = lasx_ext8_16(mins8); + const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); + + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); + + const __m256i all_scales = lasx_ext8_16(scales8); + const __m128i l_scales = lasx_extracti128(all_scales, 0); + const __m128i h_scales = lasx_extracti128(all_scales, 1); + const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32; + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + const __m256i q2_0 = __lasx_xvand_v(q2bits, m3); + const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3); + const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3); + const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3); + + __m256i p0 = lasx_maddubs_h(q2_0, q8_0); + __m256i p1 = lasx_maddubs_h(q2_1, q8_1); + __m256i p2 = lasx_maddubs_h(q2_2, q8_2); + __m256i p3 = lasx_maddubs_h(q2_3, q8_3); + + p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = __lasx_xvadd_w(p0, p1); + p2 = __lasx_xvadd_w(p2, p3); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2)); + } + + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const int32x4_t vzero = vdupq_n_s32(0); + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + ggml_int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + for (int j = 0; j < QK_K/128; ++j) { + + const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; + const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; + const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; + + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; + + scale += 4; + + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } + + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + const uint32_t *aux; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + aux = (const uint32_t *)x[i].scales; + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); + const __m128i scales[2] = { scales_0, scales_1 }; + + // high bit *128*2 from block_q3_K.hmask[QK_K/8] + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + + // integer accumulator + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] + const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + + // prepare low and high bits + const int bit = j << 2; + + const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); + const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); + const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); + const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); + const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + + const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); + const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); + const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + + const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); + const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); + const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + + // load Q8 quants from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + // multiply with scales + __m128i shuffle = _mm_set1_epi16(0x0100); + p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); + shuffle = _mm_add_epi16(shuffle, m2); + p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); + shuffle = _mm_add_epi16(shuffle, m2); + p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); + shuffle = _mm_add_epi16(shuffle, m2); + p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); + shuffle = _mm_add_epi16(shuffle, m2); + p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); + shuffle = _mm_add_epi16(shuffle, m2); + p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); + shuffle = _mm_add_epi16(shuffle, m2); + p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); + shuffle = _mm_add_epi16(shuffle, m2); + p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); + + // accumulate + p16_0 = _mm_add_epi32(p16_0, p16_1); + p16_2 = _mm_add_epi32(p16_2, p16_3); + p16_4 = _mm_add_epi32(p16_4, p16_5); + p16_6 = _mm_add_epi32(p16_6, p16_7); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + + } + + // multiply with block scale and accumulate + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint32_t aux[3]; + uint32_t utmp[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + + *s = sumf; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowMask1 = vec_splats((int8_t)0xf); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector signed char v1 = vec_splats((signed char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + UNUSED(kmask1); + UNUSED(kmask2); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(u0, lowMask1); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2)); + vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4); + vector signed char u31 = vec_and(u3, lowMask2); + + u1 = vec_or(u1, u30); + u2 = vec_or(vec_sr(u0, v4), u31); + + vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2); + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); + + vscales = vec_sub(vscales, off); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q3); + vector signed char qxs1 = (vector signed char)vec_xl(16, q3); + q3 += 32; + + //the low 2 bits + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); + vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask); + vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask); + vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask); + vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask); + + //the 3rd bit + vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); + vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2); + vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); + vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2); + vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2); + vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2); + vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2); + vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2); + qxhs0 = vec_sr(qxhs0, v4); + qxhs1 = vec_sr(qxhs1, v4); + + vector signed char q3x00 = vec_sub(qxs00, qxh00); + vector signed char q3x01 = vec_sub(qxs01, qxh01); + vector signed char q3x02 = vec_sub(qxs02, qxh02); + vector signed char q3x03 = vec_sub(qxs03, qxh03); + vector signed char q3x10 = vec_sub(qxs10, qxh10); + vector signed char q3x11 = vec_sub(qxs11, qxh11); + vector signed char q3x12 = vec_sub(qxs12, qxh12); + vector signed char q3x13 = vec_sub(qxs13, qxh13); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed short vscales_h = vec_unpackh(vscales); + vector signed short vs0 = vec_splat(vscales_h, 0); + vector signed short vs1 = vec_splat(vscales_h, 1); + vector signed short vs2 = vec_splat(vscales_h, 2); + vector signed short vs3 = vec_splat(vscales_h, 3); + vector signed short vs4 = vec_splat(vscales_h, 4); + vector signed short vs5 = vec_splat(vscales_h, 5); + vector signed short vs6 = vec_splat(vscales_h, 6); + vector signed short vs7 = vec_splat(vscales_h, 7); + vscales = vec_sld(vscales, vscales, 8); + + vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); + vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); + vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02)); + vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03)); + vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); + vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); + vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); + vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); + + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs2, vsumi1); + vsumi2 = vec_msum(qv02, vs4, vsumi2); + vsumi3 = vec_msum(qv03, vs6, vsumi3); + vsumi4 = vec_msum(qv10, vs1, vsumi4); + vsumi5 = vec_msum(qv11, vs3, vsumi5); + vsumi6 = vec_msum(qv12, vs5, vsumi6); + vsumi7 = vec_msum(qv13, vs7, vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + + const __m256i m3 = __lasx_xvreplgr2vr_b(3); + const __m256i mone = __lasx_xvreplgr2vr_b(1); + const __m128i m32 = __lsx_vreplgr2vr_b(32); + + __m256 acc = (__m256)__lasx_xvldi(0); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = lsx_set_w( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = __lsx_vsub_b(scales128, m32); + const __m256i all_scales = lasx_ext8_16(scales128); + const __m128i l_scales = lasx_extracti128(all_scales, 0); + const __m128i h_scales = lasx_extracti128(all_scales, 1); + const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + + // high bit + const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); + + // integer accumulator + __m256i sumi = __lasx_xvldi(0); + + int bit = 0; + int is = 0; + __m256i xvbit; + + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; + + xvbit = __lasx_xvreplgr2vr_h(bit); + // prepare low and high bits + const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3); + const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + xvbit = __lasx_xvreplgr2vr_h(bit); + const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3); + const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + xvbit = __lasx_xvreplgr2vr_h(bit); + const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3); + const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + xvbit = __lasx_xvreplgr2vr_h(bit); + const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3); + const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0); + __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1); + __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2); + __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3); + + __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0); + __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1); + __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2); + __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3); + + p16_0 = __lasx_xvsub_h(p16_0, q8s_0); + p16_1 = __lasx_xvsub_h(p16_1, q8s_1); + p16_2 = __lasx_xvsub_h(p16_2, q8s_2); + p16_3 = __lasx_xvsub_h(p16_3, q8s_3); + + // multiply with scales + p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = __lasx_xvadd_w(p16_0, p16_1); + p16_2 = __lasx_xvadd_w(p16_2, p16_3); + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); + } + // multiply with block scale and accumulate + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME + } + + *s = hsum_float_8(acc); + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_FEATURE_SVE + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, K_SCALE_SIZE); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const int vector_length = ggml_cpu_get_sve_cnt()*8; + const svuint8_t m4b = svdup_n_u8(0xf); + const svint32_t mzero = svdup_n_s32(0); + svint32_t sumi1 = svdup_n_s32(0); + svint32_t sumi1_1 = svdup_n_s32(0); + svint32_t sumi1_2 = svdup_n_s32(0); + svint32_t sumi2 = svdup_n_s32(0); + svint32_t sumi2_1 = svdup_n_s32(0); + svint32_t sumi2_2 = svdup_n_s32(0); + switch (vector_length) { + case 128: + { + for (int j = 0; j < QK_K/64; ++j) { + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b)); + svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4 += 32; + } + sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2); + sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2); + sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2))); + } break; + case 256: + case 512: + { + for (int j = 0; j < QK_K/64; ++j) { + const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32; + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b)); + svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4)); + q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + } + sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2))); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + } + *s = sumf; +#elif __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x2_t q4bytes; + ggml_int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + p16l = _mm256_madd_epi16(scale_l, p16l); + + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + const __m256i sumj = _mm256_add_epi32(p16l, p16h); + + sumi = _mm256_add_epi32(sumi, sumj); + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_0 = _mm_and_si128(q4bits, m4); + const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_1 = _mm_and_si128(q4bits, m4); + const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + + const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_0 = _mm_add_epi32(sumi_0, p16l); + const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16l = _mm_maddubs_epi16(q4l_1, q8l_1); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_1 = _mm_add_epi32(sumi_1, p16l); + + const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_0 = _mm_add_epi32(sumi_0, p16h); + const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16h = _mm_maddubs_epi16(q4h_1, q8h_1); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_1 = _mm_add_epi32(sumi_1, p16h); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + size_t vl = 8; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + vl = 32; + + int32_t sum_1 = 0; + int32_t sum_2 = 0; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; + + } + + sumf += d*(sum_1 + sum_2); + + } + + *s = sumf; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((uint8_t)2); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); + + vector signed short vscales = vec_unpackh(utmps); + vector signed short q4xmins = vec_unpackl(utmps); + vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); + vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins); + + vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q4xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; j+=2) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + vector signed char qxs2 = (vector signed char)vec_xl(32, q4); + vector signed char qxs3 = (vector signed char)vec_xl(48, q4); + q4 += 64; + + vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4); + vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4); + vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask); + vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4); + vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask); + vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y20 = vec_xl( 64, q8); + vector signed char q8y30 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed int qv00 = vec_msum(q8y00, q4x00, v0); + vector signed int qv01 = vec_msum(q8y01, q4x01, v0); + vector signed int qv10 = vec_msum(q8y10, q4x10, v0); + vector signed int qv11 = vec_msum(q8y11, q4x11, v0); + vector signed int qv20 = vec_msum(q8y20, q4x20, v0); + vector signed int qv21 = vec_msum(q8y21, q4x21, v0); + vector signed int qv30 = vec_msum(q8y30, q4x30, v0); + vector signed int qv31 = vec_msum(q8y31, q4x31, v0); + + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vector signed int vs2 = vec_splat(vscales_h, 2); + vector signed int vs3 = vec_splat(vscales_h, 3); + vscales = vec_sld(vscales, vscales, 8); + + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3); + + vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + GGML_UNUSED(kmask1); + GGML_UNUSED(kmask2); + GGML_UNUSED(kmask3); + + const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); + + __m256 acc = (__m256)__lasx_xvldi(0); + __m128 acc_m = (__m128)__lsx_vldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); + const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); + const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); + acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); + + const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); + const __m256i scales = lasx_insertf128(sc128, sc128); + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4l = __lasx_xvand_v(q4bits, m4); + const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4); + + const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + __m256i p16l = lasx_maddubs_h(q4l, q8l); + p16l = lasx_madd_h(scale_l, p16l); + + const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + __m256i p16h = lasx_maddubs_h(q4h, q8h); + p16h = lasx_madd_h(scale_h, p16h); + const __m256i sumj = __lasx_xvadd_w(p16l, p16h); + + sumi = __lasx_xvadd_w(sumi, sumj); + } + + __m256 vd = __lasx_xvreplfr2vr_s(d); + acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + + } + + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee)); + __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0); + acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1); + + + ft_union fi; + fi.i = __lsx_vpickve2gr_w(acc_m, 0); + *s = hsum_float_8(acc) + fi.f ; +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + + sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; + } + + sumf += d * sumi - dmin * sumi_mins; + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); + __m256i hmask = mone; + + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); + __m128i hmask = mone; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int bit = 0; + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + + __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); + __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); + __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); + __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_0, p16_1); + + q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); + q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); + q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + q5_0 = _mm_add_epi8(q5l_0, q5h_0); + q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); + p16_2 = _mm_madd_epi16(scale_1, p16_2); + p16_3 = _mm_madd_epi16(scale_1, p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + float sums = 0.0; + + size_t vl; + + for (int i = 0; i < nb; ++i) { + + vl = 8; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + vl = 32; + int32_t aux32 = 0; + int is = 0; + + uint8_t m = 1; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q5 and Q8 + vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); + vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + + // compute mask for addition + vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl); + m <<= 1; + + vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl); + m <<= 1; + + vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); + vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + + vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); + vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + q5 += 32; q8 += 64; + + } + + vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); + sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + + } + + *s = sumf+sums; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v1 = vec_splats((unsigned char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed short vscales = vec_unpackh(utmps); + + vector signed short q5xmins = vec_unpackl(utmps); + vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins); + vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins); + + vector signed int prod0 = vec_mule(q5xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q5xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q5, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q5); + vector signed char qxs1 = (vector signed char)vec_xl(16, q5); + q5 += 32; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + + vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4); + vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3); + vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4); + vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3); + qxhs0 = vec_sr(qxhs0, v2); + qxhs1 = vec_sr(qxhs1, v2); + + vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00); + vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01); + vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10); + vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl(16, q8); + vector signed char q8y01 = vec_xl(32, q8); + vector signed char q8y11 = vec_xl(48, q8); + q8 += 64; + + vector signed int qv00 = vec_msum(q8y00, q5x00, v0); + vector signed int qv01 = vec_msum(q8y01, q5x01, v0); + vector signed int qv10 = vec_msum(q8y10, q5x10, v0); + vector signed int qv11 = vec_msum(q8y11, q5x11, v0); + + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vscales = vec_sld(vscales, vscales, 12); + + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1); + vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2); + vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + GGML_UNUSED(kmask1); + GGML_UNUSED(kmask2); + GGML_UNUSED(kmask3); + + const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); + const __m128i mzero = __lsx_vldi(0); + const __m256i mone = __lasx_xvreplgr2vr_b(1); + + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); + const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); + const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); + const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero); + summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check + + const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); + const __m256i scales = lasx_insertf128(sc128, sc128); + + const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); + __m256i hmask = mone; + + __m256i sumi = __lasx_xvldi(0); + + int bit = 0; + __m256i xvbit; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; + + xvbit = __lasx_xvreplgr2vr_h(bit++); + const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4); + const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); + const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0); + hmask = __lasx_xvslli_h(hmask, 1); + + xvbit = __lasx_xvreplgr2vr_h(bit++); + const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4); + const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); + const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1); + hmask = __lasx_xvslli_h(hmask, 1); + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0); + __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1); + + p16_0 = lasx_madd_h(scale_0, p16_0); + p16_1 = lasx_madd_h(scale_1, p16_1); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); + + } + + __m256 vd = __lasx_xvreplfr2vr_s(d); + acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + //const int8x16_t m32s = vdupq_n_s8(32); + + const uint8x16_t mone = vdupq_n_u8(3); + + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; + + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); + + int32_t isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; + ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + + scale += 4; + + q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + } + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m15 = _mm_set1_epi8(15); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + // handle the q6_k -32 offset separately using bsums + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1); + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales); + const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8)); + const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5); + const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2); + const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48)); + const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48)); + const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2); + const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2); + + const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3); + const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4); + const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5); + const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6); + const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7); + + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3); + p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); + p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5); + p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); + p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); + + } + + sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0); + sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1); + const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + size_t vl; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + + *s = sumf; + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict qs = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q6, 0, 0); + __builtin_prefetch(qh, 0, 0); + __builtin_prefetch(q8, 0, 0); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q6); + vector signed char qxs1 = (vector signed char)vec_xl(16, q6); + vector signed char qxs2 = (vector signed char)vec_xl(32, q6); + vector signed char qxs3 = (vector signed char)vec_xl(48, q6); + q6 += 64; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + vector signed char qxs20 = vec_and(qxs2, lowMask); + vector signed char qxs21 = vec_sr(qxs2, v4); + vector signed char qxs30 = vec_and(qxs3, lowMask); + vector signed char qxs31 = vec_sr(qxs3, v4); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, qh); + qh += 32; + + vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); + vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); + vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4); + vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4); + vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); + vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); + vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4); + vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4); + + vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); + vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); + vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); + vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); + vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off); + vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off); + vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off); + vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y20 = vec_xl( 32, q8); + vector signed char q8y30 = vec_xl( 48, q8); + vector signed char q8y01 = vec_xl( 64, q8); + vector signed char q8y11 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); + vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); + vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20)); + vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30)); + vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); + vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); + vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21)); + vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31)); + + vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8)); + qs += 8; + + vector signed short vs0 = vec_splat(vscales, 0); + vector signed short vs1 = vec_splat(vscales, 1); + vector signed short vs2 = vec_splat(vscales, 2); + vector signed short vs3 = vec_splat(vscales, 3); + vector signed short vs4 = vec_splat(vscales, 4); + vector signed short vs5 = vec_splat(vscales, 5); + vector signed short vs6 = vec_splat(vscales, 6); + vector signed short vs7 = vec_splat(vscales, 7); + + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs4, vsumi1); + vsumi2 = vec_msum(qv10, vs1, vsumi2); + vsumi3 = vec_msum(qv11, vs5, vsumi3); + vsumi4 = vec_msum(qv20, vs2, vsumi4); + vsumi5 = vec_msum(qv21, vs6, vsumi5); + vsumi6 = vec_msum(qv30, vs3, vsumi6); + vsumi7 = vec_msum(qv31, vs7, vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + + const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); + const __m256i m2 = __lasx_xvreplgr2vr_b(3); + const __m256i m32s = __lasx_xvreplgr2vr_b(32); + + __m256 acc = (__m256)__lasx_xvldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0); + + __m256i sumi = __lasx_xvldi(0); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; + + const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4); + const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0); + const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1); + const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0); + __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1); + __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2); + __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3); + + __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0); + __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1); + __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2); + __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3); + + p16_0 = __lasx_xvsub_h(p16_0, q8s_0); + p16_1 = __lasx_xvsub_h(p16_1, q8s_1); + p16_2 = __lasx_xvsub_h(p16_2, q8s_2); + p16_3 = __lasx_xvsub_h(p16_3, q8s_3); + + p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0); + p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1); + p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2); + p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); + } + + acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.25f * sumf; + +#elif defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]); + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + memcpy(aux32, q2, 4*sizeof(uint32_t)); + q2 += 8; + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))}; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = aux32[1] >> 28; + const uint16_t ls1 = aux32[3] >> 28; + + vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + + const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); + const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; + + int32x4x4_t scales32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8x8_t scales8 = vld1_u8(x[i].scales); + const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); + const uint8x8_t scales_h = vshr_n_u8(scales8, 4); + uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); + const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); + const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); + scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); + scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); + scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); + scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); + int32x4_t sumi = vdupq_n_s32(0); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); + const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); + sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); + q2 += 8; + } + sumf += d*vaddvq_s32(sumi); + } + *s = 0.125f * sumf; + +#elif defined(__AVX2__) + + const __m256i mone = _mm256_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); + const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); + const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); + const __m256i m511 = _mm256_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; + aux_gindex = _mm256_and_si256(q2_data, m511); + + const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); + const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); + const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); + + const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); + + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], + iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], + iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); + const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], + iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); + const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], + iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + + const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits); + const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1); + const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l); + const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h); + + __m256i signs; + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); + + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); + + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); + + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3); + const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4); + + const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); + const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2))); + const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3))); + + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4)); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const __m128i mone = _mm_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes); + const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1); + const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1); + const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1); + const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2); + const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper); + const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1); + const __m128i m511 = _mm_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16; + aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511)); + + const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9); + const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9); + const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13); + const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13); + const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0); + const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1); + + const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0); + const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1); + const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0); + const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1); + + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]); + const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]); + const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]); + const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]); + + // AVX2 full_signs_1 is full_sign_bits_0 here + // AVX2 full_signs_2 is full_sign_bits_1 here + __m128i signs_0, signs_1; + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone)); + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0); + const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1); + const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0); + const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1); + + __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); + const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); + const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); + const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); + const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1)); + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1)); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__loongarch_asx) + + const __m256i mone = __lasx_xvreplgr2vr_b(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0); + const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0); + const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0); + const __m256i m511 = __lasx_xvreplgr2vr_h(511); + const __m128i m4 = __lsx_vreplgr2vr_b(0xf); + const __m128i m1 = __lsx_vreplgr2vr_b(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = __lsx_vreplgr2vr_d(aux64); + stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4)); + const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1); + + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16; + aux_gindex = __lasx_xvand_v(q2_data, m511); + + const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9); + const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13); + const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper); + + const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits); + + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + + const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], + iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); + const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], + iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); + const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], + iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); + const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], + iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + + const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0); + const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1); + const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l); + const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h); + + __m256i signs; + signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1); + + signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2); + + signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3); + + signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4); + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3); + const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4); + + const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1))); + const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2))); + const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3))); + + sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1)); + sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2)); + sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3)); + sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4)); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); +#elif defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint16_t * restrict q2 = x[i].qs; + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))}; + q2 += 8; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_s * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint8x16_t m1 = vdupq_n_u8(1); + const int32x4_t vzero = vdupq_n_s32(0); + + uint8x16x2_t vs; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); + q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); + q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); + q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); + qs += 8; + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); + + q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); + q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); + + signs += 4; + + q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); + q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); + + const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); + + sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); + sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); + sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); + sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); + } + sumf += d*(sumi1 + sumi2); + } + + *s = 0.125f * sumf; + +#elif defined(__AVX2__) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; + + __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); + + aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 + + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8); + const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8)); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]); + qs += 8; + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0))); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1))); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0))); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1))); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * restrict q2 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))}; + q2 += 8; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0); + vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0); + vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1); + vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2); + vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + + const __m128i m4 = __lsx_vreplgr2vr_b(0xf); + const __m128i m1 = __lsx_vreplgr2vr_b(1); + + const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); + const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); + uint64_t aux64; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + __m128i tmp1; + memcpy(&aux64, x[i].scales, 8); + tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0); + tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1); + const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1); + const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; + + __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); + + aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 + + const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } + + *s = 0.125f * sumf; + +#endif + +} + +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + ggml_int8x16x4_t q3s; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); + const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); + const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); + const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); + const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); + q3 += 16; + q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); + q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); + q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); + q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.5f * sumf; + +#elif defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); + +#elif defined(__POWER9_VECTOR__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * restrict q3 = x[i].qs; + const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4); + const int8_t * restrict q8 = y[i].qs; + +#pragma GCC unroll 1 + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; + vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; + vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; + vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; + q3 += 16; + + vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >> 0) & 127]), (uint64_t)(signs64[(signs[0] >> 7) & 127])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])}; + vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >> 0) & 127]), (uint64_t)(signs64[(signs[1] >> 7) & 127])}; + vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])}; + + vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0); + vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1); + vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2); + vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(signs[0] >> 28); + const uint16_t ls1 = (uint16_t)(signs[1] >> 28); + signs += 2; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.25f * vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + + const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); + const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.25f * hsum_float_8(accumf); + +#else + + uint32_t aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +#endif +} + +void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + typedef union { + uint16x8_t vec_index; + uint16_t index[8]; + } vec_index_t; + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + + const int16x8_t hshift = vld1q_s16(k_shift); + const uint16x8_t m256 = vdupq_n_u16(256); + const uint8x16_t m1 = vdupq_n_u8(1); + + uint8x16x2_t vs; + ggml_int8x16x4_t q3s; + ggml_int8x16x4_t q8b; + vec_index_t idx; + + uint32_t scales32[2]; + const uint8_t * scales8 = (const uint8_t *)scales32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; + + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; + idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); + const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); + const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + signs += 4; + + q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); + + sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; + sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; + } + sumf += d*(sumi1 + sumi2); + } + *s = sumf; + +#elif defined(__AVX2__) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = _mm256_set1_epi32(256); + + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16; + idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]); + idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]); + idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l))); + idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = _mm256_set_epi32( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = _mm256_set_epi32( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); + + __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); + + aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = hsum_float_8(accumf); + +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256); + const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16); + const __m128i idx_mask = _mm_set1_epi32(256); + + typedef union { + __m128i vec[4]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); + const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp); + const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16; + idx.vec[0] = _mm_set1_epi32(qh[ib32+0]); + idx.vec[1] = idx.vec[0]; + idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); + idx.vec[3] = idx.vec[2]; + + idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask); + idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask); + idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask); + idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask); + + idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); + idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); + idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1)); + idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); + + const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); + const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]); + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = hsum_float_8(accumf); + +#elif defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].signs); + const uint8_t * restrict sc = x[i].scales; + const int8_t * restrict q8 = y[i].qs; + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)], + iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]}; + vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)], + iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]}; + vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)], + iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]}; + vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)], + iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]}; + q3 += 16; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0); + vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0); + vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1); + vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2); + vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + sc ++; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); + const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); + + __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = __lasx_xvreplgr2vr_w(256); + + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16; + idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]); + idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]); + idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0))); + idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = lasx_set_w( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = lasx_set_w( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); + + __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); + + aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint8_t * restrict signs = x[i].signs; + const int8_t * restrict q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; + const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls1; + sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls2; + } + sumf += d * bsum; + } + *s = sumf; +#endif +} + +#if defined(__AVX2__) +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = _mm256_sign_epi8(x, x); + const __m256i sy = _mm256_sign_epi8(y, x); + return _mm256_maddubs_epi16(ax, sy); +} +#elif defined(__loongarch_asx) +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = __lasx_xvsigncov_b(x, x); + const __m256i sy = __lasx_xvsigncov_b(x, y); + __m256i tmp1, tmp2, tmp3; + tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy); + tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy); + tmp3 = __lasx_xvadd_h(tmp1, tmp2); + return __lasx_xvsat_h(tmp3, 15); +} +#endif + +void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + + ggml_int8x16x4_t q1b; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi1 = 0, sumi2 = 0, sumi3 = 0; + + for (int ib = 0; ib < QK_K/32; ib += 2) { + + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); + qs += 8; + + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); + + const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + sumi1 += vaddvq_s32(p1) * ls1; + sumi2 += vaddvq_s32(p2) * ls2; + sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); + + } + + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); + } + + *s = sumf; + +#elif defined __AVX2__ + + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m256i sumi = _mm256_setzero_si256(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], + iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], + iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + qs += 8; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2)); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum); + accum1 += d * sumi1; + + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + +#elif defined __AVX__ + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]); + qs += 8; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum); + accum1 += d * sumi1; + + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + +#elif defined(__POWER9_VECTOR__) + const vector unsigned char v0 = vec_splats((unsigned char)0x0); + const vector unsigned short vsign = vec_splats((unsigned short)0x8000); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi8 = vec_splats((int32_t)0); + + const uint8_t * restrict q1 = x[i].qs; + const uint16_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + const int16_t * restrict qs = y[i].bsums; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q1, 0, 1); + __builtin_prefetch(qh, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))}; + q1 += 8; + + vector signed char q1x0 = (vector signed char)aux64x2_0; + vector signed char q1x1 = (vector signed char)aux64x2_1; + vector signed char q1x2 = (vector signed char)aux64x2_2; + vector signed char q1x3 = (vector signed char)aux64x2_3; + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3)); + + const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7); + const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7); + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + vector signed short vscales = vec_sld(vscales23, vscales01, 8); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + + vector signed short q8ysums = vec_xl_len(qs, 8); + qs += 4; + q8ysums = vec_mergeh(q8ysums, (vector signed short)v0); + + vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8); + qh += 2; + vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0); + + vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel); + + vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + + vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + __m256 accum = (__m256)__lasx_xvldi(0); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m256i sumi = __lasx_xvldi(0); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3); + + __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3); + + qs += 8; + const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + + __m256i tmp1, tmp5, tmp6; + tmp1 = __lasx_xvreplgr2vr_h(ls1); + tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1); + tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1); + const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6); + + tmp1 = __lasx_xvreplgr2vr_h(ls2); + tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1); + tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1); + const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum); + accum1 += d * sumi1; + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi = 0, sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*((qh[ib] >> 12) & 7) + 1; + const int delta = qh[ib] & 0x8000 ? -1 : 1; + int lsum = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } + q8 += 8; + } + sumi += ls * lsum; + sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); + qs += 4; + } + + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; + +#endif +} + +void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + +#if defined __ARM_NEON + const int32x4_t mask = vdupq_n_s32(0x7); + const int32x4_t mone = vdupq_n_s32(1); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x4_t deltas; + deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1)); + deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1)); + deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1)); + deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1)); + + ggml_int8x16x4_t q1b; + ggml_int8x16x4_t q8b; + + uint32_t aux32; + const uint8_t * aux8 = (const uint8_t *)&aux32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + int32x4_t sumi1 = mzero; + int32x4_t sumi2 = mzero; + + for (int ib = 0; ib < QK_K/32; ib += 2) { + + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700))))); + + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); + const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); + const int32x4_t p12 = vpaddq_s32(p1, p2); + + const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that + aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202); + + const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1])); + const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); + const int32x4_t p34 = vpaddq_s32(p3, p4); + + int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); + + scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); + + sumi1 = vmlaq_s32(sumi1, scales_4, p12); + sumi2 = vmlaq_s32(sumi2, scales_4, p34); + + qs += 8; qh += 4; + + } + + sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i mask = _mm256_set1_epi16(0x7); + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m256i q1b_1 = _mm256_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)] + ); + const __m256i q1b_2 = _mm256_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)] + ); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + + const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + + const __m256i dot3 = mul_add_epi8(delta1, q8b_1); + const __m256i dot4 = mul_add_epi8(delta2, q8b_2); + + __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); + __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); + + scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); + scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); + const __m256i p1 = _mm256_madd_epi16(dot1, scale1); + const __m256i p2 = _mm256_madd_epi16(dot2, scale2); + const __m256i p3 = _mm256_madd_epi16(dot3, scale1); + const __m256i p4 = _mm256_madd_epi16(dot4, scale2); + + sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4)); + + qs += 8; qh += 4; + } + + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); + + accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); + accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); + +#elif defined __AVX__ + const __m128i mask = _mm_set1_epi16(0x7); + const __m128i mone = _mm_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x( + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x( + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + + const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + + const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0); + const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1); + const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0); + const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1); + + __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0); + __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3); + __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6); + __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9); + + scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone); + scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone); + scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone); + scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone); + const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1); + const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0); + const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1); + const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0); + const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1)); + + qs += 8; qh += 4; + } + + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); + + accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1); + accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2); + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); + +#else + + int sum1[2], sum2[2], delta[4]; + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + delta[0] = qh[0] & 0x08 ? -1 : 1; + delta[1] = qh[0] & 0x80 ? -1 : 1; + delta[2] = qh[1] & 0x08 ? -1 : 1; + delta[3] = qh[1] & 0x80 ? -1 : 1; + sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); + int lsum1 = 0, lsum2 = 0; + for (int j = 0; j < 8; ++j) { + lsum1 += q8[j] * grid[j]; + lsum2 += q8[j]; + } + q8 += 8; + sum1[l/2] += lsum1; + sum2[l/2] += lsum2*delta[l]; + } + + const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; + const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; + + sumi1 += sum1[0] * ls1 + sum1[1] * ls2; + sumi2 += sum2[0] * ls1 + sum2[1] * ls2; + qs += 4; + qh += 2; + } + + sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; + +#endif +} + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * restrict x = vx; + const block_q8_0 * restrict y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + uint8x16x2_t q4bits; + int8x16x4_t q4b; + int8x16x4_t q8b; + int32x4_t prod_1, prod_2; + + for (; ib + 1 < nb; ib += 2) { + + q4bits.val[0] = vld1q_u8(x[ib + 0].qs); + q4bits.val[1] = vld1q_u8(x[ib + 1].qs); + q8b.val[0] = vld1q_s8(y[ib + 0].qs); + q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); + q8b.val[2] = vld1q_s8(y[ib + 1].qs); + q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); + + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + sumf += + GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + + GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); + } + +#elif defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs); + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs); + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); + const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); + accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), + _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), + _mm256_cvtepi32_ps(p_2), accum2); + } + + sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); + +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + + const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1); + const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_perm(values, values, (vector unsigned char)q4x0); + q4x1 = vec_perm(values, values, (vector unsigned char)q4x1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + } + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#elif defined (__loongarch_asx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); + const __m256i mone = __lasx_xvreplgr2vr_h(1); + + __m256 accum1 = (__m256)__lasx_xvldi(0); + __m256 accum2 = (__m256)__lasx_xvldi(0); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0); + const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0); + const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0); + const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0); + const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), + lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); + const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), + lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = lasx_madd_h(p16_1, mone); + const __m256i p_2 = lasx_madd_h(p16_2, mone); + accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), + __lasx_xvffint_s_w(p_1), accum1); + accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), + __lasx_xvffint_s_w(p_2), accum2); + } + + sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); + +#endif + for (; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + ggml_uint8x16x2_t q4bits; + ggml_int8x16x4_t q4b; + ggml_int8x16x4_t q8b; + int32x4_t prod_1, prod_2; + + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + + const int8_t * q8 = y[ibl].qs; + const uint8_t * q4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/64; ++ib) { + + q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + sumi1 += vaddvq_s32(prod_1) * ls1; + sumi2 += vaddvq_s32(prod_2) * ls2; + + } + + sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1)); + const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2)); + sumi1 = _mm256_add_epi32(p_1, sumi1); + sumi2 = _mm256_add_epi32(p_2, sumi2); + } + accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); + +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1)); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1)); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2)); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2)); + sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0); + sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1); + sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0); + sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1); + } + __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0); + __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum); + } + + *s = hsum_float_8(accum); + +#elif defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + + for (int ibl = 0; ibl < nb; ++ibl) { + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ibl].d)); + vector float vyd = vec_splats(y[ibl].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + uint16_t h = x[ibl].scales_h; + + const uint8_t * restrict q4 = x[ibl].qs; + const uint8_t * restrict sc = x[ibl].scales_l; + const int8_t * restrict q8 = y[ibl].qs; + + for (int ib = 0; ib < QK_K/64; ib ++ ) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + q4 += 32; + + vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask); + vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4); + vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask); + vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4); + + q4x00 = vec_perm(values, values, (vector unsigned char)q4x00); + q4x01 = vec_perm(values, values, (vector unsigned char)q4x01); + q4x10 = vec_perm(values, values, (vector unsigned char)q4x10); + q4x11 = vec_perm(values, values, (vector unsigned char)q4x11); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3)); + + const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32); + const uint16_t ls1 = (uint16_t)(((sc[0] >> 4) | ((h << 2) & 0x30)) - 32); + h >>= 4; + sc ++; + + vector signed short vscales01 = vec_splats((int16_t)ls0); + vector signed short vscales23 = vec_splats((int16_t)ls1); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); + + __m256 accum = (__m256)__lasx_xvldi(0); + __m256i tmp1; + __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask; + + mask_8f = __lsx_vreplgr2vr_b(0x8f); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + __m128i zero = __lsx_vldi(0); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp3 = __lsx_vand_v(tmp0, mask); + tmp3 = __lsx_vshuf_b(values128, zero, tmp3); + + tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp4 = __lsx_vand_v(tmp0, mask); + tmp4 = __lsx_vshuf_b(values128, zero, tmp4); + + const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4); + + tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp3 = __lsx_vand_v(tmp0, mask); + tmp3 = __lsx_vshuf_b(values128, zero, tmp3); + + tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp4 = __lsx_vand_v(tmp0, mask); + tmp4 = __lsx_vshuf_b(values128, zero, tmp4); + + const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4); + + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + __m256i tmp5, tmp6; + tmp1 = __lasx_xvreplgr2vr_h(ls1); + tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1); + tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1); + const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6); + tmp1 = __lasx_xvreplgr2vr_h(ls2); + tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1); + tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1); + const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6); + sumi1 = __lasx_xvadd_w(p_1, sumi1); + sumi2 = __lasx_xvadd_w(p_2, sumi2); + } + accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); + +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +} + +// ============================ 4-bit non-linear quants + +void quantize_row_iq4_nl(const float * restrict x, void * restrict y, int64_t k) { + assert(k % QK4_NL == 0); + quantize_row_iq4_nl_ref(x, y, k); +} + +void quantize_row_iq4_xs(const float * restrict x, void * restrict y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_xs(x, y, 1, k, NULL); +} diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.h b/ggml/src/ggml-cpu/ggml-cpu-quants.h new file mode 100644 index 000000000..e33d9d473 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.h @@ -0,0 +1,63 @@ +#pragma once + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "ggml.h" + +// GGML CPU internal header + +#ifdef __cplusplus +extern "C" { +#endif + +// Quantization +void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +// Dot product +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cpu/ggml-cpu-traits.cpp b/ggml/src/ggml-cpu/ggml-cpu-traits.cpp new file mode 100644 index 000000000..62a0712da --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu-traits.cpp @@ -0,0 +1,36 @@ +#include "ggml-cpu-traits.h" + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" + +namespace ggml::cpu { +tensor_traits::~tensor_traits() {} + +extra_buffer_type::~extra_buffer_type() {} +} // namespace ggml::cpu + +bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + if (extra && extra->context) { + auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; + auto tensor_traits = buf_extra->get_tensor_traits(op); + if (tensor_traits && tensor_traits->compute_forward(params, op)) { + return true; + } + } + } + return false; +} + +bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + if (extra && extra->context) { + auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; + auto tensor_traits = buf_extra->get_tensor_traits(op); + if (tensor_traits && tensor_traits->work_size(n_threads, op, *size)) { + return true; + } + } + } + return false; +} diff --git a/ggml/src/ggml-cpu/ggml-cpu-traits.h b/ggml/src/ggml-cpu/ggml-cpu-traits.h new file mode 100644 index 000000000..99a6186b1 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu-traits.h @@ -0,0 +1,38 @@ +#pragma once +#include "ggml-backend-impl.h" +#include "ggml-cpu-impl.h" +#include "ggml.h" + +#ifdef __cplusplus +# include +extern "C" { +#endif + +// return true if op part of extra "accelerator" +bool ggml_cpu_extra_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op); +bool ggml_cpu_extra_work_size(int n_threads, const struct ggml_tensor * op, size_t * size); + +#ifdef __cplusplus +} + +namespace ggml::cpu { +// register in tensor->extra +class tensor_traits { + public: + virtual ~tensor_traits(); + virtual bool work_size(int n_threads, const struct ggml_tensor * op, size_t & size) = 0; + virtual bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) = 0; +}; + +class extra_buffer_type { + public: + virtual ~extra_buffer_type(); + virtual bool supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) = 0; + virtual tensor_traits * get_tensor_traits(const struct ggml_tensor * op) = 0; +}; +} // namespace ggml::cpu + +// implemented in ggml-cpu.cpp. +std::vector & ggml_backend_cpu_get_extra_buffers_type(); + +#endif diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c new file mode 100644 index 000000000..e809f05d2 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -0,0 +1,14392 @@ +#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows +#define _USE_MATH_DEFINES // For M_PI on MSVC + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-cpu-traits.h" +#include "ggml-cpu-impl.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ggml-quants.h" +#include "ggml-cpu-quants.h" +#include "ggml-threading.h" +#include "amx/amx.h" +#include "ggml.h" + +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__gnu_linux__) +#include +#endif + +#ifdef GGML_USE_OPENMP +#include +#endif + +#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) +#undef GGML_USE_LLAMAFILE +#endif + +#ifdef GGML_USE_LLAMAFILE +#include "llamafile/sgemm.h" +#endif + +#if defined(_MSC_VER) +// disable "possible loss of data" to avoid hundreds of casts +// we should just be careful :) +#pragma warning(disable: 4244 4267) + +// disable POSIX deprecation warnings +// these functions are never going away, anyway +#pragma warning(disable: 4996) + +// unreachable code because of multiple instances of code after GGML_ABORT +#pragma warning(disable: 4702) +#endif + +// Note: once we move threading into a separate C++ file +// will use std::hardware_destructive_interference_size instead of hardcoding it here +// and we'll use C++ attribute syntax. +#define GGML_CACHE_LINE 64 + +#if defined(__clang__) || defined(__GNUC__) +#define GGML_CACHE_ALIGN __attribute__((aligned(GGML_CACHE_LINE))) +#endif + +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) +#define GGML_TSAN_ENABLED 1 +#endif +#else // __has_feature +#if defined(__SANITIZE_THREAD__) +#define GGML_TSAN_ENABLED 1 +#endif +#endif // __has_feature + +#define UNUSED GGML_UNUSED +#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0) + +#if defined(GGML_USE_ACCELERATE) +#include +#endif + +// floating point type used to accumulate sums +typedef double ggml_float; + +#define GGML_GELU_FP16 +#define GGML_GELU_QUICK_FP16 + +#define GGML_SOFT_MAX_UNROLL 4 +#define GGML_VEC_DOT_UNROLL 2 +#define GGML_VEC_MAD_UNROLL 32 + +// +// global data +// + +// precomputed gelu table for f16 (128 KB) +static ggml_fp16_t ggml_table_gelu_f16[1 << 16]; + +// precomputed quick gelu table for f16 (128 KB) +static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; + +#if defined(__ARM_ARCH) +struct ggml_arm_arch_features_type { + int has_neon; + int has_dotprod; + int has_i8mm; + int has_sve; + int sve_cnt; +} ggml_arm_arch_features = {-1, -1, -1, -1, 0}; +#endif + + +#if defined(_WIN32) + +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX + #define NOMINMAX +#endif +#include + +#if defined(_MSC_VER) && !defined(__clang__) +#define GGML_CACHE_ALIGN __declspec(align(GGML_CACHE_LINE)) + +typedef volatile LONG atomic_int; +typedef atomic_int atomic_bool; +typedef atomic_int atomic_flag; + +#define ATOMIC_FLAG_INIT 0 + +typedef enum { + memory_order_relaxed, + memory_order_consume, + memory_order_acquire, + memory_order_release, + memory_order_acq_rel, + memory_order_seq_cst +} memory_order; + +static void atomic_store(atomic_int * ptr, LONG val) { + InterlockedExchange(ptr, val); +} +static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) { + // TODO: add support for explicit memory order + InterlockedExchange(ptr, val); +} +static LONG atomic_load(atomic_int * ptr) { + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) { + // TODO: add support for explicit memory order + return InterlockedCompareExchange(ptr, 0, 0); +} +static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { + return InterlockedExchangeAdd(ptr, inc); +} +static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) { + // TODO: add support for explicit memory order + return InterlockedExchangeAdd(ptr, inc); +} +static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { + return InterlockedExchange(ptr, 1); +} +static void atomic_flag_clear(atomic_flag * ptr) { + InterlockedExchange(ptr, 0); +} +static void atomic_thread_fence(memory_order mo) { + MemoryBarrier(); +} +#else // clang +#include +#endif + +typedef HANDLE pthread_t; + +typedef DWORD thread_ret_t; +static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) { + (void) unused; + HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); + if (handle == NULL) + { + return EAGAIN; + } + + *out = handle; + return 0; +} + +static int pthread_join(pthread_t thread, void * unused) { + (void) unused; + int ret = (int) WaitForSingleObject(thread, INFINITE); + CloseHandle(thread); + return ret; +} + +static int sched_yield (void) { + Sleep (0); + return 0; +} +#else + +#include +#include +#include +#if defined(__FreeBSD__) +#include +#endif + +typedef void * thread_ret_t; + +#include +#include +#include + +#endif + +typedef pthread_t ggml_thread_t; + +#if defined(__APPLE__) +#include +#include +#include +#endif + +// +// cache line +// + +#if defined(__cpp_lib_hardware_interference_size) +#define CACHE_LINE_SIZE hardware_destructive_interference_size +#else +#if defined(__POWER9_VECTOR__) +#define CACHE_LINE_SIZE 128 +#else +#define CACHE_LINE_SIZE 64 +#endif +#endif + +static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); + + +static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc); +static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc); +static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc); + +static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { + [GGML_TYPE_F32] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, + [GGML_TYPE_F16] = { + .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, + .vec_dot_type = GGML_TYPE_F16, + .nrows = 1, + }, + [GGML_TYPE_Q4_0] = { + .from_float = quantize_row_q4_0, + .vec_dot = ggml_vec_dot_q4_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else + .nrows = 1, +#endif + }, + [GGML_TYPE_Q4_1] = { + .from_float = quantize_row_q4_1, + .vec_dot = ggml_vec_dot_q4_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else + .nrows = 1, +#endif + }, + [GGML_TYPE_Q5_0] = { + .from_float = quantize_row_q5_0, + .vec_dot = ggml_vec_dot_q5_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, + [GGML_TYPE_Q5_1] = { + .from_float = quantize_row_q5_1, + .vec_dot = ggml_vec_dot_q5_1_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, + .nrows = 1, + }, + [GGML_TYPE_Q8_0] = { + .from_float = quantize_row_q8_0, + .vec_dot = ggml_vec_dot_q8_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else + .nrows = 1, +#endif + }, + [GGML_TYPE_Q8_1] = { + .from_float = quantize_row_q8_1, + .vec_dot_type = GGML_TYPE_Q8_1, + .nrows = 1, + }, + [GGML_TYPE_Q2_K] = { + .from_float = quantize_row_q2_K, + .vec_dot = ggml_vec_dot_q2_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_Q3_K] = { + .from_float = quantize_row_q3_K, + .vec_dot = ggml_vec_dot_q3_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_Q4_K] = { + .from_float = quantize_row_q4_K, + .vec_dot = ggml_vec_dot_q4_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_Q5_K] = { + .from_float = quantize_row_q5_K, + .vec_dot = ggml_vec_dot_q5_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_Q6_K] = { + .from_float = quantize_row_q6_K, + .vec_dot = ggml_vec_dot_q6_K_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ2_XXS] = { + .from_float = NULL, + .vec_dot = ggml_vec_dot_iq2_xxs_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ2_XS] = { + .from_float = NULL, + .vec_dot = ggml_vec_dot_iq2_xs_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ3_XXS] = { + // NOTE: from_float for iq3 and iq2_s was removed because these quants require initialization in ggml_quantize_init + //.from_float = quantize_row_iq3_xxs, + .vec_dot = ggml_vec_dot_iq3_xxs_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ3_S] = { + //.from_float = quantize_row_iq3_s, + .vec_dot = ggml_vec_dot_iq3_s_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ2_S] = { + //.from_float = quantize_row_iq2_s, + .vec_dot = ggml_vec_dot_iq2_s_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ1_S] = { + .from_float = NULL, + .vec_dot = ggml_vec_dot_iq1_s_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ1_M] = { + .from_float = NULL, + .vec_dot = ggml_vec_dot_iq1_m_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_IQ4_NL] = { + .from_float = quantize_row_iq4_nl, + .vec_dot = ggml_vec_dot_iq4_nl_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, + [GGML_TYPE_IQ4_XS] = { + .from_float = quantize_row_iq4_xs, + .vec_dot = ggml_vec_dot_iq4_xs_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_Q8_K] = { + .from_float = quantize_row_q8_K, + }, + [GGML_TYPE_BF16] = { + .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, + .vec_dot_type = GGML_TYPE_BF16, + .nrows = 1, + }, + [GGML_TYPE_TQ1_0] = { + .from_float = quantize_row_tq1_0, + .vec_dot = ggml_vec_dot_tq1_0_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_TQ2_0] = { + .from_float = quantize_row_tq2_0, + .vec_dot = ggml_vec_dot_tq2_0_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, +}; + +const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { + return &type_traits_cpu[type]; +} + +// +// simd mappings +// + +// we define a common set of C macros which map to specific intrinsics based on the current architecture +// we then implement the fundamental computation operations below using only these macros +// adding support for new architectures requires to define the corresponding SIMD macros +// +// GGML_F32_STEP / GGML_F16_STEP +// number of elements to process in a single step +// +// GGML_F32_EPR / GGML_F16_EPR +// number of elements to fit in a single register +// + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + +#define GGML_SIMD + +// F32 NEON + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 float32x4_t +#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) +#define GGML_F32x4_SET1(x) vdupq_n_f32(x) +#define GGML_F32x4_LOAD vld1q_f32 +#define GGML_F32x4_STORE vst1q_f32 +#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) +#define GGML_F32x4_ADD vaddq_f32 +#define GGML_F32x4_MUL vmulq_f32 +#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + (res) = (ggml_float) GGML_F32x4_REDUCE_ONE((x)[0]); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD(x) vld1q_f16((const ggml_fp16_internal_t *)(x)) + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + do { \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \ + (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ + } while (0) + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i]) + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x))) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i]) + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif + +#elif defined(__AVX512F__) + +#define GGML_SIMD + +// F32 AVX512 + +#define GGML_F32_STEP 64 +#define GGML_F32_EPR 16 + +#define GGML_F32x16 __m512 +#define GGML_F32x16_ZERO _mm512_setzero_ps() +#define GGML_F32x16_SET1(x) _mm512_set1_ps(x) +#define GGML_F32x16_LOAD _mm512_loadu_ps +#define GGML_F32x16_STORE _mm512_storeu_ps +// _mm512_fmadd_ps is defined in AVX512F so no guard is required +#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) +#define GGML_F32x16_ADD _mm512_add_ps +#define GGML_F32x16_MUL _mm512_mul_ps +#define GGML_F32x16_REDUCE(res, x) \ +do { \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + res = (ggml_float) _mm512_reduce_add_ps(x[0]); \ +} while (0) + +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x16 +#define GGML_F32_VEC_ZERO GGML_F32x16_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x16_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x16_LOAD +#define GGML_F32_VEC_STORE GGML_F32x16_STORE +#define GGML_F32_VEC_FMA GGML_F32x16_FMA +#define GGML_F32_VEC_ADD GGML_F32x16_ADD +#define GGML_F32_VEC_MUL GGML_F32x16_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE + +// F16 AVX512 + +// F16 AVX + +#define GGML_F16_STEP 64 +#define GGML_F16_EPR 16 + +// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead + +#define GGML_F32Cx16 __m512 +#define GGML_F32Cx16_ZERO _mm512_setzero_ps() +#define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x) + +// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F +// so F16C guard isn't required +#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) +#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) + +#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) +#define GGML_F32Cx16_ADD _mm512_add_ps +#define GGML_F32Cx16_MUL _mm512_mul_ps +#define GGML_F32Cx16_REDUCE(res, x) \ +do { \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ps(x[i], x[offset+i]); \ + } \ + res = (ggml_float) _mm512_reduce_add_ps(x[0]); \ +} while (0) + +#define GGML_F16_VEC GGML_F32Cx16 +#define GGML_F16_VEC_ZERO GGML_F32Cx16_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx16_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx16_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx16_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx16_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx16_MUL + +#define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE +#elif defined(__AVX__) + +#define GGML_SIMD + +// F32 AVX + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO _mm256_setzero_ps() +#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) +#define GGML_F32x8_LOAD _mm256_loadu_ps +#define GGML_F32x8_STORE _mm256_storeu_ps +#if defined(__FMA__) + #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) +#else + #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) +#endif +#define GGML_F32x8_ADD _mm256_add_ps +#define GGML_F32x8_MUL _mm256_mul_ps +#define GGML_F32x8_REDUCE(res, x) \ +do { \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm256_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ + _mm256_extractf128_ps(x[0], 1)); \ + const __m128 t1 = _mm_hadd_ps(t0, t0); \ + res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ +} while (0) +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 AVX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO _mm256_setzero_ps() +#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) + +#if defined(__F16C__) +// the _mm256_cvt intrinsics require F16C +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) +#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) +#else +static inline __m256 __avx_f32cx8_load(const ggml_fp16_t * x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + } + + return _mm256_loadu_ps(tmp); +} +static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { + float arr[8]; + + _mm256_storeu_ps(arr, y); + + for (int i = 0; i < 8; i++) + x[i] = GGML_FP32_TO_FP16(arr[i]); +} +#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) +#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) +#endif + +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD _mm256_add_ps +#define GGML_F32Cx8_MUL _mm256_mul_ps +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE + +#elif defined(__POWER9_VECTOR__) + +#define GGML_SIMD + +// F32 POWER9 + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 vector float +#define GGML_F32x4_ZERO 0.0f +#define GGML_F32x4_SET1 vec_splats +#define GGML_F32x4_LOAD(p) vec_xl(0, p) +#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) +#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) +#define GGML_F32x4_ADD vec_add +#define GGML_F32x4_MUL vec_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = vec_add(x[i], x[offset+i]); \ + } \ + res = vec_extract(x[0], 0) + \ + vec_extract(x[0], 1) + \ + vec_extract(x[0], 2) + \ + vec_extract(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 POWER9 +#define GGML_F16_STEP GGML_F32_STEP +#define GGML_F16_EPR GGML_F32_EPR +#define GGML_F16_VEC GGML_F32x4 +#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F16_VEC_FMA GGML_F32x4_FMA +#define GGML_F16_VEC_ADD GGML_F32x4_ADD +#define GGML_F16_VEC_MUL GGML_F32x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// Use vec_xl, not vec_ld, in case the load address is not aligned. +#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ + vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ + vec_extract_fp32_from_shortl(vec_xl(0, p)) +#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] +#define GGML_F16_VEC_STORE(p, r, i) \ + if (i & 0x1) \ + vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ + r[i - GGML_ENDIAN_BYTE(0)]), \ + 0, p - GGML_F16_EPR) + +#elif defined(__wasm_simd128__) + +#define GGML_SIMD + +// F32 WASM + +#define GGML_F32_STEP 16 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 v128_t +#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F32x4_LOAD wasm_v128_load +#define GGML_F32x4_STORE wasm_v128_store +#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) +#define GGML_F32x4_ADD wasm_f32x4_add +#define GGML_F32x4_MUL wasm_f32x4_mul +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 WASM + +#define GGML_F16_STEP 16 +#define GGML_F16_EPR 4 + +inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(p[0]); + tmp[1] = GGML_FP16_TO_FP32(p[1]); + tmp[2] = GGML_FP16_TO_FP32(p[2]); + tmp[3] = GGML_FP16_TO_FP32(p[3]); + + return wasm_v128_load(tmp); +} + +inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { + float tmp[4]; + + wasm_v128_store(tmp, x); + + p[0] = GGML_FP32_TO_FP16(tmp[0]); + p[1] = GGML_FP32_TO_FP16(tmp[1]); + p[2] = GGML_FP32_TO_FP16(tmp[2]); + p[3] = GGML_FP32_TO_FP16(tmp[3]); +} + +#define GGML_F16x4 v128_t +#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) +#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) +#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) +#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) +#define GGML_F16x4_FMA GGML_F32x4_FMA +#define GGML_F16x4_ADD wasm_f32x4_add +#define GGML_F16x4_MUL wasm_f32x4_mul +#define GGML_F16x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ + } \ + res = wasm_f32x4_extract_lane(x[0], 0) + \ + wasm_f32x4_extract_lane(x[0], 1) + \ + wasm_f32x4_extract_lane(x[0], 2) + \ + wasm_f32x4_extract_lane(x[0], 3); \ +} + +#define GGML_F16_VEC GGML_F16x4 +#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F16x4_FMA +#define GGML_F16_VEC_ADD GGML_F16x4_ADD +#define GGML_F16_VEC_MUL GGML_F16x4_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE + +#elif defined(__SSE3__) + +#define GGML_SIMD + +// F32 SSE + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 __m128 +#define GGML_F32x4_ZERO _mm_setzero_ps() +#define GGML_F32x4_SET1(x) _mm_set1_ps(x) +#define GGML_F32x4_LOAD _mm_loadu_ps +#define GGML_F32x4_STORE _mm_storeu_ps +#if defined(__FMA__) + // TODO: Does this work? + #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) +#else + #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) +#endif +#define GGML_F32x4_ADD _mm_add_ps +#define GGML_F32x4_MUL _mm_mul_ps +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm_add_ps(x[i], x[offset+i]); \ + } \ + const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ + res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ +} +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 SSE + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 4 + +static inline __m128 __sse_f16x4_load(const ggml_fp16_t * x) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(x[0]); + tmp[1] = GGML_FP16_TO_FP32(x[1]); + tmp[2] = GGML_FP16_TO_FP32(x[2]); + tmp[3] = GGML_FP16_TO_FP32(x[3]); + + return _mm_loadu_ps(tmp); +} + +static inline void __sse_f16x4_store(ggml_fp16_t * x, __m128 y) { + float arr[4]; + + _mm_storeu_ps(arr, y); + + x[0] = GGML_FP32_TO_FP16(arr[0]); + x[1] = GGML_FP32_TO_FP16(arr[1]); + x[2] = GGML_FP32_TO_FP16(arr[2]); + x[3] = GGML_FP32_TO_FP16(arr[3]); +} + +#define GGML_F32Cx4 __m128 +#define GGML_F32Cx4_ZERO _mm_setzero_ps() +#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x) +#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) +#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) +#define GGML_F32Cx4_FMA GGML_F32x4_FMA +#define GGML_F32Cx4_ADD _mm_add_ps +#define GGML_F32Cx4_MUL _mm_mul_ps +#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + +#define GGML_F16_VEC GGML_F32Cx4 +#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE + +#elif defined(__loongarch_asx) + +#define GGML_SIMD + +// F32 LASX +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0) +#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x)) +#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0) +#define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0) +#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a) +#define GGML_F32x8_ADD __lasx_xvfadd_s +#define GGML_F32x8_MUL __lasx_xvfmul_s +#define GGML_F32x8_REDUCE(res, x) \ +do { \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ + } \ + float *tmp_p = (float *)&x[0]; \ + res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \ +} while (0) +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 LASX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0) +#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x)) + +static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + } + + return (__m256)__lasx_xvld(tmp, 0); +} +static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { + float arr[8]; + + __lasx_xvst(y, arr, 0); + + for (int i = 0; i < 8; i++) { + x[i] = GGML_FP32_TO_FP16(arr[i]); + } +} +#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x) +#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y) + +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD __lasx_xvfadd_s +#define GGML_F32Cx8_MUL __lasx_xvfmul_s +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE + +#elif defined(__loongarch_sx) + +#define GGML_SIMD + +// F32 LSX + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 __m128 +#define GGML_F32x4_ZERO __lsx_vldi(0) +#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) +#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0) +#define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0) +#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a) +#define GGML_F32x4_ADD __lsx_vfadd_s +#define GGML_F32x4_MUL __lsx_vfmul_s +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \ + } \ + __m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \ + tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \ + tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ + const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \ + tmp = __lsx_vsrli_d((__m128i) t0, 32); \ + tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \ + tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ + res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 LSX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 4 + +static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(x[0]); + tmp[1] = GGML_FP16_TO_FP32(x[1]); + tmp[2] = GGML_FP16_TO_FP32(x[2]); + tmp[3] = GGML_FP16_TO_FP32(x[3]); + + return __lsx_vld(tmp, 0); +} + +static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { + float arr[4]; + + __lsx_vst(y, arr, 0); + + x[0] = GGML_FP32_TO_FP16(arr[0]); + x[1] = GGML_FP32_TO_FP16(arr[1]); + x[2] = GGML_FP32_TO_FP16(arr[2]); + x[3] = GGML_FP32_TO_FP16(arr[3]); +} + +#define GGML_F32Cx4 __m128 +#define GGML_F32Cx4_ZERO __lsx_vldi(0) +#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) +#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x) +#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y) +#define GGML_F32Cx4_FMA GGML_F32x4_FMA +#define GGML_F32Cx4_ADD __lsx_vfadd_s +#define GGML_F32Cx4_MUL __lsx_vfmul_s +#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + +#define GGML_F16_VEC GGML_F32Cx4 +#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE + +#endif + +// GGML_F32_ARR / GGML_F16_ARR +// number of registers to use per step +#ifdef GGML_SIMD +#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) +#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) +#endif + +// +// Threading defs +// + +typedef pthread_t ggml_thread_t; + +#if defined(_WIN32) + +typedef CONDITION_VARIABLE ggml_cond_t; +typedef SRWLOCK ggml_mutex_t; + +#define ggml_mutex_init(m) InitializeSRWLock(m) +#define ggml_mutex_destroy(m) +#define ggml_mutex_lock(m) AcquireSRWLockExclusive(m) +#define ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m) +#define ggml_mutex_lock_shared(m) AcquireSRWLockShared(m) +#define ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m) + +#define ggml_cond_init(c) InitializeConditionVariable(c) +#define ggml_cond_destroy(c) +#define ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED) +#define ggml_cond_broadcast(c) WakeAllConditionVariable(c) + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#else + +typedef pthread_cond_t ggml_cond_t; +typedef pthread_mutex_t ggml_mutex_t; + +#define ggml_mutex_init(m) pthread_mutex_init(m, NULL) +#define ggml_mutex_destroy(m) pthread_mutex_destroy(m) +#define ggml_mutex_lock(m) pthread_mutex_lock(m) +#define ggml_mutex_unlock(m) pthread_mutex_unlock(m) +#define ggml_mutex_lock_shared(m) pthread_mutex_lock(m) +#define ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m) + +#define ggml_lock_init(x) UNUSED(x) +#define ggml_lock_destroy(x) UNUSED(x) +#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) +#define ggml_lock_lock(x) _mm_pause() +#else +#define ggml_lock_lock(x) UNUSED(x) +#endif +#define ggml_lock_unlock(x) UNUSED(x) + +#define GGML_LOCK_INITIALIZER 0 +#define ggml_cond_init(c) pthread_cond_init(c, NULL) +#define ggml_cond_destroy(c) pthread_cond_destroy(c) +#define ggml_cond_wait(c, m) pthread_cond_wait(c, m) +#define ggml_cond_broadcast(c) pthread_cond_broadcast(c) + +#define ggml_thread_create pthread_create +#define ggml_thread_join pthread_join + +#endif + +// Threadpool def +struct ggml_threadpool { + ggml_mutex_t mutex; // mutex for cond.var + ggml_cond_t cond; // cond.var for waiting for new work + + struct ggml_cgraph * cgraph; + struct ggml_cplan * cplan; + + // synchronization primitives + atomic_int n_graph; // incremented when there is work to be done (i.e each graph) + atomic_int GGML_CACHE_ALIGN n_barrier; + atomic_int GGML_CACHE_ALIGN n_barrier_passed; + atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. + + // these are atomic as an annotation for thread-sanitizer + atomic_bool stop; // Used for stopping the threadpool altogether + atomic_bool pause; // Used for pausing the threadpool or individual threads + atomic_int abort; // Used for aborting processing of a graph + + struct ggml_compute_state * workers; // per thread state + int n_threads_max; // number of threads in the pool + atomic_int n_threads_cur; // number of threads used in the current graph + + int32_t prio; // Scheduling priority + uint32_t poll; // Polling level (0 - no polling) + + enum ggml_status ec; +}; + +// Per-thread state +struct ggml_compute_state { +#ifndef GGML_USE_OPENMP + ggml_thread_t thrd; + bool cpumask[GGML_MAX_N_THREADS]; + int last_graph; + bool pending; +#endif + struct ggml_threadpool * threadpool; + int ith; +}; + +// +// fundamental operations +// + +inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } + +inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } + +inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } + +static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + +#if defined(GGML_SIMD) + float sumf = 0.0f; + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } +#else + // scalar + ggml_float sumf = 0.0; + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(x[i]*y[i]); + } +#endif + + *s = sumf; +} + +static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + int i = 0; + ggml_float sumf = 0; + +#if defined(__AVX512BF16__) + __m512 c1 = _mm512_setzero_ps(); + __m512 c2 = _mm512_setzero_ps(); + for (; i + 64 <= n; i += 64) { + c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))), + m512bh(_mm512_loadu_si512((y + i)))); + c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))), + m512bh(_mm512_loadu_si512((y + i + 32)))); + } + sumf += (ggml_float)_mm512_reduce_add_ps(c1); + sumf += (ggml_float)_mm512_reduce_add_ps(c2); + +#elif defined(__AVX512F__) +#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16)) + __m512 c1 = _mm512_setzero_ps(); + __m512 c2 = _mm512_setzero_ps(); + for (; i + 32 <= n; i += 32) { + c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1); + c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2); + } + sumf += (ggml_float)_mm512_reduce_add_ps(c1); + sumf += (ggml_float)_mm512_reduce_add_ps(c2); + +#undef LOAD +#elif defined(__AVX2__) || defined(__AVX__) +#if defined(__AVX2__) +#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)) +#else +#define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1)) +#endif + __m256 c1 = _mm256_setzero_ps(); + __m256 c2 = _mm256_setzero_ps(); + __m256 c3 = _mm256_setzero_ps(); + __m256 c4 = _mm256_setzero_ps(); + for (; i + 32 <= n; i += 32) { + c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1); + c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2); + c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3); + c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4); + } + __m128 g; + c1 = _mm256_add_ps(_mm256_add_ps(c1, c3), + _mm256_add_ps(c2, c4)); + g = _mm_add_ps(_mm256_extractf128_ps(c1, 1), + _mm256_castps256_ps128(c1)); + g = _mm_add_ps(g, _mm_movehl_ps(g, g)); + g = _mm_add_ss(g, _mm_movehdup_ps(g)); + sumf += (ggml_float)_mm_cvtss_f32(g); + +#undef LOAD +#endif + + for (; i < n; ++i) { + sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) * + GGML_BF16_TO_FP32(y[i])); + } + *s = sumf; +} + +static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + ggml_float sumf = 0.0; + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F16_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); + } +#else + for (int i = 0; i < n; ++i) { + sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); + } +#endif + + *s = sumf; +} + +// compute GGML_VEC_DOT_UNROLL dot products at once +// xs - x row stride in bytes +inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { + ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; + + ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); + + sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); + } + } + } + + // reduce sum0..sum3 to sum0 + for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { + GGML_F16_VEC_REDUCE(sumf[k], sum[k]); + } + + // leftovers + for (int i = np; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); + } + } +#else + for (int i = 0; i < n; ++i) { + for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { + sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); + } + } +#endif + + for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { + s[i] = sumf[i]; + } +} + +inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] += x[i]*v; + } +#endif +} + +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + +// xs and vs are byte strides of x and v +inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { + + const float * restrict x[GGML_VEC_MAD_UNROLL]; + const float * restrict v[GGML_VEC_MAD_UNROLL]; + + for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) { + x[i] = (const float *) ((const char *) xv + i*xs); + v[i] = (const float *) ((const char *) vv + i*vs); + } + +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = GGML_F32_VEC_SET1(v[k][0]); + } + + GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#else + // scalar + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#endif +} + +//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } +inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { +#if defined(GGML_USE_ACCELERATE) + vDSP_vsmul(y, 1, &v, y, 1, n); +#elif defined(GGML_SIMD) + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] *= v; + } +#endif +} + +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + +inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } +inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } +inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } +inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } +inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); } +inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); } +inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } +inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } +inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } +inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } +inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } +inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } +inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } +inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } +// TODO: optimize performance +inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } +inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } +inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } + +static const float GELU_COEF_A = 0.044715f; +static const float GELU_QUICK_COEF = -1.702f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +inline static float ggml_gelu_f32(float x) { + return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = ggml_table_gelu_f16[i16[i]]; + } +} + +#ifdef GGML_GELU_FP16 +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + if (x[i] <= -10.0f) { + y[i] = 0.0f; + } else if (x[i] >= 10.0f) { + y[i] = x[i]; + } else { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); + } + } +} +#else +inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_f32(x[i]); + } +} +#endif + +inline static float ggml_gelu_quick_f32(float x) { + return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); +} + +//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { +// const uint16_t * i16 = (const uint16_t *) x; +// for (int i = 0; i < n; ++i) { +// y[i] = ggml_table_gelu_quick_f16[i16[i]]; +// } +//} + +#ifdef GGML_GELU_QUICK_FP16 +inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + uint16_t t; + for (int i = 0; i < n; ++i) { + ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); + memcpy(&t, &fp16, sizeof(uint16_t)); + y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]); + } +} +#else +inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + y[i] = ggml_gelu_quick_f32(x[i]); + } +} +#endif + +// Sigmoid Linear Unit (SiLU) function +inline static float ggml_silu_f32(float x) { + return x/(1.0f + expf(-x)); +} + +#if __FINITE_MATH_ONLY__ +#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" +#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461" +#endif + +#if defined(__ARM_NEON) && defined(__aarch64__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static float32x4_t ggml_v_expf(float32x4_t x) { + const float32x4_t r = vdupq_n_f32(0x1.8p23f); + const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); + const float32x4_t n = vsubq_f32(z, r); + const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, + vdupq_n_f32(0x1.7f7d1cp-20f)); + const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); + const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); + const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); + const float32x4_t u = vmulq_f32(b, b); + const float32x4_t j = vfmaq_f32( + vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), + vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), + vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); + if (!vpaddd_u64(vreinterpretq_u64_u32(c))) + return vfmaq_f32(k, j, k); + const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); + const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); + const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); + return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), + vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static float32x4_t ggml_v_silu(float32x4_t x) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + const float32x4_t neg_x = vsubq_f32(zero, x); + const float32x4_t exp_neg_x = ggml_v_expf(neg_x); + const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + return vdivq_f32(x, one_plus_exp_neg_x); +} + +#elif defined(__AVX512F__) && defined(__AVX512DQ__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static __m512 ggml_v_expf(__m512 x) { + const __m512 r = _mm512_set1_ps(0x1.8p23f); + const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); + const __m512 n = _mm512_sub_ps(z, r); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); + const __mmask16 d = + _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) + return res; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m512 ggml_v_silu(__m512 x) { + const __m512 one = _mm512_set1_ps(1); + const __m512 zero = _mm512_setzero_ps(); + const __m512 neg_x = _mm512_sub_ps(zero, x); + const __m512 exp_neg_x = ggml_v_expf(neg_x); + const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + return _mm512_div_ps(x, one_plus_exp_neg_x); +} + +#elif defined(__AVX2__) && defined(__FMA__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static __m256 ggml_v_expf(__m256 x) { + const __m256 r = _mm256_set1_ps(0x1.8p23f); + const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); + const __m256 n = _mm256_sub_ps(z, r); + const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), + _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); + const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); + const __m256 k = _mm256_castsi256_ps( + _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); + const __m256i c = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(126), _CMP_GT_OQ)); + const __m256 u = _mm256_mul_ps(b, b); + const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, + _mm256_set1_ps(0x1.573e2ep-5f)), u, + _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, + _mm256_set1_ps(0x1.fffdb6p-2f))), + u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) + return _mm256_fmadd_ps(j, k, k); + const __m256i g = _mm256_and_si256( + _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), + _mm256_set1_epi32(0x82000000u)); + const __m256 s1 = + _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); + const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); + const __m256i d = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(192), _CMP_GT_OQ)); + return _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), + _mm256_andnot_ps( + _mm256_castsi256_ps(d), + _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(c), + _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), + _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m256 ggml_v_silu(__m256 x) { + const __m256 one = _mm256_set1_ps(1); + const __m256 zero = _mm256_setzero_ps(); + const __m256 neg_x = _mm256_sub_ps(zero, x); + const __m256 exp_neg_x = ggml_v_expf(neg_x); + const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + return _mm256_div_ps(x, one_plus_exp_neg_x); +} + +#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON + +#if defined(__FMA__) +#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) +#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) +#else +#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) +#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) +#endif + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline static __m128 ggml_v_expf(__m128 x) { + const __m128 r = _mm_set1_ps(0x1.8p23f); + const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r); + const __m128 n = _mm_sub_ps(z, r); + const __m128 b = + NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x)); + const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23); + const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1)))); + const __m128i c = + _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126))); + const __m128 u = _mm_mul_ps(b, b); + const __m128 j = + MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u, + MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))), + u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm_movemask_epi8(c)) + return MADD128(j, k, k); + const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())), + _mm_set1_epi32(0x82000000u)); + const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u))); + const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g)); + const __m128i d = + _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192))); + return _mm_or_ps( + _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)), + _mm_andnot_ps(_mm_castsi128_ps(d), + _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)), + _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k))))); +} + +// computes silu x/(1+exp(-x)) in single precision vector +inline static __m128 ggml_v_silu(__m128 x) { + const __m128 one = _mm_set1_ps(1); + const __m128 zero = _mm_setzero_ps(); + const __m128 neg_x = _mm_sub_ps(zero, x); + const __m128 exp_neg_x = ggml_v_expf(neg_x); + const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); + return _mm_div_ps(x, one_plus_exp_neg_x); +} + +#endif // __ARM_NEON / __AVX2__ / __SSE2__ + +static void ggml_vec_silu_f32(const int n, float * y, const float * x) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i))); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i))); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i))); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i))); + } +#endif + for (; i < n; ++i) { + y[i] = ggml_silu_f32(x[i]); + } +} + +static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { + int i = 0; + ggml_float sum = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i), + _mm512_set1_ps(max))); + _mm512_storeu_ps(y + i, val); + sum += (ggml_float)_mm512_reduce_add_ps(val); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(max))); + _mm256_storeu_ps(y + i, val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (ggml_float)_mm_cvtss_f32(val2); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(max))); + _mm_storeu_ps(y + i, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif + sum += (ggml_float)_mm_cvtss_f32(val); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i), + vdupq_n_f32(max))); + vst1q_f32(y + i, val); + sum += (ggml_float)vaddvq_f32(val); + } +#endif + for (; i < n; ++i) { + float val = expf(x[i] - max); + sum += (ggml_float)val; + y[i] = val; + } + return sum; +} + +static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) { + // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i) + + int i = 0; + ggml_float sum = 0; + for (; i < n; ++i) { + float val = x[i] - max; + y[i] = val; + sum += (ggml_float)expf(val); + } + return sum = (ggml_float)logf(sum); +} + +inline static float ggml_silu_backward_f32(float x, float dy) { + const float s = 1.0f/(1.0f + expf(-x)); + return dy*s*(1.0f + x*(1.0f - s)); +} + +inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { + for (int i = 0; i < n; ++i) { + dx[i] = ggml_silu_backward_f32(x[i], dy[i]); + } +} + +inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (ggml_float)x[i]; + } + *s = sum; +#else + vDSP_sve(x, 1, s, n); +#endif +} + +inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { + ggml_float sum = 0.0; + for (int i = 0; i < n; ++i) { + sum += (ggml_float)x[i]; + } + *s = sum; +} + +inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += GGML_FP16_TO_FP32(x[i]); + } + *s = sum; +} + +inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) { + float sum = 0.0f; + for (int i = 0; i < n; ++i) { + sum += GGML_BF16_TO_FP32(x[i]); + } + *s = sum; +} + +inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { +#ifndef GGML_USE_ACCELERATE + float max = -INFINITY; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + } + *s = max; +#else + vDSP_maxv(x, 1, s, n); +#endif +} + +inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { + ggml_vec_norm_f32(n, s, x); + *s = 1.f/(*s); +} + +inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { + float max = -INFINITY; + int idx = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, x[i]); + if (max == x[i]) { idx = i; } + } + *s = idx; +} + +// Helpers for polling loops +#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) ) +static inline void ggml_thread_cpu_relax(void) { + __asm__ volatile("yield" ::: "memory"); +} +#elif defined(__x86_64__) +static inline void ggml_thread_cpu_relax(void) { + _mm_pause(); +} +#else +static inline void ggml_thread_cpu_relax(void) {;} +#endif + +// +// NUMA support +// + +#define GGML_NUMA_MAX_NODES 8 +#define GGML_NUMA_MAX_CPUS 512 + +struct ggml_numa_node { + uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node + uint32_t n_cpus; +}; + +struct ggml_numa_nodes { + enum ggml_numa_strategy numa_strategy; + struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES]; + uint32_t n_nodes; + uint32_t total_cpus; // hardware threads on system + uint32_t current_node; // node on which main process is execting +#if defined(__gnu_linux__) + cpu_set_t cpuset; // cpuset from numactl +#else + uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype +#endif +}; + +// +// ggml state +// + +struct ggml_state { + struct ggml_numa_nodes numa; +}; + +static struct ggml_state g_state = {0}; + +void ggml_barrier(struct ggml_threadpool * tp) { + int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed); + if (n_threads == 1) { + return; + } + +#ifdef GGML_USE_OPENMP + #pragma omp barrier +#else + int n_passed = atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed); + + // enter barrier (full seq-cst fence) + int n_barrier = atomic_fetch_add_explicit(&tp->n_barrier, 1, memory_order_seq_cst); + + if (n_barrier == (n_threads - 1)) { + // last thread + atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed); + + // exit barrier (fill seq-cst fence) + atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst); + return; + } + + // wait for other threads + while (atomic_load_explicit(&tp->n_barrier_passed, memory_order_relaxed) == n_passed) { + ggml_thread_cpu_relax(); + } + + // exit barrier (full seq-cst fence) + // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead + #ifdef GGML_TSAN_ENABLED + atomic_fetch_add_explicit(&tp->n_barrier_passed, 0, memory_order_seq_cst); + #else + atomic_thread_fence(memory_order_seq_cst); + #endif +#endif +} + +#if defined(__gnu_linux__) +static cpu_set_t ggml_get_numa_affinity(void) { + cpu_set_t cpuset; + pthread_t thread; + thread = pthread_self(); + CPU_ZERO(&cpuset); + pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset); + return cpuset; +} +#else +static uint32_t ggml_get_numa_affinity(void) { + return 0; // no NUMA support +} +#endif + +void ggml_numa_init(enum ggml_numa_strategy numa_flag) { + if (g_state.numa.n_nodes > 0) { + fprintf(stderr, "ggml_numa_init: NUMA already initialized\n"); + + return; + } + +#if defined(__gnu_linux__) + struct stat st; + char path[256]; + int rv; + + // set numa scheme + g_state.numa.numa_strategy = numa_flag; + + GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy); + + g_state.numa.cpuset = ggml_get_numa_affinity(); + + // enumerate nodes + while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.n_nodes; + } + + // enumerate CPUs + while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) != 0) { break; } + ++g_state.numa.total_cpus; + } + + GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus); + + // figure out which node we're on + uint current_cpu; + int getcpu_ret = 0; +#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__) + getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); +#else + // old glibc doesn't have a wrapper for this call. Fall back on direct syscall +# if !defined(SYS_getcpu) && defined(SYS_get_cpu) +# define SYS_getcpu SYS_get_cpu // some older glibc versions use this name +# endif + getcpu_ret = syscall(SYS_getcpu, ¤t_cpu, &g_state.numa.current_node); +#endif + + if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) { + g_state.numa.n_nodes = 0; + return; + } + + GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu); + + for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) { + struct ggml_numa_node * node = &g_state.numa.nodes[n]; + GGML_PRINT_DEBUG("CPUs on node %u:", n); + node->n_cpus = 0; + for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) { + rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c); + GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); + if (stat(path, &st) == 0) { + node->cpus[node->n_cpus++] = c; + GGML_PRINT_DEBUG(" %u", c); + } + } + GGML_PRINT_DEBUG("\n"); + } + + if (ggml_is_numa()) { + FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r"); + if (fptr != NULL) { + char buf[42]; + if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { + GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); + } + fclose(fptr); + } + } +#else + UNUSED(numa_flag); + // TODO +#endif +} + +bool ggml_is_numa(void) { + return g_state.numa.n_nodes > 1; +} + +#if defined(__ARM_ARCH) + +#if defined(__linux__) && defined(__aarch64__) +#include +#elif defined(__APPLE__) +#include +#endif + +#if !defined(HWCAP2_I8MM) +#define HWCAP2_I8MM (1 << 13) +#endif + +static void ggml_init_arm_arch_features(void) { +#if defined(__linux__) && defined(__aarch64__) + uint32_t hwcap = getauxval(AT_HWCAP); + uint32_t hwcap2 = getauxval(AT_HWCAP2); + + ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); + ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP); + ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); + ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); + +#if defined(__ARM_FEATURE_SVE) + ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#endif +#elif defined(__APPLE__) + int oldp = 0; + size_t size = sizeof(oldp); + if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_neon = oldp; + + if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_dotprod = oldp; + + if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + ggml_arm_arch_features.has_i8mm = oldp; + + ggml_arm_arch_features.has_sve = 0; + ggml_arm_arch_features.sve_cnt = 0; +#else +// Run-time CPU feature detection not implemented for this platform, fallback to compile time +#if defined(__ARM_NEON) + ggml_arm_arch_features.has_neon = 1; +#else + ggml_arm_arch_features.has_neon = 0; +#endif + +#if defined(__ARM_FEATURE_MATMUL_INT8) + ggml_arm_arch_features.has_i8mm = 1; +#else + ggml_arm_arch_features.has_i8mm = 0; +#endif + +#if defined(__ARM_FEATURE_SVE) + ggml_arm_arch_features.has_sve = 1; + ggml_arm_arch_features.sve_cnt = 16; +#else + ggml_arm_arch_features.has_sve = 0; + ggml_arm_arch_features.sve_cnt = 0; +#endif +#endif +} +#endif + +struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { + GGML_ASSERT(!ggml_get_no_alloc(ctx)); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); + + ggml_set_i32(result, value); + + return result; +} + +struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { + GGML_ASSERT(!ggml_get_no_alloc(ctx)); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + + ggml_set_f32(result, value); + + return result; +} + +struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); + } + } break; + case GGML_TYPE_BF16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value)); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + GGML_ABORT("fatal error"); + } + } + + return tensor; +} + +struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { + const int n = ggml_nrows(tensor); + const int nc = tensor->ne[0]; + const size_t n1 = tensor->nb[1]; + + char * const data = tensor->data; + + switch (tensor->type) { + case GGML_TYPE_I8: + { + assert(tensor->nb[0] == sizeof(int8_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I16: + { + assert(tensor->nb[0] == sizeof(int16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_I32: + { + assert(tensor->nb[0] == sizeof(int32_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); + } + } break; + case GGML_TYPE_F16: + { + assert(tensor->nb[0] == sizeof(ggml_fp16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); + } + } break; + case GGML_TYPE_BF16: + { + assert(tensor->nb[0] == sizeof(ggml_bf16_t)); + for (int i = 0; i < n; i++) { + ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value)); + } + } break; + case GGML_TYPE_F32: + { + assert(tensor->nb[0] == sizeof(float)); + for (int i = 0; i < n; i++) { + ggml_vec_set_f32(nc, (float *)(data + i*n1), value); + } + } break; + default: + { + GGML_ABORT("fatal error"); + } + } + + return tensor; +} + +int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]); + } + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + return ((int8_t *)(tensor->data))[i]; + } + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + return ((int16_t *)(tensor->data))[i]; + } + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + return ((int32_t *)(tensor->data))[i]; + } + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]); + } + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return ((float *)(tensor->data))[i]; + } + default: + { + GGML_ABORT("fatal error"); + } + } +} + +void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } + switch (tensor->type) { + case GGML_TYPE_I8: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); + ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value); + } break; + case GGML_TYPE_F32: + { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case GGML_TYPE_F16: + return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_BF16: + return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]); + case GGML_TYPE_F32: + return ((float *) data)[0]; + default: + GGML_ABORT("fatal error"); + } +} + +void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case GGML_TYPE_F16: + { + ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_BF16: + { + ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value); + } break; + case GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]); + } + switch (tensor->type) { + case GGML_TYPE_I8: + { + return ((int8_t *)(tensor->data))[i]; + } + case GGML_TYPE_I16: + { + return ((int16_t *)(tensor->data))[i]; + } + case GGML_TYPE_I32: + { + return ((int32_t *)(tensor->data))[i]; + } + case GGML_TYPE_F16: + { + return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); + } + case GGML_TYPE_BF16: + { + return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]); + } + case GGML_TYPE_F32: + { + return ((float *)(tensor->data))[i]; + } + default: + { + GGML_ABORT("fatal error"); + } + } +} + +void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { + if (!ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } + switch (tensor->type) { + case GGML_TYPE_I8: + { + ((int8_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I16: + { + ((int16_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_I32: + { + ((int32_t *)(tensor->data))[i] = value; + } break; + case GGML_TYPE_F16: + { + ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_BF16: + { + ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value); + } break; + case GGML_TYPE_F32: + { + ((float *)(tensor->data))[i] = value; + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case GGML_TYPE_F16: + return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); + case GGML_TYPE_BF16: + return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]); + case GGML_TYPE_F32: + return ((float *) data)[0]; + default: + GGML_ABORT("fatal error"); + } +} + +void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case GGML_TYPE_F16: + { + ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); + } break; + case GGML_TYPE_BF16: + { + ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value); + } break; + case GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +//////////////////////////////////////////////////////////////////////////////// + +// ggml_compute_forward_dup + +static void ggml_compute_forward_dup_same_cont( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == dst->type); + + const size_t nb0 = ggml_type_size(src0->type); + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by elements + const int ne = ggml_nelements(dst); + const int dr = (ne + nth - 1) / nth; + const int ie0 = dr * ith; + const int ie1 = MIN(ie0 + dr, ne); + + if (ie0 < ie1) { + memcpy( + ((char *) dst->data + ie0*nb0), + ((char *) src0->data + ie0*nb0), + (ie1 - ie0) * nb0); + } +} + +static void ggml_compute_forward_dup_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + + if (ggml_is_contiguous(dst)) { + if (nb00 == sizeof(ggml_fp16_t)) { + if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { + ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } + return; + } + + // dst counters + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t)); + + if (++i10 == ne00) { + i10 = 0; + if (++i11 == ne01) { + i11 = 0; + if (++i12 == ne02) { + i12 = 0; + if (++i13 == ne03) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } +} + +static void ggml_compute_forward_dup_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + + if (ggml_is_contiguous(dst)) { + if (nb00 == sizeof(ggml_bf16_t)) { + if (dst->type == GGML_TYPE_BF16) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00])); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + for (int i00 = 0; i00 < ne00; i00++) { + dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { + ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; + float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + for (int i00 = 0; i00 < ne00; i00++) { + src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]); + } + + quantize_row_q(src0_f32, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_BF16) { + size_t id = 0; + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr)); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } + return; + } + + // dst counters + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == GGML_TYPE_BF16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t)); + + if (++i10 == ne00) { + i10 = 0; + if (++i11 == ne01) { + i11 = 0; + if (++i12 == ne02) { + i12 = 0; + if (++i13 == ne03) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr)); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } +} + +static void ggml_compute_forward_dup_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + + GGML_TENSOR_UNARY_OP_LOCALS + + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (ggml_is_contiguous(dst)) { + // TODO: simplify + if (nb00 == sizeof(float)) { + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + const size_t rs = ne00 * nb00; + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else if (ggml_get_type_traits_cpu(dst->type)->from_float) { + ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float; + + size_t id = 0; + size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); + char * dst_ptr = (char *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + quantize_row_q(src0_ptr, dst_ptr + id, ne00); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else if (dst->type == GGML_TYPE_BF16) { + size_t id = 0; + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + id += ne00 * ir0; + for (int i01 = ir0; i01 < ir1; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr); + id++; + } + } + id += ne00 * (ne01 - ir1); + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + if (dst->type == GGML_TYPE_F32) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, sizeof(float)); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else if (dst->type == GGML_TYPE_BF16) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + } else { + GGML_ABORT("fatal error"); // TODO: implement + } +} + +// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. +static void ggml_compute_forward_dup_bytes( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(src0->type == dst->type); + + GGML_TENSOR_UNARY_OP_LOCALS; + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { + ggml_compute_forward_dup_same_cont(params, dst); + return; + } + + const size_t type_size = ggml_type_size(src0->type); + const int ith = params->ith; // thread index + const int nth = params->nth; // number of threads + + + // parallelize by rows + const int nr = ne01; + // number of rows per thread + const int dr = (nr + nth - 1) / nth; + // row range for this thread + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (src0->type == dst->type && + ne00 == ne0 && + nb00 == type_size && nb0 == type_size) { + // copy by rows + const size_t rs = ne00 * type_size; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ir0; i01 < ir1; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (ggml_is_contiguous(dst)) { + size_t id = 0; + char * dst_ptr = (char *) dst->data; + const size_t rs = ne00 * type_size; + + if (nb00 == type_size) { + // src0 is contigous on first dimension, copy by rows + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, rs); + id += rs; + } + id += rs * (ne01 - ir1); + } + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + id += rs * ir0; + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03; + memcpy(dst_ptr + id, src0_ptr, type_size); + + id += type_size; + } + } + id += rs * (ne01 - ir1); + } + } + } + + return; + } + + // dst counters + + int64_t i10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + i10 += ne00 * ir0; + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + + memcpy(dst_ptr, src0_ptr, type_size); + + if (++i10 == ne0) { + i10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + i10 += ne00 * (ne01 - ir1); + while (i10 >= ne0) { + i10 -= ne0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } +} + +static void ggml_compute_forward_dup_q( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float; + + size_t qk = ggml_blck_size(type); + const int64_t nr = ggml_nelements(src1) / qk; + + // destination must be contiguous in the first dimension + GGML_ASSERT(nb10 == ggml_type_size(dst->type)); + // must either have first dimension large enough to hold a row, or fully contiguous + GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + + uint32_t i = ir * qk; + + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + + dequantize_row_q( + (const void *) ((char *) src0->data + x_offset), + (float *) ((char *) dst->data + dst_offset), qk); + } +} + +static void ggml_compute_forward_dup( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (src0->type == dst->type) { + ggml_compute_forward_dup_bytes(params, dst); + return; + } + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_dup_f16(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_dup_bf16(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_dup_f32(params, dst); + } break; + default: + { + if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) { + ggml_compute_forward_dup_q(params, dst); + break; + } + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_add + +static void ggml_compute_forward_add_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { +#ifdef GGML_USE_ACCELERATE + vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); +#else + ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_add_f16_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + if (dst->type == GGML_TYPE_F32) { + GGML_ASSERT( nb0 == sizeof(float)); + } + else { + GGML_ASSERT(dst->type == GGML_TYPE_F16); + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + } + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + if (dst->type == GGML_TYPE_F16) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } else { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + } + } + } + } + else { + // src1 is not contiguous + GGML_ABORT("fatal error"); + } +} + +static void ggml_compute_forward_add_bf16_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + if (dst->type == GGML_TYPE_F32) { + GGML_ASSERT( nb0 == sizeof(float)); + } + else { + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + } + + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + if (dst->type == GGML_TYPE_BF16) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } else { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + } + } + } + } + else { + // src1 is not contiguous + GGML_ABORT("fatal error"); + } +} + +static void ggml_compute_forward_add_f16_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i])); + } + } + } + else { + // src1 is not contiguous + GGML_ABORT("fatal error"); + } +} + +static void ggml_compute_forward_add_bf16_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(ggml_bf16_t)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i])); + } + } + } + else { + // src1 is not contiguous + GGML_ABORT("fatal error"); + } +} + +static void ggml_compute_forward_add_q_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + const enum ggml_type dtype = dst->type; + ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float; + ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ggml_is_quantized(src0->type)); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 indices + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + + // src1 and dst are same shape as src0 => same indices + const int i13 = i03; + const int i12 = i02; + const int i11 = i01; + + const int i3 = i03; + const int i2 = i02; + const int i1 = i01; + + void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); + float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); + void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + assert(ne00 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne00); + // add src1 + ggml_vec_acc_f32(ne00, wdata, src1_row); + // quantize row to dst + if (quantize_row_q != NULL) { + quantize_row_q(wdata, dst_row, ne00); + } else { + memcpy(dst_row, wdata, ne0*nb0); + } + } +} + +static void ggml_compute_forward_add( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add_f32(params, dst); + } + else { + GGML_ABORT("fatal error"); + } + } break; + case GGML_TYPE_F16: + { + if (src1->type == GGML_TYPE_F16) { + ggml_compute_forward_add_f16_f16(params, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add_f16_f32(params, dst); + } + else { + GGML_ABORT("fatal error"); + } + } break; + case GGML_TYPE_BF16: + { + if (src1->type == GGML_TYPE_BF16) { + ggml_compute_forward_add_bf16_bf16(params, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add_bf16_f32(params, dst); + } + else { + GGML_ABORT("fatal error"); + } + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: + { + ggml_compute_forward_add_q_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_add1 + +static void ggml_compute_forward_add1_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + +#ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_add1_f32); + + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, + (float *) ((char *) src1->data), 0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, + ne0); +#else + ggml_vec_add1_f32(ne0, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), + *(float *) src1->data); +#endif + } +} + +static void ggml_compute_forward_add1_f16_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1_f16_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + // scalar to add + const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F16); + + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1_q_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float; + ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float; + + // we don't support permuted src0 + GGML_ASSERT(nb00 == ggml_type_size(type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ggml_is_quantized(src0->type)); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03)); + void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 )); + + assert(ne0 % 32 == 0); + + // unquantize row from src0 to temp buffer + dequantize_row_q(src0_row, wdata, ne0); + // add src1 + ggml_vec_acc1_f32(ne0, wdata, v); + // quantize row to dst + quantize_row_q(wdata, dst_row, ne0); + } +} + +static void ggml_compute_forward_add1_bf16_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + // scalar to add + const float v = *(float *) src1->data; + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1_bf16_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_scalar(src1)); + + // scalar to add + const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_BF16); + GGML_ASSERT(src1->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_BF16); + + GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); + ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v); + } + } +} + +static void ggml_compute_forward_add1( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add1_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + if (src1->type == GGML_TYPE_F16) { + ggml_compute_forward_add1_f16_f16(params, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add1_f16_f32(params, dst); + } + else { + GGML_ABORT("fatal error"); + } + } break; + case GGML_TYPE_BF16: + { + if (src1->type == GGML_TYPE_BF16) { + ggml_compute_forward_add1_bf16_bf16(params, dst); + } + else if (src1->type == GGML_TYPE_F32) { + ggml_compute_forward_add1_bf16_f32(params, dst); + } + else { + GGML_ABORT("fatal error"); + } + } break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: + { + ggml_compute_forward_add1_q_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_acc + +static void ggml_compute_forward_acc_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during acc + // nb0 is implicitly element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + ggml_barrier(params->threadpool); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src1); + const int nc = src1->ne[0]; + + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + + // src0 and dst as viewed during acc + const size_t nb0 = ggml_element_size(src0); + + const size_t nb00 = nb0; + const size_t nb01 = nb1; + const size_t nb02 = nb2; + const size_t nb03 = nb3; + + GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst)); + GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0)); + + GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + +#ifdef GGML_USE_ACCELERATE + vDSP_vadd( + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1, + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc); +#else + ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); +#endif + } +} + +static void ggml_compute_forward_acc( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_acc_f32(params, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_sub + +static void ggml_compute_forward_sub_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + if (nb10 == sizeof(float)) { + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { +#ifdef GGML_USE_ACCELERATE + vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); +#else + ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int ir = ir0; ir < ir1; ++ir) { + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; + } + } + } +} + +static void ggml_compute_forward_sub( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sub_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_mul + +static void ggml_compute_forward_mul_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0 ; r < nr0; ++r) { +#ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_mul_f32); + + vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); +#else + ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); + } + } + } +} + +static void ggml_compute_forward_mul( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now"); + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mul_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_div + +static void ggml_compute_forward_div_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (nb10 == sizeof(float)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { +#ifdef GGML_USE_ACCELERATE + UNUSED(ggml_vec_div_f32); + + vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); +#else + ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); +#endif + } + } + } else { + // src1 is not contiguous + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); + + dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); + } + } + } +} + +static void ggml_compute_forward_div( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_div_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_sqr + +static void ggml_compute_forward_sqr_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqr_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqr( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqr_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_sqrt + +static void ggml_compute_forward_sqrt_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert( dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sqrt_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sqrt( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sqrt_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_log + +static void ggml_compute_forward_log_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_log_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_log( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_log_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_sin + +static void ggml_compute_forward_sin_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_sin_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sin( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sin_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_cos + +static void ggml_compute_forward_cos_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_cos_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_cos( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cos_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_sum + +static void ggml_compute_forward_sum_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(float)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + ggml_float sum = 0; + ggml_float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32_ggf(ne00, + &row_sum, + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + sum += row_sum; + } + } + } + ((float *) dst->data)[0] = sum; +} + +static void ggml_compute_forward_sum_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_scalar(dst)); + + assert(src0->nb[0] == sizeof(ggml_fp16_t)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + float sum = 0; + float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f16_ggf(ne00, + &row_sum, + (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; + } + } + } + ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); +} + +static void ggml_compute_forward_sum_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_scalar(dst)); + + assert(src0->nb[0] == sizeof(ggml_bf16_t)); + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + + float sum = 0; + float row_sum = 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_bf16_ggf(ne00, + &row_sum, + (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); + sum += row_sum; + } + } + } + ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum); +} + +static void ggml_compute_forward_sum( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_sum_f16(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_sum_bf16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_sum_rows + +static void ggml_compute_forward_sum_rows_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + for (int64_t i3 = 0; i3 < ne03; i3++) { + for (int64_t i2 = 0; i2 < ne02; i2++) { + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); + float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); + float row_sum = 0; + ggml_vec_sum_f32(ne00, &row_sum, src_row); + dst_row[0] = row_sum; + } + } + } +} + +static void ggml_compute_forward_sum_rows( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sum_rows_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_mean + +static void ggml_compute_forward_mean_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + + GGML_TENSOR_UNARY_OP_LOCALS + + assert(ne0 == 1); + assert(ne1 == ne01); + assert(ne2 == ne02); + assert(ne3 == ne03); + + UNUSED(ne0); + UNUSED(ne1); + UNUSED(ne2); + UNUSED(ne3); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + ggml_vec_sum_f32(ne00, + (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); + + *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; + } + } + } +} + +static void ggml_compute_forward_mean( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_mean_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_argmax + +static void ggml_compute_forward_argmax_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(src0->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + const size_t nb01 = src0->nb[1]; + const size_t nb0 = dst->nb[0]; + + for (int64_t i1 = 0; i1 < ne01; i1++) { + float * src = (float *) ((char *) src0->data + i1*nb01); + int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); + int v = 0; + ggml_vec_argmax_f32(ne00, &v, src); + dst_[0] = v; + } +} + +static void ggml_compute_forward_argmax( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argmax_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_count_equal + +static void ggml_compute_forward_count_equal_i32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS; + + GGML_ASSERT(src0->type == GGML_TYPE_I32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_scalar(dst)); + GGML_ASSERT(dst->type == GGML_TYPE_I64); + + const int64_t nr = ggml_nrows(src0); + + const int ith = params->ith; + const int nth = params->nth; + + int64_t * sums = (int64_t *) params->wdata; + int64_t sum_thread = 0; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02*ne01); + const int64_t i02 = (ir - i03*ne03) / ne01; + const int64_t i01 = ir - i03*ne03 - i02*ne02; + + const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01; + const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11; + + for (int64_t i00 = 0; i00 < ne00; ++i00) { + const int32_t val0 = *((const int32_t *) (data0 + i00*nb00)); + const int32_t val1 = *((const int32_t *) (data1 + i00*nb10)); + + sum_thread += val0 == val1; + } + } + if (ith != 0) { + sums[ith] = sum_thread; + } + ggml_barrier(params->threadpool); + + if (ith != 0) { + return; + } + + for (int ith_other = 1; ith_other < nth; ++ith_other) { + sum_thread += sums[ith_other]; + } + *((int64_t *) dst->data) = sum_thread; +} + +static void ggml_compute_forward_count_equal( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_I32: + { + ggml_compute_forward_count_equal_i32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_repeat + +static void ggml_compute_forward_repeat_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_can_repeat(src0, dst)); + + GGML_TENSOR_UNARY_OP_LOCALS + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_vec_cpy_f32(ne00, + (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), + (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_repeat_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_can_repeat(src0, dst)); + + GGML_TENSOR_UNARY_OP_LOCALS + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); + ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); + // ggml_vec_cpy_f16(ne00, y, x) + for (int i = 0; i < ne00; ++i) { + y[i] = x[i]; + } + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_repeat( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_I16: + { + ggml_compute_forward_repeat_f16(params, dst); + } break; + case GGML_TYPE_F32: + case GGML_TYPE_I32: + { + ggml_compute_forward_repeat_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_repeat_back + +static void ggml_compute_forward_repeat_back_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_can_repeat(dst, src0)); + + GGML_TENSOR_UNARY_OP_LOCALS + + // guaranteed to be an integer due to the check in ggml_can_repeat + const int nr0 = (int)(ne00/ne0); + const int nr1 = (int)(ne01/ne1); + const int nr2 = (int)(ne02/ne2); + const int nr3 = (int)(ne03/ne3); + + // TODO: support for transposed / permuted tensors + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + if (ggml_is_contiguous(dst)) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } else { + for (int k3 = 0; k3 < ne3; k3++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int k1 = 0; k1 < ne1; k1++) { + ggml_vec_set_f32(ne0, + (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), + 0); + } + } + } + } + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne3; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne2; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne1; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + ggml_vec_acc_f32(ne0, + (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), + (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_repeat_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_repeat_back_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_concat + +static void ggml_compute_forward_concat_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const float * x; + + // TODO: smarter multi-theading + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = ith; i2 < ne2; i2 += nth) { + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); + } + + float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } +} + +static void ggml_compute_forward_concat( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_I32: + { + ggml_compute_forward_concat_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_abs + +static void ggml_compute_forward_abs_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_abs_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_abs( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_abs_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_sgn + +static void ggml_compute_forward_sgn_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_sgn_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sgn( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sgn_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_neg + +static void ggml_compute_forward_neg_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_neg_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_neg( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_neg_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_step + +static void ggml_compute_forward_step_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_step_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_step( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_step_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_tanh + +static void ggml_compute_forward_tanh_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_tanh_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_tanh( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_tanh_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_elu + +static void ggml_compute_forward_elu_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_elu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_elu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_elu_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_relu + +static void ggml_compute_forward_relu_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_relu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_relu_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_sigmoid + +static void ggml_compute_forward_sigmoid_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_sigmoid_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_sigmoid( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_sigmoid_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_gelu + +static void ggml_compute_forward_gelu_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_gelu_quick + +static void ggml_compute_forward_gelu_quick_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_quick_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu_quick( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_quick_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_silu + +static void ggml_compute_forward_silu_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_silu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_silu_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} +// ggml_compute_forward_leaky_relu + +static void ggml_compute_forward_leaky_relu_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + ggml_vec_leaky_relu_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope); + } +} + +static void ggml_compute_forward_leaky_relu( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_leaky_relu_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_silu_back + +static void ggml_compute_forward_silu_back_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * grad = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + assert(ggml_is_contiguous_1(grad)); + assert(ggml_is_contiguous_1(src1)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src1, dst)); + assert(ggml_are_same_shape(src1, grad)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1->ne[0]; + const int nr = ggml_nrows(src1); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_silu_backward_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src1->data + i1*(src1->nb[1])), + (float *) ((char *) grad->data + i1*(grad->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_silu_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_silu_back_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + + +static void ggml_compute_forward_hardswish_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_hardswish_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} +static void ggml_compute_forward_hardswish( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_hardswish_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_compute_forward_hardsigmoid_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_hardsigmoid_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_hardsigmoid( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_hardsigmoid_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_compute_forward_exp_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + ggml_vec_exp_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_exp( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_exp_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + + +// ggml_compute_forward_norm + +static void ggml_compute_forward_norm_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + GGML_ASSERT(eps >= 0.0f); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)x[i00]; + } + + float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + ggml_float sum2 = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sum2 += (ggml_float)(v*v); + } + + float variance = sum2/ne00; + const float scale = 1.0f/sqrtf(variance + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_norm( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_group_rms_norm + +static void ggml_compute_forward_rms_norm_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + GGML_ASSERT(eps >= 0.0f); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)(x[i00] * x[i00]); + } + + const float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + // for (int i00 = 0; i00 < ne00; i00++) { + // y[i00] = x[i00]; + // } + + const float scale = 1.0f/sqrtf(mean + eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_rms_norm( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_compute_forward_rms_norm_back_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output + const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass + + GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + // src1 is same shape as src0 => same indices + const int64_t i11 = i01; + const int64_t i12 = i02; + const int64_t i13 = i03; + + const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + + ggml_float sum_xx = 0.0; + ggml_float sum_xdz = 0.0; + + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum_xx += (ggml_float)(x[i00] * x[i00]); + sum_xdz += (ggml_float)(x[i00] * dz[i00]); + } + + //const float mean = (float)(sum_xx)/ne00; + const float mean_eps = (float)(sum_xx)/ne00 + eps; + const float sum_eps = (float)(sum_xx) + eps*ne00; + //const float mean_xdz = (float)(sum_xdz)/ne00; + // we could cache rms from forward pass to improve performance. + // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms. + //const float rms = sqrtf(mean_eps); + const float rrms = 1.0f / sqrtf(mean_eps); + //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) + + { + // z = rms_norm(x) + // + // rms_norm(src1) = + // scale( + // src1, + // div( + // 1, + // sqrt( + // add( + // scale( + // sum( + // sqr( + // src1)), + // (1.0/N)), + // eps)))); + + // postorder: + // ## op args grad + // 00 param src1 grad[#00] + // 01 const 1 + // 02 sqr (#00) grad[#02] + // 03 sum (#02) grad[#03] + // 04 const 1/N + // 05 scale (#03, #04) grad[#05] + // 06 const eps + // 07 add (#05, #06) grad[#07] + // 08 sqrt (#07) grad[#08] + // 09 div (#01,#08) grad[#09] + // 10 scale (#00,#09) grad[#10] + // + // backward pass, given grad[#10] + // #10: scale + // grad[#00] += scale(grad[#10],#09) + // grad[#09] += sum(mul(grad[#10],#00)) + // #09: div + // grad[#08] += neg(mul(grad[#09], div(#09,#08))) + // #08: sqrt + // grad[#07] += mul(grad[#08], div(0.5, #08)) + // #07: add + // grad[#05] += grad[#07] + // #05: scale + // grad[#03] += scale(grad[#05],#04) + // #03: sum + // grad[#02] += repeat(grad[#03], #02) + // #02: + // grad[#00] += scale(mul(#00, grad[#02]), 2.0) + // + // substitute and simplify: + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#02] = repeat(grad[#03], #02) + // grad[#02] = repeat(scale(grad[#05],#04), #02) + // grad[#02] = repeat(scale(grad[#07],#04), #02) + // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) + // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps))) + // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps)) + // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps)) + // a = b*c + d*e + // a = b*c*f/f + d*e*f/f + // a = (b*c*f + d*e*f)*(1/f) + // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c)) + // a = (b + d*e/c)*c + // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps) + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms + // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms + // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms + // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms + // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms + // a = (dz + x*div(-mean_xdz,mean_eps))*rrms + // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms) + // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + } + // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) + // post-order: + // dx := x + // dx := scale(dx,-mean_xdz/mean_eps) + // dx := add(dx, dz) + // dx := scale(dx, rrms) + float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) + ggml_vec_cpy_f32 (ne00, dx, x); + // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); + ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); + ggml_vec_acc_f32 (ne00, dx, dz); + ggml_vec_scale_f32(ne00, dx, rrms); + } + } + } +} + +static void ggml_compute_forward_rms_norm_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_back_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_group_norm + +static void ggml_compute_forward_group_norm_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + // TODO: optimize + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + + int n_channels = src0->ne[2]; + int n_groups = dst->op_params[0]; + int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; + for (int i = ith; i < n_groups; i += nth) { + int start = i * n_channels_per_group; + int end = start + n_channels_per_group; + if (end > n_channels) { + end = n_channels; + } + int step = end - start; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + ggml_float sum = 0.0; + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + ggml_float sumr = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sumr += (ggml_float)x[i00]; + } + sum += sumr; + } + } + const float mean = sum / (ne00 * ne01 * step); + + ggml_float sum2 = 0.0; + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); + + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + + ggml_float sumr = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v; + sumr += (ggml_float)(v * v); + } + sum2 += sumr; + } + } + const float variance = sum2 / (ne00 * ne01 * step); + const float scale = 1.0f / sqrtf(variance + eps); + + for (int64_t i02 = start; i02 < end; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); + ggml_vec_scale_f32(ne00, y, scale); + } + } + } + } +} + +static void ggml_compute_forward_group_norm( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_group_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_mul_mat + +static void ggml_compute_forward_mul_mat_one_chunk( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const enum ggml_type type, + const int64_t num_rows_per_vec_dot, + const int64_t ir0_start, + const int64_t ir0_end, + const int64_t ir1_start, + const int64_t ir1_end) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const bool src1_cont = ggml_is_contiguous(src1); + + ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; + + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + + //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); + + // threads with no work simply yield (not sure if it helps) + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { + return; + } + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; + + // attempt to reduce false-sharing (does not seem to make a difference) + // 16 * 2, accounting for mmla kernels + float tmp[32]; + + for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { + const int64_t i13 = (ir1 / (ne12 * ne1)); + const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; + const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); + + // broadcast src0 into src1 + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + const int64_t i3 = i13; + + const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char*)wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12 + i13 * nb13)); + float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + } + + for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { + memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); + } + } + } + } +} + +static void ggml_compute_forward_mul_mat( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + enum ggml_type const vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; + ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; + int64_t const vec_dot_num_rows = type_traits_cpu[src0->type].nrows; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: extract to "extra_op" +#if GGML_USE_LLAMAFILE + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + + const bool src1_cont = ggml_is_contiguous(src1); + + if (src1_cont) { + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(params, + ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)src1->data + i12*nb12 + i13*nb13, + nb11/ggml_type_size(src1->type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/ggml_type_size(dst->type), + src0->type, + src1->type, + dst->type)) + goto UseGgmlGemm1; + return; + } +UseGgmlGemm1:; +#endif + + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; + + const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + assert(params->wsize >= ne13*nbw3); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); + } + } + } + } + + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed); + } + + ggml_barrier(params->threadpool); + +#if GGML_USE_LLAMAFILE + if (src1->type != vec_dot_type) { + const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + for (int64_t i13 = 0; i13 < ne13; i13++) + for (int64_t i12 = 0; i12 < ne12; i12++) + if (!llamafile_sgemm(params, + ne01, ne11, ne00/ggml_blck_size(src0->type), + (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, + nb01/ggml_type_size(src0->type), + (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, + row_size/ggml_type_size(vec_dot_type), + (char *)dst->data + i12*nb2 + i13*nb3, + nb1/ggml_type_size(dst->type), + src0->type, + vec_dot_type, + dst->type)) + goto UseGgmlGemm2; + return; + } +UseGgmlGemm2:; +#endif + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const int64_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const int64_t nr1 = ne1 * ne2 * ne3; + + // Now select a reasonable chunk size. + int chunk_size = 16; + + // We need to step up the size if it's small + if (nr0 == 1 || nr1 == 1) { + chunk_size = 64; + } + + // distribute the work across the inner or outer loop based on which one is larger + // The number of chunks in the 0/1 dim. + // CEIL(nr0/chunk_size) + int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; + int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; + + // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. + // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 + // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. + if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { + // distribute the thread work across the inner or outer loop based on which one is larger + nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + } + + // The number of elements in each chunk + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; + + while (current_chunk < nchunk0 * nchunk1) { + const int64_t ith0 = current_chunk % nchunk0; + const int64_t ith1 = current_chunk / nchunk0; + + const int64_t ir0_start = dr0 * ith0; + const int64_t ir0_end = MIN(ir0_start + dr0, nr0); + + const int64_t ir1_start = dr1 * ith1; + const int64_t ir1_end = MIN(ir1_start + dr1, nr1); + + // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols + int64_t num_rows_per_vec_dot = vec_dot_num_rows; + + // these checks are needed to avoid crossing dim1 boundaries + // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity + if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) { + num_rows_per_vec_dot = 1; + } + + ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); + + if (nth >= nchunk0 * nchunk1) { + break; + } + + current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed); + } +} + +// ggml_compute_forward_mul_mat_id + +static void ggml_compute_forward_mul_mat_id( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * ids = dst->src[2]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + + const bool src1_cont = ggml_is_contiguous(src1); + + ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; + ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + char * wdata_src1_end = (src1->type == vec_dot_type) ? + (char *) params->wdata : + (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] + + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; + + const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + assert(params->wsize >= ne13*nbw3); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); + } + } + } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] + + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + + // group rows by src0 matrix + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + assert(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; + matrix_row_counts[i02] += 1; + } + } + } + + ggml_barrier(params->threadpool); + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_cur = (const char *) src0->data + cur_a*nb02; + + const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows + + // distribute the thread work across the inner or outer loop based on which one is larger + + const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + const int64_t ith0 = ith % nth0; + const int64_t ith1 = ith / nth0; + + const int64_t dr0 = (nr0 + nth0 - 1)/nth0; + const int64_t dr1 = (nr1 + nth1 - 1)/nth1; + + const int64_t ir010 = dr0*ith0; + const int64_t ir011 = MIN(ir010 + dr0, nr0); + + const int64_t ir110 = dr1*ith1; + const int64_t ir111 = MIN(ir110 + dr1, nr1); + + // threads with no work simply yield (not sure if it helps) + //if (ir010 >= ir011 || ir110 >= ir111) { + // sched_yield(); + // continue; + //} + + // block-tiling attempt + const int64_t blck_0 = 16; + const int64_t blck_1 = 16; + + // attempt to reduce false-sharing (does not seem to make a difference) + float tmp[16]; + + for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { + for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { + for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { + const int64_t _i12 = ir1; // logical row index for this expert + + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides + // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using + // the original src1 data pointer, so we should index using the indices directly + // TODO: this is a bit of a hack, we should probably have a better way to handle this + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12*ne11)*row_size + : (i11*nb11 + i12*nb12)); + + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); + + //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); + //} + + for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); + } + + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } + } + +#undef MMID_MATRIX_ROW +} + +// ggml_compute_forward_out_prod + +static void ggml_compute_forward_out_prod_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + GGML_ASSERT(ne2 % ne02 == 0); + GGML_ASSERT(ne3 % ne03 == 0); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (ith == 0) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } + ggml_barrier(params->threadpool); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // block-tiling attempt + const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); + const int64_t blck_1 = 16; + + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { + const int64_t bir1 = MIN(bir + blck_1, ir1); + for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { + const int64_t bne01 = MIN(bi01 + blck_0, ne01); + for (int64_t ir = bir; ir < bir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2 / dps2; + const int64_t i03 = i3 / dps3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + +#if GGML_VEC_MAD_UNROLL > 2 + const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL); + for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); + } + for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#else + for (int64_t i01 = bi01; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#endif + } + } + } +} + +static void ggml_compute_forward_out_prod_q_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int ith = params->ith; + const int nth = params->nth; + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float; + + GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(ne03 == ne13); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 dim0 + GGML_ASSERT(nb00 == ggml_type_size(type)); + + // dst dim0 cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + // GGML_ASSERT(nb0 <= nb1); + // GGML_ASSERT(nb1 <= nb2); + // GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + if (ith == 0) { + ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + } + ggml_barrier(params->threadpool); + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + + for (int64_t ir = ir0; ir < ir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + + for (int64_t i01 = 0; i01 < ne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + dequantize_row_q(s0, wdata, ne0); + ggml_vec_mad_f32(ne0, d, wdata, *s1); + } + } +} + +static void ggml_compute_forward_out_prod( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: + { + ggml_compute_forward_out_prod_q_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + GGML_ABORT("fatal error"); // todo + // ggml_compute_forward_out_prod_f16_f32(params, dst); + } + case GGML_TYPE_F32: + { + ggml_compute_forward_out_prod_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_scale + +static void ggml_compute_forward_scale_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + // scale factor + float v; + memcpy(&v, dst->op_params, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const size_t nb01 = src0->nb[1]; + + const size_t nb1 = dst->nb[1]; + + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); + } +} + +static void ggml_compute_forward_scale( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_scale_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_set + +static void ggml_compute_forward_set_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during set + // nb0 is implicitly element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + ggml_barrier(params->threadpool); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src1); + const int nc = src1->ne[0]; + + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + + // src0 and dst as viewed during set + const size_t nb0 = ggml_element_size(src0); + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst)); + + GGML_ASSERT(nb10 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + } +} + +static void ggml_compute_forward_set_i32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during set + // nb0 is implicitly element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + ggml_barrier(params->threadpool); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src1); + const int nc = src1->ne[0]; + + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + + // src0 and dst as viewed during set + const size_t nb0 = ggml_element_size(src0); + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst)); + + GGML_ASSERT(nb10 == sizeof(int32_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + + ggml_vec_cpy_i32(nc, + (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + } +} + +static void ggml_compute_forward_set( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_set_f32(params, dst); + } break; + case GGML_TYPE_I32: + { + ggml_compute_forward_set_i32(params, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_cpy + +static void ggml_compute_forward_cpy( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, dst); +} + +// ggml_compute_forward_cont + +static void ggml_compute_forward_cont( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, dst); +} + +// ggml_compute_forward_reshape + +static void ggml_compute_forward_reshape( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(dst); +} + +// ggml_compute_forward_view + +static void ggml_compute_forward_view( + const struct ggml_compute_params * params, + const struct ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(dst); +} + +// ggml_compute_forward_permute + +static void ggml_compute_forward_permute( + const struct ggml_compute_params * params, + const struct ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(dst); +} + +// ggml_compute_forward_transpose + +static void ggml_compute_forward_transpose( + const struct ggml_compute_params * params, + const struct ggml_tensor * dst) { + // NOP + UNUSED(params); + UNUSED(dst); +} + +// ggml_compute_forward_get_rows + +static void ggml_compute_forward_get_rows_q( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + const enum ggml_type type = src0->type; + ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float; + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == ggml_type_size(type)); + assert(ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + dequantize_row_q( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } +} + +static void ggml_compute_forward_get_rows_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(ggml_fp16_t)); + assert(ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + ggml_fp16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } +} + +static void ggml_compute_forward_get_rows_bf16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(ggml_bf16_t)); + assert(ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + ggml_bf16_to_fp32_row( + (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); + } +} + +static void ggml_compute_forward_get_rows_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(dst) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i/(ne11*ne10); + const int64_t i11 = (i - i12*ne11*ne10)/ne10; + const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + ggml_vec_cpy_f32(nc, + (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), + (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); + } +} + +static void ggml_compute_forward_get_rows( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: + { + ggml_compute_forward_get_rows_q(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_f16(params, dst); + } break; + case GGML_TYPE_BF16: + { + ggml_compute_forward_get_rows_bf16(params, dst); + } break; + case GGML_TYPE_F32: + case GGML_TYPE_I32: + { + ggml_compute_forward_get_rows_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// ggml_compute_forward_get_rows_back + +static void ggml_compute_forward_get_rows_back_f32_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_is_contiguous(dst)); + + // ggml_compute_forward_dup_same_cont(params, opt0, dst); + + memset(dst->data, 0, ggml_nbytes(dst)); + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + GGML_ASSERT( dst->ne[0] == nc); + GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + for (int j = 0; j < nc; ++j) { + ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v); + } + } +} + +static void ggml_compute_forward_get_rows_back_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (params->ith != 0) { + return; + } + + GGML_ASSERT(ggml_is_contiguous(dst)); + + // ggml_compute_forward_dup_same_cont(params, opt0, dst); + + memset(dst->data, 0, ggml_nbytes(dst)); + + const int nc = src0->ne[0]; + const int nr = ggml_nelements(src1); + + GGML_ASSERT( dst->ne[0] == nc); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < nr; ++i) { + const int r = ((int32_t *) src1->data)[i]; + + ggml_vec_add_f32(nc, + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) dst->data + r*dst->nb[1]), + (float *) ((char *) src0->data + i*src0->nb[1])); + } +} + +static void ggml_compute_forward_get_rows_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_back_f32_f16(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rows_back_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// ggml_compute_forward_diag + +static void ggml_compute_forward_diag_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + // TODO: handle transposed/permuted matrices + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne00 == ne1); + GGML_ASSERT(ne01 == 1); + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne3); + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = 0; i1 < ne1; i1++) { + float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02); + for (int i0 = 0; i0 < i1; i0++) { + d[i0] = 0; + } + d[i1] = s[i1]; + for (int i0 = i1+1; i0 < ne0; i0++) { + d[i0] = 0; + } + } + } + } +} + +static void ggml_compute_forward_diag( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_diag_mask_inf + +static void ggml_compute_forward_diag_mask_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const float value) { + + const struct ggml_tensor * src0 = dst->src[0]; + + const int ith = params->ith; + const int nth = params->nth; + + const int n_past = ((int32_t *) dst->op_params)[0]; + const bool inplace = src0->data == dst->data; + + GGML_ASSERT(n_past >= 0); + + if (!inplace) { + if (ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); + } + ggml_barrier(params->threadpool); + } + + // TODO: handle transposed/permuted matrices + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + const int nr = src0->ne[1]; + const int nz = n/nr; + + GGML_ASSERT( dst->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + for (int k = 0; k < nz; k++) { + for (int j = ith; j < nr; j += nth) { + for (int i = n_past; i < nc; i++) { + if (i > n_past + j) { + *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; + } + } + } + } +} + +static void ggml_compute_forward_diag_mask_inf( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_compute_forward_diag_mask_zero( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_diag_mask_f32(params, dst, 0); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_soft_max + +static void ggml_compute_forward_soft_max_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + assert(ggml_is_contiguous(dst)); + assert(ggml_are_same_shape(src0, dst)); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + //const int64_t ne11 = src1 ? src1->ne[1] : 1; + + // TODO: is this supposed to be ceil instead of floor? + // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + for (int i1 = ir0; i1 < ir1; i1++) { + // ALiBi + const uint32_t h = (i1/ne01)%ne02; // head + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); + float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); + + // broadcast the mask across rows + ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; + + ggml_vec_cpy_f32 (nc, wp, sp); + ggml_vec_scale_f32(nc, wp, scale); + if (mp_f32) { + if (use_f16) { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); + } + } else { + for (int i = 0; i < nc; ++i) { + wp[i] += slope*mp_f32[i]; + } + } + } + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(wp[i])); + } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, wp); + + ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(nc, dp, sum); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dp[i])); + assert(!isinf(dp[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + + +// ggml_compute_forward_soft_max_ext_back + +static void ggml_compute_forward_soft_max_ext_back_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src1, dst)); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + // TODO: handle transposed/permuted matrices + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); + float *y = (float *)((char *) src1->data + i1*src1->nb[1]); + float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(dy[i])); + assert(!isnan(y[i])); + } +#endif + // Jii = yi - yi*yi + // Jij = -yi*yj + // J = diag(y)-y.T*y + // dx = J * dy + // dxk = sum_i(Jki * dyi) + // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*dyk + // dxk = -yk * sum_i(yi * dyi) + yk*dyk + // dxk = -yk * dot(y, dy) + yk*dyk + // dxk = yk * (- dot(y, dy) + dyk) + // dxk = yk * (dyk - dot(y, dy)) + // + // post-order: + // dot_y_dy := dot(y, dy) + // dx := dy + // dx := dx - dot_y_dy + // dx := dx * y + + // linear runtime, no additional memory + float dot_y_dy = 0; + ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); + ggml_vec_cpy_f32 (nc, dx, dy); + ggml_vec_acc1_f32 (nc, dx, -dot_y_dy); + ggml_vec_mul_f32 (nc, dx, dx, y); + ggml_vec_scale_f32(nc, dx, scale); + +#ifndef NDEBUG + for (int i = 0; i < nc; ++i) { + assert(!isnan(dx[i])); + assert(!isinf(dx[i])); + } +#endif + } +} + +static void ggml_compute_forward_soft_max_ext_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_soft_max_ext_back_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_clamp + +static void ggml_compute_forward_clamp_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + float min; + float max; + memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + const size_t nb00 = src0->nb[0]; + const size_t nb01 = src0->nb[1]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + float * src0_ptr = (float *) ((char *) src0->data + j*nb01); + + for (int i = 0; i < nc; i++) { + dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); + } + } +} + +static void ggml_compute_forward_clamp( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_clamp_f32(params, dst); + } break; + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q8_K: + case GGML_TYPE_I8: + case GGML_TYPE_I16: + case GGML_TYPE_I32: + case GGML_TYPE_I64: + case GGML_TYPE_F64: + case GGML_TYPE_COUNT: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_rope + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + return 1 - MIN(1, MAX(0, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; +} + +static void ggml_rope_cache_init( + float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py + float theta = theta_base; + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; + rope_yarn( + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + ); + cache[i0 + 1] *= sin_sign; + + theta *= theta_scale; + } +} + +static void ggml_mrope_cache_init( + float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects, + float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py + float theta_t = theta_base_t; + float theta_h = theta_base_h; + float theta_w = theta_base_w; + float theta_e = theta_base_e; // extra position id for vision encoder + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + GGML_ASSERT(sect_dims <= ne0); + + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; + + int sector = (i0 / 2) % sect_dims; + if (indep_sects) { + // compute theta independently for each dim sections + // (i.e. reset corresponding theta when `i0` go from one section to another) + if (sector == 0) { + theta_t = theta_base_t; + } + else if (sector == sections[0]) { + theta_h = theta_base_h;; + } + else if (sector == sec_w) { + theta_w = theta_base_w; + } + else if (sector == sec_e) { + theta_e = theta_base_e; + } + } + + float theta = theta_t; + if (sector >= sections[0] && sector < sec_w) { + theta = theta_h; + } + else if (sector >= sec_w && sector < sec_w + sections[2]) { + theta = theta_w; + } + else if (sector >= sec_w + sections[2]) { + theta = theta_e; + } + + rope_yarn( + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + ); + cache[i0 + 1] *= sin_sign; + + theta_t *= theta_scale; + theta_w *= theta_scale; + theta_h *= theta_scale; + theta_e *= theta_scale; + } +} + +static void ggml_compute_forward_rope_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const bool forward) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); + + GGML_TENSOR_UNARY_OP_LOCALS + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + GGML_ASSERT(nb00 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + GGML_ASSERT(n_dims <= ne0); + GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne0/2); + } + + const float * freq_factors = NULL; + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } + + // backward process uses inverse rotation by cos and sin. + // cos and sin build a rotation matrix, where the inverse is the transpose. + // this essentially just switches the sign of sin. + const float sin_sign = forward ? 1.0f : -1.0f; + + const int32_t * pos = (const int32_t *) src1->data; + + for (int64_t i3 = 0; i3 < ne3; i3++) { // batch + for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len + + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + if (!is_mrope) { + const int64_t p = pos[i2]; + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + else { + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; + ggml_mrope_cache_init( + p_t, p_h, p_w, p_e, sections, is_vision, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + + for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads + if (ir++ < ir0) continue; + if (ir > ir1) break; + + if (is_neox || is_mrope) { + if (is_vision){ + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims] = x0*sin_theta + x1*cos_theta; + } + } else { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } + } else { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } + } + + if (is_vision) { + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims] = x0*sin_theta + x1*cos_theta; + } + } else { + // fill the remain channels with data from src tensor + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } + } + } + } +} + +// TODO: deduplicate f16/f32 code +static void ggml_compute_forward_rope_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const bool forward) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); + + + GGML_TENSOR_UNARY_OP_LOCALS + + //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); + //printf("n_past = %d, ne2 = %d\n", n_past, ne2); + + GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + GGML_ASSERT(n_dims <= ne0); + GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne0/2); + } + + const float * freq_factors = NULL; + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } + + // backward process uses inverse rotation by cos and sin. + // cos and sin build a rotation matrix, where the inverse is the transpose. + // this essentially just switches the sign of sin. + const float sin_sign = forward ? 1.0f : -1.0f; + + const int32_t * pos = (const int32_t *) src1->data; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + if (!is_mrope) { + const int64_t p = pos[i2]; + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + else { + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; + ggml_mrope_cache_init( + p_t, p_h, p_w, p_e, sections, is_vision, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (ir++ < ir0) continue; + if (ir > ir1) break; + + if (is_neox || is_mrope) { + if (is_vision) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } else { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } + } else { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[1]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } + + if (is_vision) { + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims]); + + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } else { + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } + } + } + } +} + +static void ggml_compute_forward_rope( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_rope_f16(params, dst, true); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_f32(params, dst, true); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_rope_back + +static void ggml_compute_forward_rope_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_rope_f16(params, dst, false); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_f32(params, dst, false); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_conv_transpose_1d + +static void ggml_compute_forward_conv_transpose_1d_f16_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (ith == 0) { + memset(params->wdata, 0, params->wsize); + + // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); + ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } + + // permute source data (src1) from (L x Cin) to (Cin x L) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + ggml_fp16_t * dst_data = wdata; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + + // need to zero dst since we are accumulating into it + memset(dst->data, 0, ggml_nbytes(dst)); + } + ggml_barrier(params->threadpool); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + + // total rows in dst + const int nr = ne1; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata_src = wdata + nk; + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f16(ne02, &v, 0, + (ggml_fp16_t *) wdata_src + i1n, 0, + (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1); + dst_data[i10*s0 + i00] += v; + } + } + } +} + +static void ggml_compute_forward_conv_transpose_1d_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02; + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (ith == 0) { + memset(params->wdata, 0, params->wsize); + + // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) + { + float * const wdata = (float *) params->wdata + 0; + + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); + float * dst_data = wdata + i01*ne00*ne02; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i00*ne02 + i02] = src[i00]; + } + } + } + } + + // prepare source data (src1) + { + float * const wdata = (float *) params->wdata + nk; + float * dst_data = wdata; + + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + for (int64_t i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne11 + i11] = src[i10]; + } + } + } + + // need to zero dst since we are accumulating into it + memset(dst->data, 0, ggml_nbytes(dst)); + } + ggml_barrier(params->threadpool); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + + // total rows in dst + const int nr = ne1; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * const wdata = (float *) params->wdata + 0; + float * const wdata_src = wdata + nk; + + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i1*nb1); + float * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f32(ne02, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i00*ne02, 0, 1); + dst_data[i10*s0 + i00] += v; + } + } + } +} + +static void ggml_compute_forward_conv_transpose_1d( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_transpose_1d_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_im2col_f32 +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void ggml_compute_forward_im2col_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; + + GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] + + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]); + } + } + } + } + } + } + } + } +} + + +// ggml_compute_forward_im2col_f16 +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void ggml_compute_forward_im2col_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne13 : ne12; + const int64_t IC = is_2D ? ne12 : ne11; + const int64_t IH = is_2D ? ne11 : 1; + const int64_t IW = ne10; + + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + + const int64_t OH = is_2D ? ne2 : 1; + const int64_t OW = ne1; + + int ofs0 = is_2D ? nb13 : nb12; + int ofs1 = is_2D ? nb12 : nb11; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + + // micro kernel + ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] + + for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); + } + } + } + } + } + } + } + } +} + +static void ggml_compute_forward_im2col( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + switch (dst->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_im2col_f16(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_im2col_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_im2col_back_f32 + +static void ggml_compute_forward_im2col_back_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output + const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t N = is_2D ? ne3 : ne2; + const int64_t IC = is_2D ? ne2 : ne1; + const int64_t IH = is_2D ? ne1 : 1; + const int64_t IW = ne0; + + const int64_t KH = is_2D ? ne11 : 1; + const int64_t KW = ne10; + + const int64_t OH = is_2D ? ne02 : 1; + const int64_t OW = ne01; + + int ofs0 = is_2D ? nb3 : nb2; + int ofs1 = is_2D ? nb2 : nb1; + + GGML_ASSERT(nb0 == sizeof(float)); + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + float * const wdata = (float *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iic = ith; iic < IC; iic += nth) { + for (int64_t iih = 0; iih < IH; iih++) { + for (int64_t iiw = 0; iiw < IW; iiw++) { + + // micro kernel + float grad = 0.0f; + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + // For s0 > 1 some values were skipped over in the forward pass. + // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well. + const int64_t tmpw = (iiw + p0 - ikw*d0); + if (tmpw % s0 != 0) { + continue; + } + const int64_t iow = tmpw / s0; + + // Equivalent logic as above except for s1. + int64_t ioh; + if (is_2D) { + const int64_t tmph = iih + p1 - ikh*d1; + + if (tmph % s1 != 0) { + continue; + } + + ioh = tmph / s1; + } else { + ioh = 0; + } + + if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) { + continue; + } + + const float * const grad_in = (const float *) src0->data + + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + grad += grad_in[iic*(KH*KW) + ikh*KW + ikw]; + } + } + float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] + dst_data[iih*IW + iiw] = grad; + } + } + } + } + } +} + +// ggml_compute_forward_conv_transpose_2d + +static void ggml_compute_forward_conv_transpose_2d( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int nk = ne00*ne01*ne02*ne03; + + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + + if (ith == 0) { + memset(params->wdata, 0, params->wsize); + + // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); + ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + for (int64_t i01 = 0; i01 < ne01; i01++) { + for (int64_t i00 = 0; i00 < ne00; i00++) { + dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; + } + } + } + } + } + + // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) + { + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + for (int i12 = 0; i12 < ne12; i12++) { + for (int i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); + ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + for (int i10 = 0; i10 < ne10; i10++) { + dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]); + } + } + } + } + + memset(dst->data, 0, ggml_nbytes(dst)); + } + ggml_barrier(params->threadpool); + + const int32_t stride = ggml_get_op_params_i32(dst, 0); + + // total patches in dst + const int np = ne2; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + ggml_fp16_t * const wdata_src = wdata + nk; + + for (int i2 = ip0; i2 < ip1; i2++) { // Cout + float * dst_data = (float *)((char *) dst->data + i2*nb2); + ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + for (int i11 = 0; i11 < ne11; i11++) { + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i11*ne10*ne12 + i10*ne12; + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + ggml_vec_dot_f16(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; + } + } + } + } + } +} + +// ggml_compute_forward_pool_1d_sk_p0 + +static void ggml_compute_forward_pool_1d_sk_p0( + const struct ggml_compute_params * params, + const enum ggml_op_pool op, + const int k, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src = dst->src[0]; + + assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); + + if (params->ith != 0) { + return; + } + + const char * cdata = (const char *)src->data; + const char * const data_end = cdata + ggml_nbytes(src); + float * drow = (float *)dst->data; + + const int64_t rs = dst->ne[0]; + + while (cdata < data_end) { + const void * srow = (const void *)cdata; + int j = 0; + for (int64_t i = 0; i < rs; ++i) { + switch (op) { + case GGML_OP_POOL_AVG: drow[i] = 0; break; + case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + } + for (int ki = 0; ki < k; ++ki) { + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); + switch (op) { + case GGML_OP_POOL_AVG: drow[i] += srow_j; break; + case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + } + ++j; + } + switch (op) { + case GGML_OP_POOL_AVG: drow[i] /= k; break; + case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + } + } + + cdata += src->nb[1]; + drow += rs; + } +} + +// ggml_compute_forward_pool_1d + +static void ggml_compute_forward_pool_1d( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int s0 = opts[2]; + const int p0 = opts[3]; + GGML_ASSERT(p0 == 0); // padding not supported + GGML_ASSERT(k0 == s0); // only s = k supported + + ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst); +} + +// ggml_compute_forward_pool_2d + +static void ggml_compute_forward_pool_2d( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src = dst->src[0]; + + assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); + + if (params->ith != 0) { + return; + } + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + const char * cdata = (const char*)src->data; + const char * const data_end = cdata + ggml_nbytes(src); + + const int64_t px = dst->ne[0]; + const int64_t py = dst->ne[1]; + const int64_t pa = px * py; + + float * dplane = (float *)dst->data; + + const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; + + while (cdata < data_end) { + for (int oy = 0; oy < py; ++oy) { + float * const drow = dplane + oy * px; + for (int ox = 0; ox < px; ++ox) { + float * const out = drow + ox; + switch (op) { + case GGML_OP_POOL_AVG: *out = 0; break; + case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + } + + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; + const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= src->ne[0]) continue; + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); + switch (op) { + case GGML_OP_POOL_AVG: *out += srow_j; break; + case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + } + } + } + switch (op) { + case GGML_OP_POOL_AVG: *out /= ka; break; + case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + } + } + } + + cdata += src->nb[2]; + dplane += pa; + } +} + +// ggml_compute_forward_pool_2d_back + +static void ggml_compute_forward_pool_2d_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src = dst->src[0]; + const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst + + assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + + if (params->ith != 0) { + return; + } + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + + char * cdata = (char *) dst->data; + const char * cdataf = (const char *) dstf->data; + const char * const data_end = cdata + ggml_nbytes(dst); + + GGML_ASSERT(params->ith == 0); + memset(cdata, 0, ggml_nbytes(dst)); + + const int64_t px = src->ne[0]; + const int64_t py = src->ne[1]; + const int64_t pa = px * py; + + const float * splane = (const float *) src->data; + + const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; + + while (cdata < data_end) { + for (int oy = 0; oy < py; ++oy) { + const float * const srow = splane + oy * px; + for (int ox = 0; ox < px; ++ox) { + const float grad0 = srow[ox]; + + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; + + if (op == GGML_OP_POOL_MAX) { + float maxval = -FLT_MAX; + int kxmax = -1; + int kymax = -1; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + const float val = dst->type == GGML_TYPE_F32 ? + ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]); + if (val <= maxval) { + continue; + } + + maxval = val; + kxmax = kx; + kymax = ky; + } + } + + if (kxmax == -1 || kymax == -1) { + continue; + } + + void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax)); + const int j = ix + kxmax; + if (dst->type == GGML_TYPE_F32) { + ((float *) drow)[j] += grad0; + } else { + ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j])); + } + } else if (op == GGML_OP_POOL_AVG) { + const float grad = grad0 / ka; + + for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= dst->ne[1]) { + continue; + } + void * drow = (void *)(cdata + dst->nb[1] * (iy + ky)); + for (int kx = 0; kx < k0; ++kx) { + int j = ix + kx; + if (j < 0 || j >= dst->ne[0]) { + continue; + } + + if (dst->type == GGML_TYPE_F32) { + ((float *) drow)[j] += grad; + } else { + ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad); + } + } + } + } else { + GGML_ASSERT(false); + } + } + } + + cdata += dst->nb[2]; + cdataf += dst->nb[2]; + splane += pa; + } +} + +// ggml_compute_forward_upscale + +static void ggml_compute_forward_upscale_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; + + // TODO: optimize + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const int64_t i01 = i1 / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const int64_t i00 = i0 / sf0; + + const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; + } + } + } + } +} + +static void ggml_compute_forward_upscale( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_upscale_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + + +// ggml_compute_forward_pad + +static void ggml_compute_forward_pad_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT( dst->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float * dst_ptr = (float *) dst->data; + + // TODO: optimize + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = ith; i1 < ne1; i1 += nth) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + for (int64_t i3 = 0; i3 < ne3; ++i3) { + const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; + + const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + dst_ptr[dst_idx] = *src_ptr; + } else { + dst_ptr[dst_idx] = 0; + } + } + } + } + } +} + +static void ggml_compute_forward_pad( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_pad_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_pad_reflect_1d + +static void ggml_compute_forward_pad_reflect_1d( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + const int32_t * opts = (const int32_t *) dst->op_params; + const int p0 = opts[0]; + const int p1 = opts[1]; + + GGML_TENSOR_UNARY_OP_LOCALS + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + for (int64_t i1 = ith; i1 < ne1; i1 += nth) { + float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0); + float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0); + + ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); + + for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; } + for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; } + } + } + } +} + +// ggml_compute_forward_arange + +static void ggml_compute_forward_arange_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + GGML_ASSERT(dst->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const float start = ggml_get_op_params_f32(dst, 0); + const float stop = ggml_get_op_params_f32(dst, 1); + const float step = ggml_get_op_params_f32(dst, 2); + + const int64_t steps = (int64_t) ceilf((stop - start) / step); + + GGML_ASSERT(ggml_nelements(dst) == steps); + + for (int64_t i = ith; i < steps; i+= nth) { + float value = start + step * i; + ((float *)dst->data)[i] = value; + } +} + +static void ggml_compute_forward_arange( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + switch (dst->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_arange_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_compute_forward_timestep_embedding_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + const int dim = ggml_get_op_params_i32(dst, 0); + const int max_period = ggml_get_op_params_i32(dst, 1); + + int half = dim / 2; + + for (int64_t i = 0; i < ne00; i++) { + float * embed_data = (float *)((char *) dst->data + i*nb1); + for (int64_t j = ith; j < half; j += nth) { + float timestep = ((float *)src0->data)[i]; + float freq = (float)expf(-logf(max_period) * j / half); + float arg = timestep * freq; + embed_data[j] = cosf(arg); + embed_data[j + half] = sinf(arg); + } + if (dim % 2 != 0 && ith == 0) { + embed_data[dim] = 0.f; + } + } +} + +static void ggml_compute_forward_timestep_embedding( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_timestep_embedding_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_argsort + +static void ggml_compute_forward_argsort_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || + (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + } +} + +static void ggml_compute_forward_argsort( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argsort_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + enum ggml_type const k_vec_dot_type = type_traits_cpu[k->type].vec_dot_type; + ggml_from_float_t const q_to_vec_dot = type_traits_cpu[k_vec_dot_type].from_float; + ggml_vec_dot_t const kq_vec_dot = type_traits_cpu[k->type].vec_dot; + ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; + + GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); + GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + + // loop over n_batch and n_head + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 + + if (v->type == GGML_TYPE_F16) { + memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + } else { + memset(VKQ32, 0, D*sizeof(float)); + } + + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + q_to_vec_dot(pq, Q_q, D); + + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + if (mv == -INFINITY) { + continue; + } + + float s; // KQ value + + const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); + kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); + + s = s*scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap*tanhf(s); + } + + s += mv; // apply mask + + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + if (v->type == GGML_TYPE_F16) { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, VKQ16, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); + } else { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f32(D, VKQ32, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + v_to_float(v_data, V32, D); + + // V += v*expf(s - M) + ggml_vec_mad_f32(D, VKQ32, V32, vs); + } + + S = S*ms + vs; // scale and increment sum with partial sum + } + + if (v->type == GGML_TYPE_F16) { + for (int64_t d = 0; d < D; ++d) { + VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); + } + } + + // V /= S + const float S_inv = 1.0f/S; + ggml_vec_scale_f32(D, VKQ32, S_inv); + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (dst->op_params[3]) { + case GGML_PREC_DEFAULT: + case GGML_PREC_F32: + { + // uses F32 accumulators + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_flash_attn_back + +static void ggml_compute_forward_flash_attn_back_f32( + const struct ggml_compute_params * params, + const bool masked, + struct ggml_tensor * dst) { + + const struct ggml_tensor * q = dst->src[0]; + const struct ggml_tensor * k = dst->src[1]; + const struct ggml_tensor * v = dst->src[2]; + const struct ggml_tensor * d = dst->src[3]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ned, d, ne) + GGML_TENSOR_LOCALS(size_t, nbd, d, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + const int64_t P = nek1 - N; + const int64_t M = P + N; + + const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); + const int mxDM = MAX(D, Mup); + + // GGML_ASSERT(ne0 == D); + // GGML_ASSERT(ne1 == N); + GGML_ASSERT(P >= 0); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(float)); + GGML_ASSERT(nbv0 == sizeof(float)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev1 == D); + GGML_ASSERT(ned0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nek1 == N + P); + GGML_ASSERT(nev1 == D); + GGML_ASSERT(ned1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + if (ith == 0) { + memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); + } + ggml_barrier(params->threadpool); + + const int64_t elem_q = ggml_nelements(q); + const int64_t elem_k = ggml_nelements(k); + + enum ggml_type result_type = dst->type; + GGML_ASSERT(ggml_blck_size(result_type) == 1); + const size_t tsize = ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); + const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); + + void * grad_q = (char *) dst->data; + void * grad_k = (char *) dst->data + offs_k; + void * grad_v = (char *) dst->data + offs_v; + + const size_t nbgq1 = nb0*neq0; + const size_t nbgq2 = nb0*neq0*neq1; + const size_t nbgq3 = nb0*neq0*neq1*neq2; + + const size_t nbgk1 = nb0*nek0; + const size_t nbgk2 = nb0*nek0*nek1; + const size_t nbgk3 = nb0*nek0*nek1*neq2; + + const size_t nbgv1 = nb0*nev0; + const size_t nbgv2 = nb0*nev0*nev1; + const size_t nbgv3 = nb0*nev0*nev1*neq2; + + // parallelize by k rows using ggml_vec_dot_f32 + + // total rows in k + const int nr = nek2*nek3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float scale = 1.0f/sqrtf(D); + + //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + + // how often k2 (and v2) is repeated in q2 + int nrep = neq2/nek2; + + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int ik3 = ir/(nek2); + const int ik2 = ir - ik3*nek2; + + const int iq3 = ik3; + const int id3 = ik3; + const int iv3 = ik3; + const int iv2 = ik2; + + for (int irep = 0; irep < nrep; ++irep) { + const int iq2 = ik2 + irep*nek2; + const int id2 = iq2; + + // (ik2 + irep*nek2) % nek2 == ik2 + for (int iq1 = 0; iq1 < neq1; ++iq1) { + const int id1 = iq1; + + // not sure about CACHE_LINE_SIZE_F32.. + // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? + float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); + float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); + + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } + + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { + // k indices + const int ik1 = ic; + + // S indices + const int i1 = ik1; + + ggml_vec_dot_f32(neq0, + S + i1, 0, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1); + } + + // scale + ggml_vec_scale_f32(masked_begin, S, scale); + + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; + } + + // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SM values to zero + { + float max = -INFINITY; + ggml_vec_max_f32(masked_begin, &max, S); + + ggml_float sum = 0.0; + { +#ifdef GGML_SOFT_MAX_ACCELERATE + max = -max; + vDSP_vsadd(SM, 1, &max, SM, 1, Mup); + vvexpf(SM, SM, &Mup); + ggml_vec_sum_f32(Mup, &sum, SM); +#else + sum = ggml_vec_soft_max_f32(Mup, SM, S, max); +#endif + } + + assert(sum > 0.0); + + sum = 1.0/sum; + ggml_vec_scale_f32(masked_begin, SM, sum); + + } + + // step-by-step explanation + { + // forward-process shape grads from backward process + // parallel_for ik2,ik3: + // for irep: + // iq2 = ik2 + irep*nek2 + // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur] + // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] + // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur] + // for iq1: + // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur + // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur + // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 + // S0 = -Inf [D,1,1,1] + // ~S1[i] = dot(kcur[:D,i], qcur) + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale + // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) + // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3] + // ~dst[i,iq1,iq2,iq3] = S5[i] ^ + // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3] + // dst backward-/ grad[dst] = d + // + // output gradients with their dependencies: + // + // grad[kcur] = grad[S1].T @ qcur + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S4] = grad[S5] @ vcur + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[qcur] = grad[S1] @ kcur + // grad[vcur] = grad[S5].T @ S4 + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // in post-order: + // + // S1 = qcur @ kcur.T + // S2 = S1 * scale + // S3 = diag_mask_inf(S2, P) + // S4 = softmax(S3) + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[qcur] = grad[S1] @ kcur + // grad[kcur] = grad[S1].T @ qcur + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // using less variables (SM=S4): + // + // S = diag_mask_inf(qcur @ kcur.T * scale, P) + // SM = softmax(S) + // S = d[:D,iq1,iq2,iq3] @ vcur + // dot_SM_gradSM = dot(SM, S) + // S = SM * (S - dot(SM, S)) + // S = diag_mask_zero(S, P) * scale + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[k][:D,:M,ik2,ik3] += S.T @ qcur + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + } + + // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // for ic: + // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] + // exclude known future zero S[..] values from operation + ggml_vec_set_f32(masked_begin, S, 0); + for (int64_t ic = 0; ic < D; ++ic) { + ggml_vec_mad_f32(masked_begin, + S, + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } + + // S = SM * (S - dot(SM, S)) + float dot_SM_gradSM = 0; + ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1); + ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); + ggml_vec_mul_f32 (masked_begin, S, S, SM); + + // S = diag_mask_zero(S, P) * scale + // already done by above ggml_vec_set_f32 + + // exclude known zero S[..] values from operation + ggml_vec_scale_f32(masked_begin, S, scale); + + // S shape [M,1] + // SM shape [M,1] + // kcur shape [D,M] + // qcur shape [D,1] + // vcur shape [M,D] + + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] + // for ic: + // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + ggml_vec_mad_f32(D, + (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), + (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + S[ic]); + } + + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // for ic: + // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] + // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + ggml_vec_mad_f32(D, + (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), + S[ic]); + } + + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + // for ic: + // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] + // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] + // exclude known zero SM[..] values from mad + for (int64_t ic = 0; ic < D; ++ic) { + ggml_vec_mad_f32(masked_begin, + (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), + SM, + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } + } + } + } +} + +static void ggml_compute_forward_flash_attn_back( + const struct ggml_compute_params * params, + const bool masked, + struct ggml_tensor * dst) { + + const struct ggml_tensor * q = dst->src[0]; + + switch (q->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_flash_attn_back_f32(params, masked, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_ssm_conv + +static void ggml_compute_forward_ssm_conv_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t + const int nr = src0->ne[1]; // d_inner + const int n_t = dst->ne[1]; // tokens per sequence + const int n_s = dst->ne[2]; // number of sequences in the batch + + GGML_ASSERT( dst->ne[0] == nr); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + const int ir = ir1 - ir0; + + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + // {d_conv - 1 + n_t, d_inner, n_seqs} + // sliding window + const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} + const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} + float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} + + // TODO: transpose the output for smaller strides for big batches? + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // rowwise dot product + // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision + float sumf = 0.0f; + + // d_conv + for (int i0 = 0; i0 < nc; ++i0) { + sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; + } + x[i1] = sumf; + } + } + } +} + +static void ggml_compute_forward_ssm_conv( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_conv_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_ssm_scan + +static void ggml_compute_forward_ssm_scan_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; // s + const struct ggml_tensor * src1 = dst->src[1]; // x + const struct ggml_tensor * src2 = dst->src[2]; // dt + const struct ggml_tensor * src3 = dst->src[3]; // A + const struct ggml_tensor * src4 = dst->src[4]; // B + const struct ggml_tensor * src5 = dst->src[5]; // C + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = src1->ne[1]; // number of tokens per sequence + const int64_t n_s = src0->ne[2]; // number of sequences in the batch + + GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(float)); + GGML_ASSERT(src4->nb[0] == sizeof(float)); + GGML_ASSERT(src5->nb[0] == sizeof(float)); + // required for the dot product between s and C + GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); + // required for per-sequence offsets for states + GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); + // required to get correct offset for state destination (i.e. src1->nb[3]) + GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + const int ir = ir1 - ir0; + + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; + } + } + } +} + +static void ggml_compute_forward_ssm_scan( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_ssm_scan_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_win_part + +static void ggml_compute_forward_win_part_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + UNUSED(params); + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + + const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t w = ((const int32_t *)(dst->op_params))[2]; + + assert(ne00 == ne0); + assert(ne3 == nep0*nep1); + + // TODO: optimize / multi-thread + for (int py = 0; py < nep1; ++py) { + for (int px = 0; px < nep0; ++px) { + const int64_t i3 = py*nep0 + px; + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i02 = py*w + i2; + const int64_t i01 = px*w + i1; + const int64_t i00 = i0; + + const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; + const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; + + if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { + ((float *) dst->data)[i] = 0.0f; + } else { + ((float *) dst->data)[i] = ((float *) src0->data)[j]; + } + } + } + } + } + } +} + +static void ggml_compute_forward_win_part( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_part_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_win_unpart + +static void ggml_compute_forward_win_unpart_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + UNUSED(params); + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + + const int32_t w = ((const int32_t *)(dst->op_params))[0]; + + // padding + const int px = (w - ne1%w)%w; + //const int py = (w - ne2%w)%w; + + const int npx = (px + ne1)/w; + //const int npy = (py + ne2)/w; + + assert(ne0 == ne00); + + // TODO: optimize / multi-thread + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int ip2 = i2/w; + const int ip1 = i1/w; + + const int64_t i02 = i2%w; + const int64_t i01 = i1%w; + const int64_t i00 = i0; + + const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; + const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; + + ((float *) dst->data)[j] = ((float *) src0->data)[i]; + } + } + } +} + +static void ggml_compute_forward_win_unpart( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_win_unpart_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +//gmml_compute_forward_unary + +static void ggml_compute_forward_unary( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const enum ggml_unary_op op = ggml_get_unary_op(dst); + + switch (op) { + case GGML_UNARY_OP_ABS: + { + ggml_compute_forward_abs(params, dst); + } break; + case GGML_UNARY_OP_SGN: + { + ggml_compute_forward_sgn(params, dst); + } break; + case GGML_UNARY_OP_NEG: + { + ggml_compute_forward_neg(params, dst); + } break; + case GGML_UNARY_OP_STEP: + { + ggml_compute_forward_step(params, dst); + } break; + case GGML_UNARY_OP_TANH: + { + ggml_compute_forward_tanh(params, dst); + } break; + case GGML_UNARY_OP_ELU: + { + ggml_compute_forward_elu(params, dst); + } break; + case GGML_UNARY_OP_RELU: + { + ggml_compute_forward_relu(params, dst); + } break; + case GGML_UNARY_OP_SIGMOID: + { + ggml_compute_forward_sigmoid(params, dst); + } break; + case GGML_UNARY_OP_GELU: + { + ggml_compute_forward_gelu(params, dst); + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + ggml_compute_forward_gelu_quick(params, dst); + } break; + case GGML_UNARY_OP_SILU: + { + ggml_compute_forward_silu(params, dst); + } break; + case GGML_UNARY_OP_HARDSWISH: + { + ggml_compute_forward_hardswish(params, dst); + } break; + case GGML_UNARY_OP_HARDSIGMOID: + { + ggml_compute_forward_hardsigmoid(params, dst); + } break; + case GGML_UNARY_OP_EXP: + { + ggml_compute_forward_exp(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_get_rel_pos + +static void ggml_compute_forward_get_rel_pos_f16( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + UNUSED(params); + + const struct ggml_tensor * src0 = dst->src[0]; + + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 + + GGML_TENSOR_UNARY_OP_LOCALS + + const int64_t w = ne1; + + ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; + ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; + + for (int64_t i2 = 0; i2 < ne2; ++i2) { + for (int64_t i1 = 0; i1 < ne1; ++i1) { + const int64_t pos = (w - i1 - 1) + i2; + for (int64_t i0 = 0; i0 < ne0; ++i0) { + dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; + } + } + } +} + +static void ggml_compute_forward_get_rel_pos( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + { + ggml_compute_forward_get_rel_pos_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_add_rel_pos + +static void ggml_compute_forward_add_rel_pos_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; + + const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; + if (!inplace) { + if (params->ith == 0) { + memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); + } + ggml_barrier(params->threadpool); + } + // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 + + float * src1_data = (float *) src1->data; + float * src2_data = (float *) src2->data; + float * dst_data = (float *) dst->data; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int ith = params->ith; + const int nth = params->nth; + + // total patches in dst + const int np = ne13; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + const int ip0 = dp*ith; + const int ip1 = MIN(ip0 + dp, np); + + for (int64_t i13 = ip0; i13 < ip1; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10; + for (int64_t i10 = 0; i10 < ne10; ++i10) { + const int64_t jp0 = jp1 + i10; + const float src1_e = src1_data[jp0]; + const float src2_e = src2_data[jp0]; + + const int64_t jdh = jp0 * ne10; + const int64_t jdw = jdh - (ne10 - 1) * i10; + + for (int64_t j = 0; j < ne10; ++j) { + dst_data[jdh + j ] += src2_e; + dst_data[jdw + j*ne10] += src1_e; + } + } + } + } + } +} + +static void ggml_compute_forward_add_rel_pos( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_add_rel_pos_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_rwkv_wkv6 + +static void ggml_compute_forward_rwkv_wkv6_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const int64_t T = dst->src[1]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t HEADS = dst->src[1]->ne[1]; + const int64_t n_seqs = dst->src[5]->ne[1]; + const int64_t head_size = C / HEADS; + + float * dst_data = (float *) dst->data; + float * state = ((float *) dst->data) + C * T; + + const int ith = params->ith; + const int nth = params->nth; + + if (ith >= HEADS) { + return; + } + + const int h_start = (HEADS * ith) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; + + float * k = (float *) dst->src[0]->data; + float * v = (float *) dst->src[1]->data; + float * r = (float *) dst->src[2]->data; + float * time_faaaa = (float *) dst->src[3]->data; + float * time_decay = (float *) dst->src[4]->data; + + size_t t_stride = HEADS * head_size; // Same to C + + size_t h_stride = C / HEADS; + GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS + size_t h_stride_2d = head_size * head_size; + + if (ith == 0) { + memset(dst_data, 0, T * C * sizeof(float)); + } + ggml_barrier(params->threadpool); + + + #if defined(__AVX__) && !defined(__AVX512F__) + #define GGML_F32X GGML_F32x8 + #define GGML_F32X_SET1 GGML_F32x8_SET1 + #define GGML_F32X_LOAD GGML_F32x8_LOAD + #define GGML_F32X_STORE GGML_F32x8_STORE + #define GGML_F32X_MUL GGML_F32x8_MUL + #define GGML_F32X_FMA GGML_F32x8_FMA + #define WKV_VECTOR_SIZE 8 + #elif defined(__AVX512F__) + #define GGML_F32X GGML_F32x16 + #define GGML_F32X_SET1 GGML_F32x16_SET1 + #define GGML_F32X_LOAD GGML_F32x16_LOAD + #define GGML_F32X_STORE GGML_F32x16_STORE + #define GGML_F32X_MUL GGML_F32x16_MUL + #define GGML_F32X_FMA GGML_F32x16_FMA + #define WKV_VECTOR_SIZE 16 + #elif defined(__ARM_NEON) && defined(__aarch64__) + #define GGML_F32X GGML_F32x4 + #define GGML_F32X_SET1 GGML_F32x4_SET1 + #define GGML_F32X_LOAD GGML_F32x4_LOAD + #define GGML_F32X_STORE GGML_F32x4_STORE + #define GGML_F32X_MUL GGML_F32x4_MUL + #define GGML_F32X_FMA GGML_F32x4_FMA + #define WKV_VECTOR_SIZE 4 + #endif + + #ifdef WKV_VECTOR_SIZE + const int64_t vec_count = head_size / WKV_VECTOR_SIZE; + + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + float time_decay_val = time_decay[t_h_i_offset]; + + // Broadcast scalar values to vectors + GGML_F32X k_vec = GGML_F32X_SET1(k_val); + GGML_F32X r_vec = GGML_F32X_SET1(r_val); + GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val); + GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val); + + for (int64_t j = 0; j < vec_count; j++) { + size_t base_j = j * WKV_VECTOR_SIZE; + size_t t_h_j_offset = t_h_offset + base_j; + size_t h_2d_i_j_offset = h_2d_i_offset + base_j; + + // Load x elements at once + GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]); + GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]); + GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]); + + // Compute kv = v * k + GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec); + + // Compute temp = kv * time_faaaa + prev_state + GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec); + + // Update dst: dst += temp * r + dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec); + GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec); + + // Update state: state = prev_state * time_decay + kv + GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec); + GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec); + } + + // Handle remaining elements, this will not be used. + for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } + } + } + } + + #else + // basically fused operations: + // dst = r @ (time_faaaa * (k @ v) + state), + // state = time_decay * state + (k @ v), + // recursive through each token + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_i_offset = h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float r_val = r[t_h_i_offset]; + float time_faaaa_val = time_faaaa[h_i_offset]; + // RWKV v6: different time_decay for each token. + float time_decay_val = time_decay[t_h_i_offset]; + + for (int64_t j = 0; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val * time_faaaa_val + prev_state_val; + dst_data[t_h_j_offset] += temp_val * r_val; + state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; + } + } + } + } + #endif +} + + +static void ggml_compute_forward_rwkv_wkv6( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rwkv_wkv6_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_gla + +static void ggml_compute_forward_gla_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + const int64_t T = dst->src[1]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t HEADS = dst->src[1]->ne[1]; + const int64_t n_seqs = dst->src[4]->ne[1]; + const int64_t head_size = C / HEADS; + const float scale = ggml_get_op_params_f32(dst, 0); + + float * dst_data = (float *) dst->data; + float * state = ((float *) dst->data) + C * T; + + const int ith = params->ith; + const int nth = params->nth; + + if (ith >= HEADS) { + return; + } + + const int h_start = (HEADS * ith) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; + + float * k = (float *) dst->src[0]->data; + float * v = (float *) dst->src[1]->data; + float * q = (float *) dst->src[2]->data; + float * g = (float *) dst->src[3]->data; + + size_t t_stride = HEADS * head_size; // Same to C + + size_t h_stride = C / HEADS; + GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS + size_t h_stride_2d = head_size * head_size; + + if (ith == 0) { + memset(dst_data, 0, T * C * sizeof(float)); + } + ggml_barrier(params->threadpool); + + + #if defined(__AVX__) && !defined(__AVX512F__) + #define GGML_F32X GGML_F32x8 + #define GGML_F32X_SET1 GGML_F32x8_SET1 + #define GGML_F32X_LOAD GGML_F32x8_LOAD + #define GGML_F32X_STORE GGML_F32x8_STORE + #define GGML_F32X_MUL GGML_F32x8_MUL + #define GGML_F32X_FMA GGML_F32x8_FMA + #define GLA_VECTOR_SIZE 8 + #elif defined(__AVX512F__) + #define GGML_F32X GGML_F32x16 + #define GGML_F32X_SET1 GGML_F32x16_SET1 + #define GGML_F32X_LOAD GGML_F32x16_LOAD + #define GGML_F32X_STORE GGML_F32x16_STORE + #define GGML_F32X_MUL GGML_F32x16_MUL + #define GGML_F32X_FMA GGML_F32x16_FMA + #define GLA_VECTOR_SIZE 16 + #elif defined(__ARM_NEON) && defined(__aarch64__) + #define GGML_F32X GGML_F32x4 + #define GGML_F32X_SET1 GGML_F32x4_SET1 + #define GGML_F32X_LOAD GGML_F32x4_LOAD + #define GGML_F32X_STORE GGML_F32x4_STORE + #define GGML_F32X_MUL GGML_F32x4_MUL + #define GGML_F32X_FMA GGML_F32x4_FMA + #define GLA_VECTOR_SIZE 4 + #endif + + #ifdef GLA_VECTOR_SIZE + const int64_t vec_count = head_size / GLA_VECTOR_SIZE; + + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float q_val = q[t_h_i_offset] * scale; + float g_val = g[t_h_i_offset]; + + // Broadcast scalar values to vectors + GGML_F32X k_vec = GGML_F32X_SET1(k_val); + GGML_F32X q_vec = GGML_F32X_SET1(q_val); + GGML_F32X g_vec = GGML_F32X_SET1(g_val); + + for (int64_t j = 0; j < vec_count; j++) { + size_t base_j = j * GLA_VECTOR_SIZE; + size_t t_h_j_offset = t_h_offset + base_j; + size_t h_2d_i_j_offset = h_2d_i_offset + base_j; + + // Load x elements at once + GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]); + GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]); + GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]); + + // Compute kv = v * k + GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec); + + // Compute temp = prev_state * g + kv + GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec); + + // Update dst: dst += temp * q + dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec); + GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec); + + // Update state + GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec); + } + + // Handle remaining elements, this will not be used. + for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = kv_val + prev_state_val * g_val; + dst_data[t_h_j_offset] += temp_val * q_val; + state_cur[h_2d_i_j_offset] = temp_val; + } + } + } + } + + #else + for (int64_t t = 0; t < T; t++) { + size_t t_offset = t * t_stride; + size_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + size_t h_offset = h * h_stride; + size_t t_h_offset = t_offset + h_offset; + size_t h_2d_offset = h * h_stride_2d; + + for (int64_t i = 0; i < head_size; i++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + + float k_val = k[t_h_i_offset]; + float q_val = q[t_h_i_offset] * scale; + float g_val = g[t_h_i_offset]; + + for (int64_t j = 0; j < head_size; j++) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + + float v_val = v[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + float temp_val = prev_state_val * g_val + kv_val; + dst_data[t_h_j_offset] += temp_val * q_val; + state_cur[h_2d_i_j_offset] = temp_val; + } + } + } + } + #endif +} + + +static void ggml_compute_forward_gla( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gla_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_map_unary + +static void ggml_compute_forward_map_unary_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const ggml_unary_op_f32_t fun) { + + const struct ggml_tensor * src0 = dst->src[0]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void ggml_compute_forward_map_unary( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const ggml_unary_op_f32_t fun) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_map_unary_f32(params, dst, fun); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_map_binary + +static void ggml_compute_forward_map_binary_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const ggml_binary_op_f32_t fun) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (params->ith != 0) { + return; + } + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(src1)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + + const int n = ggml_nrows(src0); + const int nc = src0->ne[0]; + + for (int i = 0; i < n; i++) { + fun(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1])), + (float *) ((char *) src1->data + i*(src1->nb[1]))); + } +} + +static void ggml_compute_forward_map_binary( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const ggml_binary_op_f32_t fun) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_map_binary_f32(params, dst, fun); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_map_custom1 + +static void ggml_compute_forward_map_custom1_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const ggml_custom1_op_f32_t fun) { + + const struct ggml_tensor * a = dst->src[0]; + + if (params->ith != 0) { + return; + } + + fun(dst, a); +} + +// ggml_compute_forward_map_custom2 + +static void ggml_compute_forward_map_custom2_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const ggml_custom2_op_f32_t fun) { + + const struct ggml_tensor * a = dst->src[0]; + const struct ggml_tensor * b = dst->src[1]; + + if (params->ith != 0) { + return; + } + + fun(dst, a, b); +} + +// ggml_compute_forward_map_custom3 + +static void ggml_compute_forward_map_custom3_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst, + const ggml_custom3_op_f32_t fun) { + + const struct ggml_tensor * a = dst->src[0]; + const struct ggml_tensor * b = dst->src[1]; + const struct ggml_tensor * c = dst->src[1]; + + if (params->ith != 0) { + return; + } + + fun(dst, a, b, c); +} + +// ggml_compute_forward_map_custom1 + +static void ggml_compute_forward_map_custom1( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * a = dst->src[0]; + + struct ggml_map_custom1_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); + + p.fun(dst, a, params->ith, params->nth, p.userdata); +} + +// ggml_compute_forward_map_custom2 + +static void ggml_compute_forward_map_custom2( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * a = dst->src[0]; + const struct ggml_tensor * b = dst->src[1]; + + struct ggml_map_custom2_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); + + p.fun(dst, a, b, params->ith, params->nth, p.userdata); +} + +// ggml_compute_forward_map_custom3 + +static void ggml_compute_forward_map_custom3( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * a = dst->src[0]; + const struct ggml_tensor * b = dst->src[1]; + const struct ggml_tensor * c = dst->src[2]; + + struct ggml_map_custom3_op_params p; + memcpy(&p, dst->op_params, sizeof(p)); + + p.fun(dst, a, b, c, params->ith, params->nth, p.userdata); +} + +// ggml_compute_forward_cross_entropy_loss + +static void ggml_compute_forward_cross_entropy_loss_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_scalar(dst)); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + // TODO: handle transposed/permuted matrices + const int64_t nc = src0->ne[0]; + const int64_t nr = ggml_nrows(src0); + + const int ith = params->ith; + const int nth = params->nth; + + float * sums = (float *) params->wdata; + float * st = ((float *) params->wdata) + nth + ith*nc; + float sum_thread = 0.0f; + + GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + for (int64_t i1 = ir0; i1 < ir1; ++i1) { + const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]); + const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]); + +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max); + assert(sum_softmax >= 0.0); + + ggml_vec_add1_f32(nc, st, st, -sum_softmax); + ggml_vec_mul_f32(nc, st, st, s1); + + float sum_st = 0.0f; + ggml_vec_sum_f32(nc, &sum_st, st); + sum_thread += sum_st; + +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + assert(!isnan(st[i])); + assert(!isinf(st[i])); + } +#endif + } + sums[ith] = sum_thread; + ggml_barrier(params->threadpool); + + if (ith == 0) { + float * dp = (float *) dst->data; + ggml_vec_sum_f32(nth, dp, sums); + dp[0] *= -1.0f / (float) nr; + } +} + +static void ggml_compute_forward_cross_entropy_loss( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_cross_entropy_loss_back + +static void ggml_compute_forward_cross_entropy_loss_back_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * grad = dst->src[0]; // gradient of forward pass output + const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass + const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); + GGML_ASSERT(ggml_is_contiguous(grad)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst)); + + const int64_t ith = params->ith; + const int64_t nth = params->nth; + + // TODO: handle transposed/permuted matrices + const int64_t nc = src0f->ne[0]; + const int64_t nr = ggml_nrows(src0f); + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + const float d_by_nr = ((const float *) grad->data)[0] / (float) nr; + + for (int64_t i1 = ir0; i1 < ir1; i1++) { + float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); + const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]); + const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]); + +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + //printf("p[%d] = %f\n", i, p[i]); + assert(!isnan(s0[i])); + assert(!isnan(s1[i])); + } +#endif + + // soft_max + float max = -INFINITY; + ggml_vec_max_f32(nc, &max, s0); + const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); + assert(sum > 0.0); + ggml_vec_scale_f32(nc, ds0, 1.0/sum); + + // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr + ggml_vec_sub_f32(nc, ds0, ds0, s1); + ggml_vec_scale_f32(nc, ds0, d_by_nr); + +#ifndef NDEBUG + for (int64_t i = 0; i < nc; ++i) { + assert(!isnan(ds0[i])); + assert(!isinf(ds0[i])); + } +#endif + } +} + +static void ggml_compute_forward_cross_entropy_loss_back( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_cross_entropy_loss_back_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_compute_forward_opt_step_adamw_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src0_grad = dst->src[1]; + const struct ggml_tensor * src0_grad_m = dst->src[2]; + const struct ggml_tensor * src0_grad_v = dst->src[3]; + const struct ggml_tensor * adamw_params = dst->src[4]; + + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v)); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const float * adamw_params_ptr = ggml_get_data_f32(adamw_params); + const float alpha = adamw_params_ptr[0]; + const float beta1 = adamw_params_ptr[1]; + const float beta2 = adamw_params_ptr[2]; + const float eps = adamw_params_ptr[3]; + const float wd = adamw_params_ptr[4]; + const float beta1h = adamw_params_ptr[5]; + const float beta2h = adamw_params_ptr[6]; + + for (int ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const size_t offset = i03*nb03 + i02*nb02 + i01*nb01; + + float * w = (float *) ((char *) src0->data + offset); // weight + const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad + float * m = (float *) ((char *) src0_grad_m->data + offset); + float * v = (float *) ((char *) src0_grad_v->data + offset); + + for (int i00 = 0; i00 < ne00; ++i00) { + m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1); + v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2); + + const float mh = m[i00]*beta1h; + const float vh = sqrtf(v[i00]*beta2h) + eps; + + // The weight decay is applied independently of the Adam momenta m and v. + // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. + // See: https://arxiv.org/pdf/1711.05101v3.pdf + w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh; + } + } +} + +static void ggml_compute_forward_opt_step_adamw( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_opt_step_adamw_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} +///////////////////////////////// + +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { + GGML_ASSERT(params); + + if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { + return; + } + + // extra_buffer op? + if (ggml_cpu_extra_compute_forward(params, tensor)) return; + + switch (tensor->op) { + case GGML_OP_DUP: + { + ggml_compute_forward_dup(params, tensor); + } break; + case GGML_OP_ADD: + { + ggml_compute_forward_add(params, tensor); + } break; + case GGML_OP_ADD1: + { + ggml_compute_forward_add1(params, tensor); + } break; + case GGML_OP_ACC: + { + ggml_compute_forward_acc(params, tensor); + } break; + case GGML_OP_SUB: + { + ggml_compute_forward_sub(params, tensor); + } break; + case GGML_OP_MUL: + { + ggml_compute_forward_mul(params, tensor); + } break; + case GGML_OP_DIV: + { + ggml_compute_forward_div(params, tensor); + } break; + case GGML_OP_SQR: + { + ggml_compute_forward_sqr(params, tensor); + } break; + case GGML_OP_SQRT: + { + ggml_compute_forward_sqrt(params, tensor); + } break; + case GGML_OP_LOG: + { + ggml_compute_forward_log(params, tensor); + } break; + case GGML_OP_SIN: + { + ggml_compute_forward_sin(params, tensor); + } break; + case GGML_OP_COS: + { + ggml_compute_forward_cos(params, tensor); + } break; + case GGML_OP_SUM: + { + ggml_compute_forward_sum(params, tensor); + } break; + case GGML_OP_SUM_ROWS: + { + ggml_compute_forward_sum_rows(params, tensor); + } break; + case GGML_OP_MEAN: + { + ggml_compute_forward_mean(params, tensor); + } break; + case GGML_OP_ARGMAX: + { + ggml_compute_forward_argmax(params, tensor); + } break; + case GGML_OP_COUNT_EQUAL: + { + ggml_compute_forward_count_equal(params, tensor); + } break; + case GGML_OP_REPEAT: + { + ggml_compute_forward_repeat(params, tensor); + } break; + case GGML_OP_REPEAT_BACK: + { + ggml_compute_forward_repeat_back(params, tensor); + } break; + case GGML_OP_CONCAT: + { + ggml_compute_forward_concat(params, tensor); + } break; + case GGML_OP_SILU_BACK: + { + ggml_compute_forward_silu_back(params, tensor); + } break; + case GGML_OP_NORM: + { + ggml_compute_forward_norm(params, tensor); + } break; + case GGML_OP_RMS_NORM: + { + ggml_compute_forward_rms_norm(params, tensor); + } break; + case GGML_OP_RMS_NORM_BACK: + { + ggml_compute_forward_rms_norm_back(params, tensor); + } break; + case GGML_OP_GROUP_NORM: + { + ggml_compute_forward_group_norm(params, tensor); + } break; + case GGML_OP_MUL_MAT: + { + ggml_compute_forward_mul_mat(params, tensor); + } break; + case GGML_OP_MUL_MAT_ID: + { + ggml_compute_forward_mul_mat_id(params, tensor); + } break; + case GGML_OP_OUT_PROD: + { + ggml_compute_forward_out_prod(params, tensor); + } break; + case GGML_OP_SCALE: + { + ggml_compute_forward_scale(params, tensor); + } break; + case GGML_OP_SET: + { + ggml_compute_forward_set(params, tensor); + } break; + case GGML_OP_CPY: + { + ggml_compute_forward_cpy(params, tensor); + } break; + case GGML_OP_CONT: + { + ggml_compute_forward_cont(params, tensor); + } break; + case GGML_OP_RESHAPE: + { + ggml_compute_forward_reshape(params, tensor); + } break; + case GGML_OP_VIEW: + { + ggml_compute_forward_view(params, tensor); + } break; + case GGML_OP_PERMUTE: + { + ggml_compute_forward_permute(params, tensor); + } break; + case GGML_OP_TRANSPOSE: + { + ggml_compute_forward_transpose(params, tensor); + } break; + case GGML_OP_GET_ROWS: + { + ggml_compute_forward_get_rows(params, tensor); + } break; + case GGML_OP_GET_ROWS_BACK: + { + ggml_compute_forward_get_rows_back(params, tensor); + } break; + case GGML_OP_DIAG: + { + ggml_compute_forward_diag(params, tensor); + } break; + case GGML_OP_DIAG_MASK_INF: + { + ggml_compute_forward_diag_mask_inf(params, tensor); + } break; + case GGML_OP_DIAG_MASK_ZERO: + { + ggml_compute_forward_diag_mask_zero(params, tensor); + } break; + case GGML_OP_SOFT_MAX: + { + ggml_compute_forward_soft_max(params, tensor); + } break; + case GGML_OP_SOFT_MAX_BACK: + { + ggml_compute_forward_soft_max_ext_back(params, tensor); + } break; + case GGML_OP_ROPE: + { + ggml_compute_forward_rope(params, tensor); + } break; + case GGML_OP_ROPE_BACK: + { + ggml_compute_forward_rope_back(params, tensor); + } break; + case GGML_OP_CLAMP: + { + ggml_compute_forward_clamp(params, tensor); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + ggml_compute_forward_conv_transpose_1d(params, tensor); + } break; + case GGML_OP_IM2COL: + { + ggml_compute_forward_im2col(params, tensor); + } break; + case GGML_OP_IM2COL_BACK: + { + ggml_compute_forward_im2col_back_f32(params, tensor); + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + ggml_compute_forward_conv_transpose_2d(params, tensor); + } break; + case GGML_OP_POOL_1D: + { + ggml_compute_forward_pool_1d(params, tensor); + } break; + case GGML_OP_POOL_2D: + { + ggml_compute_forward_pool_2d(params, tensor); + } break; + case GGML_OP_POOL_2D_BACK: + { + ggml_compute_forward_pool_2d_back(params, tensor); + } break; + case GGML_OP_UPSCALE: + { + ggml_compute_forward_upscale(params, tensor); + } break; + case GGML_OP_PAD: + { + ggml_compute_forward_pad(params, tensor); + } break; + case GGML_OP_PAD_REFLECT_1D: + { + ggml_compute_forward_pad_reflect_1d(params, tensor); + } break; + case GGML_OP_ARANGE: + { + ggml_compute_forward_arange(params, tensor); + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + ggml_compute_forward_timestep_embedding(params, tensor); + } break; + case GGML_OP_ARGSORT: + { + ggml_compute_forward_argsort(params, tensor); + } break; + case GGML_OP_LEAKY_RELU: + { + ggml_compute_forward_leaky_relu(params, tensor); + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + int32_t t = ggml_get_op_params_i32(tensor, 0); + GGML_ASSERT(t == 0 || t == 1); + bool masked = t != 0; + ggml_compute_forward_flash_attn_back(params, masked, tensor); + } break; + case GGML_OP_SSM_CONV: + { + ggml_compute_forward_ssm_conv(params, tensor); + } break; + case GGML_OP_SSM_SCAN: + { + ggml_compute_forward_ssm_scan(params, tensor); + } break; + case GGML_OP_WIN_PART: + { + ggml_compute_forward_win_part(params, tensor); + } break; + case GGML_OP_WIN_UNPART: + { + ggml_compute_forward_win_unpart(params, tensor); + } break; + case GGML_OP_UNARY: + { + ggml_compute_forward_unary(params, tensor); + } break; + case GGML_OP_GET_REL_POS: + { + ggml_compute_forward_get_rel_pos(params, tensor); + } break; + case GGML_OP_ADD_REL_POS: + { + ggml_compute_forward_add_rel_pos(params, tensor); + } break; + case GGML_OP_RWKV_WKV6: + { + ggml_compute_forward_rwkv_wkv6(params, tensor); + } break; + case GGML_OP_GATED_LINEAR_ATTN: + { + ggml_compute_forward_gla(params, tensor); + } break; + case GGML_OP_MAP_UNARY: + { + ggml_unary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_unary(params, tensor, fun); + } + break; + case GGML_OP_MAP_BINARY: + { + ggml_binary_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_binary(params, tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM1_F32: + { + ggml_custom1_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom1_f32(params, tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM2_F32: + { + ggml_custom2_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom2_f32(params, tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM3_F32: + { + ggml_custom3_op_f32_t fun; + memcpy(&fun, tensor->op_params, sizeof(fun)); + ggml_compute_forward_map_custom3_f32(params, tensor, fun); + } + break; + case GGML_OP_MAP_CUSTOM1: + { + ggml_compute_forward_map_custom1(params, tensor); + } + break; + case GGML_OP_MAP_CUSTOM2: + { + ggml_compute_forward_map_custom2(params, tensor); + } + break; + case GGML_OP_MAP_CUSTOM3: + { + ggml_compute_forward_map_custom3(params, tensor); + } + break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + ggml_compute_forward_cross_entropy_loss(params, tensor); + } + break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + ggml_compute_forward_cross_entropy_loss_back(params, tensor); + } + break; + case GGML_OP_OPT_STEP_ADAMW: + { + ggml_compute_forward_opt_step_adamw(params, tensor); + } + break; + case GGML_OP_NONE: + { + // nop + } break; + case GGML_OP_COUNT: + { + GGML_ABORT("fatal error"); + } + } +} + +// Android's libc implementation "bionic" does not support setting affinity +#if defined(__gnu_linux__) +static void set_numa_thread_affinity(int thread_n) { + if (!ggml_is_numa()) { + return; + } + + int node_num; + int rv; + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + switch(g_state.numa.numa_strategy) { + case GGML_NUMA_STRATEGY_DISTRIBUTE: + // run thread on node_num thread_n / (threads per node) + node_num = thread_n % g_state.numa.n_nodes; + break; + case GGML_NUMA_STRATEGY_ISOLATE: + // run thread on current_node + node_num = g_state.numa.current_node; + break; + case GGML_NUMA_STRATEGY_NUMACTL: + // use the cpuset that numactl gave us + rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv)); + } + return; + default: + return; + } + + struct ggml_numa_node * node = &g_state.numa.nodes[node_num]; + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (size_t i = 0; i < node->n_cpus; ++i) { + CPU_SET_S(node->cpus[i], setsize, cpus); + } + + rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv)); + } + + CPU_FREE(cpus); +} + +static void clear_numa_thread_affinity(void) { + if (!ggml_is_numa()) { + return; + } + + size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); + + cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); + CPU_ZERO_S(setsize, cpus); + for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) { + CPU_SET_S(i, setsize, cpus); + } + + int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); + if (rv) { + fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv)); + } + + CPU_FREE(cpus); +} +#else +// TODO: Windows etc. +// (the linux implementation may also work on BSD, someone should test) +static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } +static void clear_numa_thread_affinity(void) {} +#endif + +static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { + int n_tasks = 0; + + if (ggml_is_empty(node)) { + // no need to multi-thread a no-op + n_tasks = 1; + return n_tasks; + } + + switch (node->op) { + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_ACC: + { + n_tasks = n_threads; + } break; + case GGML_OP_SUB: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + { + n_tasks = 1; + } break; + case GGML_OP_COUNT_EQUAL: + { + n_tasks = n_threads; + } break; + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_LEAKY_RELU: + { + n_tasks = 1; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: + { + n_tasks = 1; + } break; + + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + { + n_tasks = n_threads; + } break; + default: + GGML_ABORT("fatal error"); + } + break; + case GGML_OP_SILU_BACK: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: + case GGML_OP_CONCAT: + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_OUT_PROD: + { + n_tasks = n_threads; + } break; + case GGML_OP_GET_ROWS: + { + // FIXME: get_rows can use additional threads, but the cost of launching additional threads + // decreases performance with GPU offloading + //n_tasks = n_threads; + n_tasks = 1; + } break; + case GGML_OP_SCALE: + case GGML_OP_SET: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS_BACK: + case GGML_OP_DIAG: + { + n_tasks = 1; + } break; + case GGML_OP_DIAG_MASK_ZERO: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_ADD_REL_POS: + { + n_tasks = n_threads; + } break; + case GGML_OP_CLAMP: + { + n_tasks = 1; //TODO + } break; + case GGML_OP_SOFT_MAX: + { + n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); + } break; + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_BACK: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_CONV_TRANSPOSE_2D: + { + n_tasks = n_threads; + } break; + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + case GGML_OP_POOL_2D_BACK: + { + n_tasks = 1; + } break; + case GGML_OP_UPSCALE: + case GGML_OP_PAD: + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_ARGSORT: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_FLASH_ATTN_BACK: + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_MAP_UNARY: + case GGML_OP_MAP_BINARY: + case GGML_OP_MAP_CUSTOM1_F32: + case GGML_OP_MAP_CUSTOM2_F32: + case GGML_OP_MAP_CUSTOM3_F32: + { + n_tasks = 1; + } break; + case GGML_OP_MAP_CUSTOM1: + { + struct ggml_map_custom1_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p.n_tasks, n_threads); + } + } break; + case GGML_OP_MAP_CUSTOM2: + { + struct ggml_map_custom2_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p.n_tasks, n_threads); + } + } break; + case GGML_OP_MAP_CUSTOM3: + { + struct ggml_map_custom3_op_params p; + memcpy(&p, node->op_params, sizeof(p)); + if (p.n_tasks == GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p.n_tasks, n_threads); + } + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_OPT_STEP_ADAMW: + { + n_tasks = n_threads; + } break; + case GGML_OP_NONE: + { + n_tasks = 1; + } break; + case GGML_OP_COUNT: + { + GGML_ABORT("fatal error"); + } + default: + { + fprintf(stderr, "%s: op not implemented: ", __func__); + if (node->op < GGML_OP_COUNT) { + fprintf(stderr, "%s\n", ggml_op_name(node->op)); + } else { + fprintf(stderr, "%d\n", node->op); + } + GGML_ABORT("fatal error"); + } + } + + assert(n_tasks > 0); + + return n_tasks; +} + +static thread_ret_t ggml_graph_compute_secondary_thread(void* data); + +#if defined(_WIN32) +#include "windows.h" + +// TODO: support > 64 CPUs +static bool ggml_thread_apply_affinity(bool * mask) { + HANDLE h = GetCurrentThread(); + uint64_t bitmask = 0ULL; + + assert(GGML_MAX_N_THREADS >= 64); + + for (int32_t i = 0; i < 8; i++) { + int32_t idx = i * 8; + uint8_t val = 0; + val |= mask[idx + 0] << 0; + val |= mask[idx + 1] << 1; + val |= mask[idx + 2] << 2; + val |= mask[idx + 3] << 3; + val |= mask[idx + 4] << 4; + val |= mask[idx + 5] << 5; + val |= mask[idx + 6] << 6; + val |= mask[idx + 7] << 7; + bitmask |= (uint64_t)val << idx; + } + + for (int32_t i = 64; i < GGML_MAX_N_THREADS; i++) { + if (mask[i]) { + fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n"); + break; + } + } + + DWORD_PTR m = (DWORD_PTR)bitmask; + + m = SetThreadAffinityMask(h, m); + + return m != 0; +} + +static bool ggml_thread_apply_priority(int32_t prio) { + // Note that on Windows the Process Priority Class must be updated in order to set Thread priority. + // This is up to the applications. + DWORD p = THREAD_PRIORITY_NORMAL; + switch (prio) { + case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break; + case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break; + case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break; + case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break; + } + + if (prio == GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + if (!SetThreadPriority(GetCurrentThread(), p)) { + fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError()); + return false; + } + + return true; +} + +#elif defined(__APPLE__) +#include +#include + +static bool ggml_thread_apply_affinity(const bool * mask) { + // Not supported on Apple platforms + UNUSED(mask); + return true; +} + +static bool ggml_thread_apply_priority(int32_t prio) { + struct sched_param p; + int32_t policy = SCHED_OTHER; + switch (prio) { + case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; + case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; + case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; + case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; + } + + if (prio == GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + int32_t err = pthread_setschedparam(pthread_self(), policy, &p); + if (err != 0) { + fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); + return false; + } + + return true; +} + +#elif defined(__gnu_linux__) +// TODO: this may not work on BSD, to be verified + +static bool ggml_thread_apply_affinity(const bool * mask) { + cpu_set_t cpuset; + int err; + + CPU_ZERO(&cpuset); + + for (uint32_t i = 0; i < GGML_MAX_N_THREADS; i++) { + if (mask[i]) { + GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i); + CPU_SET(i, &cpuset); + } + } + +#ifdef __ANDROID__ + err = sched_setaffinity(0, sizeof(cpuset), &cpuset); + if (err < 0) { + err = errno; + } +#else + err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset); +#endif + if (err != 0) { + fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err); + return false; + } + + return true; +} + +static bool ggml_thread_apply_priority(int32_t prio) { + struct sched_param p; + int32_t policy = SCHED_OTHER; + switch (prio) { + case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; + case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; + case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; + case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; + } + + if (prio == GGML_SCHED_PRIO_NORMAL) { + // Keep inherited policy/priority + return true; + } + + int32_t err = pthread_setschedparam(pthread_self(), policy, &p); + if (err != 0) { + fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); + return false; + } + + return true; +} + +#else // unsupported platforms + +static bool ggml_thread_apply_affinity(const bool * mask) { + UNUSED(mask); + return true; +} + +static bool ggml_thread_apply_priority(int32_t prio) { + UNUSED(prio); + return true; +} + +#endif + +static bool ggml_thread_cpumask_is_valid(const bool * mask) { + for (int i = 0; i < GGML_MAX_N_THREADS; i++) { + if (mask[i]) { return true; } + } + return false; +} + +static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) { + if (!strict) { + memcpy(local_mask, global_mask, GGML_MAX_N_THREADS); + return; + } else { + memset(local_mask, 0, GGML_MAX_N_THREADS); + int32_t base_idx = *iter; + for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) { + int32_t idx = base_idx + i; + if (idx >= GGML_MAX_N_THREADS) { + // Just a cheaper modulo + idx -= GGML_MAX_N_THREADS; + } + if (global_mask[idx]) { + local_mask[idx] = 1; + *iter = idx + 1; + return; + } + } + } +} + +void ggml_threadpool_free(struct ggml_threadpool* threadpool) { + if (!threadpool) return; + + const int n_threads = threadpool->n_threads_max; + +#ifndef GGML_USE_OPENMP + struct ggml_compute_state* workers = threadpool->workers; + + ggml_mutex_lock(&threadpool->mutex); + + threadpool->stop = true; + threadpool->pause = false; + + ggml_cond_broadcast(&threadpool->cond); + ggml_mutex_unlock(&threadpool->mutex); + + for (int j = 1; j < n_threads; j++) { + int32_t rc = ggml_thread_join(workers[j].thrd, NULL); + GGML_ASSERT(rc == GGML_EXIT_SUCCESS || rc == GGML_EXIT_ABORTED); + UNUSED(rc); + } + + ggml_mutex_destroy(&threadpool->mutex); + ggml_cond_destroy(&threadpool->cond); +#endif // GGML_USE_OPENMP + + const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads; + ggml_aligned_free(threadpool->workers, workers_size); + ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool)); +} + +#ifndef GGML_USE_OPENMP +// pause/resume must be called under mutex +static void ggml_threadpool_pause_locked(struct ggml_threadpool * threadpool) { + GGML_PRINT_DEBUG("Pausing threadpool\n"); + threadpool->pause = true; + ggml_cond_broadcast(&threadpool->cond); +} + +static void ggml_threadpool_resume_locked(struct ggml_threadpool * threadpool) { + GGML_PRINT_DEBUG("Resuming threadpool\n"); + threadpool->pause = false; + ggml_cond_broadcast(&threadpool->cond); +} +#endif + +void ggml_threadpool_pause(struct ggml_threadpool * threadpool) { +#ifndef GGML_USE_OPENMP + ggml_mutex_lock(&threadpool->mutex); + if (!threadpool->pause) { + ggml_threadpool_pause_locked(threadpool); + } + ggml_mutex_unlock(&threadpool->mutex); +#else + UNUSED(threadpool); +#endif +} + +void ggml_threadpool_resume(struct ggml_threadpool * threadpool) { +#ifndef GGML_USE_OPENMP + ggml_mutex_lock(&threadpool->mutex); + if (threadpool->pause) { + ggml_threadpool_resume_locked(threadpool); + } + ggml_mutex_unlock(&threadpool->mutex); +#else + UNUSED(threadpool); +#endif +} + +struct ggml_cplan ggml_graph_plan( + const struct ggml_cgraph * cgraph, + int n_threads, + struct ggml_threadpool * threadpool) { + + if (threadpool == NULL) { + //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); + } + if (n_threads <= 0) { + n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; + } + + size_t work_size = 0; + + struct ggml_cplan cplan; + memset(&cplan, 0, sizeof(struct ggml_cplan)); + + int max_tasks = 1; + + // thread scheduling for the different operations + work buffer size estimation + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + + const int n_tasks = ggml_get_n_tasks(node, n_threads); + + max_tasks = MAX(max_tasks, n_tasks); + + size_t cur = 0; + + if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) { + + switch (node->op) { + case GGML_OP_CPY: + case GGML_OP_DUP: + { + if (ggml_is_quantized(node->type) || + // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 + (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || + (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; + } + } break; + case GGML_OP_ADD: + case GGML_OP_ADD1: + { + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + } break; + case GGML_OP_ACC: + { + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; + } + } break; + case GGML_OP_COUNT_EQUAL: + { + cur = ggml_type_size(node->type)*n_tasks; + } break; + case GGML_OP_MUL_MAT: + { + const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type; + + if (node->src[1]->type != vec_dot_type) { + cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); + } + } break; + case GGML_OP_MUL_MAT_ID: + { + cur = 0; + const struct ggml_tensor * src0 = node->src[0]; + const struct ggml_tensor * src1 = node->src[1]; + const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; + if (src1->type != vec_dot_type) { + cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); + } + const int n_as = src0->ne[2]; + cur += GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + } break; + case GGML_OP_OUT_PROD: + { + if (ggml_is_quantized(node->src[0]->type)) { + cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + } break; + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + { + cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(node->src[0]->ne[3] == 1); + GGML_ASSERT(node->src[1]->ne[2] == 1); + GGML_ASSERT(node->src[1]->ne[3] == 1); + + const int64_t ne00 = node->src[0]->ne[0]; // K + const int64_t ne01 = node->src[0]->ne[1]; // Cout + const int64_t ne02 = node->src[0]->ne[2]; // Cin + const int64_t ne10 = node->src[1]->ne[0]; // L + const int64_t ne11 = node->src[1]->ne[1]; // Cin + + if ((node->src[0]->type == GGML_TYPE_F16 || + node->src[0]->type == GGML_TYPE_BF16) && + node->src[1]->type == GGML_TYPE_F32) { + cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02; + cur += sizeof(ggml_fp16_t)*ne10*ne11; + } else if (node->src[0]->type == GGML_TYPE_F32 && + node->src[1]->type == GGML_TYPE_F32) { + cur += sizeof(float)*ne00*ne01*ne02; + cur += sizeof(float)*ne10*ne11; + } else { + GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + const int64_t ne00 = node->src[0]->ne[0]; // W + const int64_t ne01 = node->src[0]->ne[1]; // H + const int64_t ne02 = node->src[0]->ne[2]; // Channels Out + const int64_t ne03 = node->src[0]->ne[3]; // Channels In + + const int64_t ne10 = node->src[1]->ne[0]; // W + const int64_t ne11 = node->src[1]->ne[1]; // H + const int64_t ne12 = node->src[1]->ne[2]; // Channels In + + cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; + cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + const int64_t D = node->src[0]->ne[0]; + const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); + const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back + if (node->src[1]->type == GGML_TYPE_F32) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_F16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == GGML_TYPE_BF16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } + } break; + + case GGML_OP_CROSS_ENTROPY_LOSS: + { + cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); + } break; + case GGML_OP_COUNT: + { + GGML_ABORT("fatal error"); + } + default: + break; + } + } + + work_size = MAX(work_size, cur); + } + + if (work_size > 0) { + work_size += CACHE_LINE_SIZE*(n_threads); + } + + cplan.threadpool = threadpool; + cplan.n_threads = MIN(max_tasks, n_threads); + cplan.work_size = work_size; + cplan.work_data = NULL; + + return cplan; +} + +static thread_ret_t ggml_graph_compute_thread(void * data) { + struct ggml_compute_state * state = (struct ggml_compute_state *) data; + struct ggml_threadpool * tp = state->threadpool; + + const struct ggml_cgraph * cgraph = tp->cgraph; + const struct ggml_cplan * cplan = tp->cplan; + + set_numa_thread_affinity(state->ith); + + struct ggml_compute_params params = { + /*.ith =*/ state->ith, + /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed), + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.threadpool=*/ tp, + }; + + for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) { + struct ggml_tensor * node = cgraph->nodes[node_n]; + + ggml_compute_forward(¶ms, node); + + if (state->ith == 0 && cplan->abort_callback && + cplan->abort_callback(cplan->abort_callback_data)) { + atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed); + tp->ec = GGML_STATUS_ABORTED; + } + + ggml_barrier(state->threadpool); + } + + return 0; +} + +#ifndef GGML_USE_OPENMP + +// check if thread is active +static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) { + struct ggml_threadpool * threadpool = state->threadpool; + int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed); + return (state->ith < n_threads); +} + +// check if thread is ready to proceed (exit from polling or sleeping) +static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) { + struct ggml_threadpool * threadpool = state->threadpool; + + if (state->pending || threadpool->stop || threadpool->pause) { return true; } + + // check for new graph/work + int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); + if (new_graph != state->last_graph) { + state->pending = ggml_graph_compute_thread_active(state); + state->last_graph = new_graph; + } + + return state->pending; +} + +// sync thread state after polling +static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * state) { + // TSAN doesn't support standalone fence yet, we use a dummy read-modify-write instead + #ifdef GGML_TSAN_ENABLED + atomic_fetch_add_explicit(&state->threadpool->n_graph, 0, memory_order_seq_cst); + #else + atomic_thread_fence(memory_order_seq_cst); + #endif + UNUSED(state); +} + +static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) { + struct ggml_threadpool * threadpool = state->threadpool; + + // Skip polling for unused threads + if (!ggml_graph_compute_thread_active(state)) { + return state->pending; + } + + // This seems to make 0 ... 100 a decent range for polling level across modern processors. + // Perhaps, we can adjust it dynamically based on load and things. + const uint64_t n_rounds = 1024UL * 128 * threadpool->poll; + + for (uint64_t i=0; !ggml_graph_compute_thread_ready(state) && i < n_rounds; i++) { + // No new work. Keep polling. + ggml_thread_cpu_relax(); + } + + return state->pending; +} + +static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) { + struct ggml_threadpool * threadpool = state->threadpool; + + if (ggml_graph_compute_poll_for_work(state)) { + ggml_graph_compute_thread_sync(state); + return state->pending; + } + + ggml_mutex_lock_shared(&threadpool->mutex); + while (!ggml_graph_compute_thread_ready(state)) { + // No new work. Wait for the signal. + GGML_PRINT_DEBUG("thread #%d waiting for work (sleeping)\n", state->ith); + ggml_cond_wait(&threadpool->cond, &threadpool->mutex); + } + ggml_mutex_unlock_shared(&threadpool->mutex); + + return state->pending; +} + +static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { + struct ggml_compute_state * state = (struct ggml_compute_state *) data; + struct ggml_threadpool * threadpool = state->threadpool; + + ggml_thread_apply_priority(threadpool->prio); + if (ggml_thread_cpumask_is_valid(state->cpumask)) { + ggml_thread_apply_affinity(state->cpumask); + } + + while (true) { + // Check if we need to sleep + while (threadpool->pause) { + GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith); + ggml_mutex_lock_shared(&threadpool->mutex); + if (threadpool->pause) { + ggml_cond_wait(&threadpool->cond, &threadpool->mutex); + } + GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith); + ggml_mutex_unlock_shared(&threadpool->mutex); + } + + // This needs to be checked for after the cond_wait + if (threadpool->stop) break; + + // Check if there is new work + // The main thread is the only one that can dispatch new work + + ggml_graph_compute_check_for_work(state); + if (state->pending) { + state->pending = false; + + ggml_graph_compute_thread(state); + } + } + + return (thread_ret_t) 0; +} + +// Start processing new graph +static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int n_threads) +{ + // Always take the mutex here because the worker threads are doing hybrid poll/wait + + ggml_mutex_lock(&threadpool->mutex); + + GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads); + + // Update the number of active threads + atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + + // Indicate the graph is ready to be processed + // We need the full seq-cst fence here because of the polling threads (used in thread_sync) + atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst); + + if (threadpool->pause) { + // Update main thread prio and affinity to match the threadpool settings + ggml_thread_apply_priority(threadpool->prio); + if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { + ggml_thread_apply_affinity(threadpool->workers[0].cpumask); + } + + // resume does cond broadcast + ggml_threadpool_resume_locked(threadpool); + } else { + ggml_cond_broadcast(&threadpool->cond); + } + + ggml_mutex_unlock(&threadpool->mutex); +} + +#endif // GGML_USE_OPENMP + +static struct ggml_threadpool * ggml_threadpool_new_impl( + struct ggml_threadpool_params * tpp, + struct ggml_cgraph * cgraph, + struct ggml_cplan * cplan) { + + struct ggml_threadpool * threadpool = + ggml_aligned_malloc(sizeof(struct ggml_threadpool)); + { + threadpool->cgraph = cgraph; + threadpool->cplan = cplan; + threadpool->n_graph = 0; + threadpool->n_barrier = 0; + threadpool->n_barrier_passed = 0; + threadpool->current_chunk = 0; + threadpool->stop = false; + threadpool->pause = tpp->paused; + threadpool->abort = -1; + threadpool->workers = NULL; + threadpool->n_threads_max = tpp->n_threads; + threadpool->n_threads_cur = tpp->n_threads; + threadpool->poll = tpp->poll; + threadpool->prio = tpp->prio; + threadpool->ec = GGML_STATUS_SUCCESS; + } + + // Allocate and init workers state + const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads; + struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size); + + memset(workers, 0, workers_size); + for (int j = 0; j < tpp->n_threads; j++) { + workers[j].threadpool = threadpool; + workers[j].ith = j; + } + + threadpool->workers = workers; + +#ifndef GGML_USE_OPENMP + ggml_mutex_init(&threadpool->mutex); + ggml_cond_init(&threadpool->cond); + + // Spin the threads for all workers, and update CPU placements. + // Place the main thread last (towards the higher numbered CPU cores). + + int32_t cpumask_iter = 0; + + for (int j = 1; j < tpp->n_threads; j++) { + ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter); + + int32_t rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_secondary_thread, &workers[j]); + GGML_ASSERT(rc == 0); + } + + ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter); + + if (!threadpool->pause) { + // Update main thread prio and affinity at the start, otherwise we'll do it in resume + ggml_thread_apply_priority(threadpool->prio); + if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { + ggml_thread_apply_affinity(threadpool->workers[0].cpumask); + } + } +#endif // GGML_USE_OPENMP + + return threadpool; +} + +struct ggml_threadpool * ggml_threadpool_new(struct ggml_threadpool_params * tpp) { + return ggml_threadpool_new_impl(tpp, NULL, NULL); +} + +enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { + ggml_cpu_init(); + + GGML_ASSERT(cplan); + GGML_ASSERT(cplan->n_threads > 0); + GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL); + + int n_threads = cplan->n_threads; + struct ggml_threadpool * threadpool = cplan->threadpool; + + bool disposable_threadpool = false; + + if (threadpool == NULL) { + //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); + disposable_threadpool = true; + + struct ggml_threadpool_params ttp = ggml_threadpool_params_default(n_threads); + threadpool = ggml_threadpool_new_impl(&ttp, cgraph, cplan); + } else { + // Reset some of the parameters that need resetting + // No worker threads should be accessing the parameters below at this stage + threadpool->cgraph = cgraph; + threadpool->cplan = cplan; + threadpool->current_chunk = 0; + threadpool->abort = -1; + threadpool->ec = GGML_STATUS_SUCCESS; + } + +#ifdef GGML_USE_OPENMP + if (n_threads > 1) { + #pragma omp parallel num_threads(n_threads) + { + #pragma omp single + { + // update the number of threads from the actual number of threads that we got from OpenMP + n_threads = omp_get_num_threads(); + atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + } + + ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]); + } + } else { + atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed); + ggml_graph_compute_thread(&threadpool->workers[0]); + } +#else + if (n_threads > threadpool->n_threads_max) { + GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max); + n_threads = threadpool->n_threads_max; + } + + // Kick all threads to start the new graph + ggml_graph_compute_kickoff(threadpool, n_threads); + + // This is a work thread too + ggml_graph_compute_thread(&threadpool->workers[0]); +#endif + + // don't leave affinity set on the main thread + clear_numa_thread_affinity(); + + enum ggml_status ret = threadpool->ec; + + if (disposable_threadpool) { + ggml_threadpool_free(threadpool); + } + + return ret; +} + +enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { + struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads, NULL); + + cplan.work_data = (uint8_t *)ggml_new_buffer(ctx, cplan.work_size); + + return ggml_graph_compute(cgraph, &cplan); +} + + +int ggml_cpu_has_avx(void) { +#if defined(__AVX__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx_vnni(void) { +#if defined(__AVXVNNI__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx2(void) { +#if defined(__AVX2__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512(void) { +#if defined(__AVX512F__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512_vbmi(void) { +#if defined(__AVX512VBMI__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512_vnni(void) { +#if defined(__AVX512VNNI__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_avx512_bf16(void) { +#if defined(__AVX512BF16__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_amx_int8(void) { +#if defined(__AMX_INT8__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fma(void) { +#if defined(__FMA__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_arm_fma(void) { +#if defined(__ARM_FEATURE_FMA) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_riscv_v(void) { +#if defined(__riscv_v_intrinsic) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_f16c(void) { +#if defined(__F16C__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_fp16_va(void) { +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_wasm_simd(void) { +#if defined(__wasm_simd128__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_llamafile(void) { +#if defined(GGML_USE_LLAMAFILE) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_sse3(void) { +#if defined(__SSE3__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_ssse3(void) { +#if defined(__SSSE3__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_vsx(void) { +#if defined(__POWER9_VECTOR__) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_neon(void) { +#if defined(__ARM_ARCH) && defined(__ARM_NEON) + return ggml_arm_arch_features.has_neon; +#else + return 0; +#endif +} + +int ggml_cpu_has_dotprod(void) { +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD) + return ggml_arm_arch_features.has_dotprod; +#else + return 0; +#endif +} + +int ggml_cpu_has_sve(void) { +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE) + return ggml_arm_arch_features.has_sve; +#else + return 0; +#endif +} + +int ggml_cpu_has_matmul_int8(void) { +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8) + return ggml_arm_arch_features.has_i8mm; +#else + return 0; +#endif +} + +int ggml_cpu_get_sve_cnt(void) { +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE) + return ggml_arm_arch_features.sve_cnt; +#else + return 0; +#endif +} + +void ggml_cpu_init(void) { + // needed to initialize f16 tables + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + ggml_critical_section_start(); + + static bool is_first_call = true; + + if (is_first_call) { + // initialize GELU, Quick GELU, SILU and EXP F32 tables + { + const uint64_t t_start = ggml_time_us(); UNUSED(t_start); + + for (int i = 0; i < (1 << 16); ++i) { + union { + uint16_t u16; + ggml_fp16_t fp16; + } u = {i}; + float f = GGML_FP16_TO_FP32(u.fp16); + ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); + ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); + } + + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); + + GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); + } + +#if defined(__ARM_ARCH) + ggml_init_arm_arch_features(); +#endif + + is_first_call = false; + } + + ggml_critical_section_end(); +} diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp new file mode 100644 index 000000000..2ccb4b472 --- /dev/null +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -0,0 +1,637 @@ +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-aarch64.h" +#include "ggml-cpu-traits.h" +#include "ggml-impl.h" +#include "amx/amx.h" + +#include +#include +#include + +#ifdef GGML_USE_CPU_HBM +#include "ggml-cpu-hbm.h" +#endif + +#if defined(__APPLE__) +#include +#include +#endif + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX + #define NOMINMAX +#endif +#include +#endif + +// ggml-backend interface + +std::vector& ggml_backend_cpu_get_extra_buffers_type() { + static std::vector bufts = []() { + std::vector bufts; + +#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) + if (ggml_backend_amx_buffer_type()) { + bufts.push_back(ggml_backend_amx_buffer_type()); + } +#endif + +#ifdef GGML_USE_CPU_AARCH64 + if (ggml_backend_cpu_aarch64_buffer_type()) { + bufts.push_back(ggml_backend_cpu_aarch64_buffer_type()); + } +#endif + + bufts.push_back(NULL); + + return bufts; + }(); + + return bufts; +} + +static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) { + return ggml_backend_cpu_get_extra_buffers_type().data(); + + GGML_UNUSED(device); +} + +static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) { + for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + if (extra && extra == buft) return true; + } + return false; +} + +// CPU backend - backend (stream) + +struct ggml_backend_cpu_context { + int n_threads; + ggml_threadpool_t threadpool; + + uint8_t * work_data; + size_t work_size; + + ggml_abort_callback abort_callback; + void * abort_callback_data; +}; + +static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) { + return "CPU"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_cpu_free(ggml_backend_t backend) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + delete[] cpu_ctx->work_data; + delete cpu_ctx; + delete backend; +} + +struct ggml_backend_plan_cpu { + struct ggml_cplan cplan; + struct ggml_cgraph cgraph; +}; + +static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_backend_plan_cpu * cpu_plan = new ggml_backend_plan_cpu; + + cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); + cpu_plan->cgraph = *cgraph; // FIXME: deep copy + + if (cpu_plan->cplan.work_size > 0) { + cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size]; + if (cpu_plan->cplan.work_data == NULL) { + delete cpu_plan; + return NULL; + } + } + + cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; + cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return cpu_plan; +} + +static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + delete[] cpu_plan->cplan.work_data; + delete cpu_plan; + + GGML_UNUSED(backend); +} + +static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { + struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan; + + return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); + + GGML_UNUSED(backend); +} + +static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; + + struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool); + + if (cpu_ctx->work_size < cplan.work_size) { + delete[] cpu_ctx->work_data; + cpu_ctx->work_data = new uint8_t[cplan.work_size]; + if (cpu_ctx->work_data == NULL) { + cpu_ctx->work_size = 0; + return GGML_STATUS_ALLOC_FAILED; + } + cpu_ctx->work_size = cplan.work_size; + } + cplan.work_data = (uint8_t *)cpu_ctx->work_data; + + cplan.abort_callback = cpu_ctx->abort_callback; + cplan.abort_callback_data = cpu_ctx->abort_callback_data; + + return ggml_graph_compute(cgraph, &cplan); +} + +static const struct ggml_backend_i ggml_backend_cpu_i = { + /* .get_name = */ ggml_backend_cpu_get_name, + /* .free = */ ggml_backend_cpu_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, + /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, + /* .graph_compute = */ ggml_backend_cpu_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static ggml_guid_t ggml_backend_cpu_guid(void) { + static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 }; + return &guid; +} + +ggml_backend_t ggml_backend_cpu_init(void) { + // initialize CPU backend now to avoid slowing the first graph computation + ggml_cpu_init(); + + struct ggml_backend_cpu_context * ctx = new ggml_backend_cpu_context; + if (ctx == NULL) { + return NULL; + } + + ctx->n_threads = GGML_DEFAULT_N_THREADS; + ctx->threadpool = NULL; + ctx->work_data = NULL; + ctx->work_size = 0; + ctx->abort_callback = NULL; + ctx->abort_callback_data = NULL; + + ggml_backend_t cpu_backend = new ggml_backend { + /* .guid = */ ggml_backend_cpu_guid(), + /* .interface = */ ggml_backend_cpu_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ ctx, + }; + + if (cpu_backend == NULL) { + delete ctx; + return NULL; + } + + return cpu_backend; +} + +bool ggml_backend_is_cpu(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid()); +} + +void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->n_threads = n_threads; +} + +void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + + if (ctx->threadpool && ctx->threadpool != threadpool) { + // already had a different threadpool, pause/suspend it before switching + ggml_threadpool_pause(ctx->threadpool); + } + ctx->threadpool = threadpool; +} + +void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = abort_callback_data; +} + +// CPU backend - device + +struct ggml_backend_cpu_device_context { + std::string description = "CPU"; + + ggml_backend_cpu_device_context() { +#ifdef __APPLE__ + size_t len = 0; + if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) { + description.resize(len); + sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT + } +#elif defined(__linux__) + FILE * f = fopen("/proc/cpuinfo", "r"); + if (f) { + char buf[1024]; + while (fgets(buf, sizeof(buf), f)) { + if (strncmp(buf, "model name", 10) == 0) { + char * p = strchr(buf, ':'); + if (p) { + p++; + while (std::isspace(*p)) { + p++; + } + while (std::isspace(p[strlen(p) - 1])) { + p[strlen(p) - 1] = '\0'; + } + description = p; + break; + } + } + } + fclose(f); + } +#elif defined(_WIN32) + HKEY hKey; + if (RegOpenKeyEx(HKEY_LOCAL_MACHINE, + TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"), + 0, + KEY_READ, + &hKey) == ERROR_SUCCESS) { + DWORD cpu_brand_size = 0; + if (RegQueryValueExA(hKey, + TEXT("ProcessorNameString"), + NULL, + NULL, + NULL, + &cpu_brand_size) == ERROR_SUCCESS) { + description.resize(cpu_brand_size); + if (RegQueryValueExA(hKey, + TEXT("ProcessorNameString"), + NULL, + NULL, + (LPBYTE)&description[0], // NOLINT + &cpu_brand_size) == ERROR_SUCCESS) { + if (description.find('\0') != std::string::npos) { + description.resize(description.find('\0')); + } + } + } + RegCloseKey(hKey); + } +#endif + } +}; + +static const char * ggml_backend_cpu_device_get_name(ggml_backend_dev_t dev) { + return "CPU"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t dev) { + struct ggml_backend_cpu_device_context * ctx = (struct ggml_backend_cpu_device_context *)dev->context; + + return ctx->description.c_str(); +} + +static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + // TODO + *free = 0; + *total = 0; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_CPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_cpu_device_get_name(dev); + props->description = ggml_backend_cpu_device_get_description(dev); + props->type = ggml_backend_cpu_device_get_type(dev); + ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_cpu_device_init_backend(ggml_backend_dev_t dev, const char * params) { + return ggml_backend_cpu_init(); + + GGML_UNUSED(dev); + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_cpu_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + return ggml_backend_cpu_buffer_from_ptr(ptr, size); + + GGML_UNUSED(dev); + GGML_UNUSED(max_tensor_size); +} + +static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + if (op->op == GGML_OP_NONE || op->op == GGML_OP_RESHAPE || op->op == GGML_OP_VIEW || op->op == GGML_OP_PERMUTE || op->op == GGML_OP_TRANSPOSE) { + return true; + } + + // extra_buffer_op? + for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { + if (extra) { + auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context; + if (buf_extra && buf_extra->supports_op(dev, op)) { + return true; + } + } + } + + // the other case need host buffer. + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) { + return false; + } + } + + switch (op->op) { + case GGML_OP_CPY: + return + op->type != GGML_TYPE_IQ3_XXS && + op->type != GGML_TYPE_IQ3_S && + op->type != GGML_TYPE_IQ2_XXS && + op->type != GGML_TYPE_IQ2_XS && + op->type != GGML_TYPE_IQ2_S && + op->type != GGML_TYPE_IQ1_S && + op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float + case GGML_OP_MUL_MAT: + return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type; + case GGML_OP_SOFT_MAX_BACK: { + if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) { + return false; + } + float max_bias = 0.0f; + + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + + return max_bias == 0.0f; + } + case GGML_OP_IM2COL_BACK: + return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + case GGML_OP_OUT_PROD: + return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && + src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + default: + return true; + } +} + +static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_host(buft) || ggml_backend_cpu_is_extra_buffer_type(buft); + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_cpu_device_i = { + /* .get_name = */ ggml_backend_cpu_device_get_name, + /* .get_description = */ ggml_backend_cpu_device_get_description, + /* .get_memory = */ ggml_backend_cpu_device_get_memory, + /* .get_type = */ ggml_backend_cpu_device_get_type, + /* .get_props = */ ggml_backend_cpu_device_get_props, + /* .init_backend = */ ggml_backend_cpu_device_init_backend, + /* .get_buffer_type = */ ggml_backend_cpu_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_host_ptr, + /* .supports_op = */ ggml_backend_cpu_device_supports_op, + /* .supports_buft = */ ggml_backend_cpu_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// CPU backend - backend (reg) + +static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) { + return "CPU"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_cpu_reg_get_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + static ggml_backend_cpu_device_context ctx; + static ggml_backend_device ggml_backend_cpu_device = { + /* .iface = */ ggml_backend_cpu_device_i, + /* .reg = */ reg, + /* .context = */ &ctx, + }; + + return &ggml_backend_cpu_device; +} + +// This is intended to replace the the ggml_cpu_has_* functions when loading the CPU backend dynamically, +// and additionally to allow other backends to expose their own list of features that applications can query using the same API +static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t reg) { + static std::vector features = []() { + ggml_cpu_init(); + + std::vector features; + if (ggml_cpu_has_sse3()) { + features.push_back({ "SSE3", "1" }); + } + if (ggml_cpu_has_ssse3()) { + features.push_back({ "SSSE3", "1" }); + } + if (ggml_cpu_has_avx()) { + features.push_back({ "AVX", "1" }); + } + if (ggml_cpu_has_avx_vnni()) { + features.push_back({ "AVX_VNNI", "1" }); + } + if (ggml_cpu_has_avx2()) { + features.push_back({ "AVX2", "1" }); + } + if (ggml_cpu_has_f16c()) { + features.push_back({ "F16C", "1" }); + } + if (ggml_cpu_has_fma()) { + features.push_back({ "FMA", "1" }); + } + if (ggml_cpu_has_avx512()) { + features.push_back({ "AVX512", "1" }); + } + if (ggml_cpu_has_avx512_vbmi()) { + features.push_back({ "AVX512_VBMI", "1" }); + } + if (ggml_cpu_has_avx512_vnni()) { + features.push_back({ "AVX512_VNNI", "1" }); + } + if (ggml_cpu_has_avx512_bf16()) { + features.push_back({ "AVX512_BF16", "1" }); + } + if (ggml_cpu_has_amx_int8()) { + features.push_back({ "AMX_INT8", "1" }); + } + if (ggml_cpu_has_neon()) { + features.push_back({ "NEON", "1" }); + } + if (ggml_cpu_has_arm_fma()) { + features.push_back({ "ARM_FMA", "1" }); + } + if (ggml_cpu_has_fp16_va()) { + features.push_back({ "FP16_VA", "1" }); + } + if (ggml_cpu_has_matmul_int8()) { + features.push_back({ "MATMUL_INT8", "1" }); + } + if (ggml_cpu_has_sve()) { + features.push_back({ "SVE", "1" }); + } + if (ggml_cpu_has_dotprod()) { + features.push_back({ "DOTPROD", "1" }); + } + if (ggml_cpu_has_matmul_int8()) { + features.push_back({ "MATMUL_INT8", "1" }); + } + if (ggml_cpu_get_sve_cnt() > 0) { + static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt()); + features.push_back({ "SVE_CNT", sve_cnt.c_str() }); + } + if (ggml_cpu_has_riscv_v()) { + features.push_back({ "RISCV_V", "1" }); + } + if (ggml_cpu_has_vsx()) { + features.push_back({ "VSX", "1" }); + } + if (ggml_cpu_has_wasm_simd()) { + features.push_back({ "WASM_SIMD", "1" }); + } + if (ggml_cpu_has_llamafile()) { + features.push_back({ "LLAMAFILE", "1" }); + } + #ifdef GGML_USE_ACCELERATE + features.push_back({ "ACCELERATE", "1" }); + #endif + #ifdef GGML_USE_CPU_HBM + features.push_back({ "CPU_HBM", "1" }); + #endif + #ifdef GGML_USE_OPENMP + features.push_back({ "OPENMP", "1" }); + #endif + #ifdef GGML_USE_CPU_AARCH64 + features.push_back({ "AARCH64_REPACK", "1" }); + #endif + + features.push_back({ nullptr, nullptr }); + + return features; + }(); + + return features.data(); + + GGML_UNUSED(reg); +} + +static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "ggml_backend_set_n_threads") == 0) { + ggml_backend_set_n_threads_t fct = ggml_backend_cpu_set_n_threads; + return (void *)fct; + } + if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) { + ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_cpu_device_get_extra_buffers_type; + return (void *)fct; + } + if (strcmp(name, "ggml_backend_get_features") == 0) { + return (void *)ggml_backend_cpu_get_features; + } + if (strcmp(name, "ggml_backend_set_abort_callback") == 0) { + return (void *)ggml_backend_cpu_set_abort_callback; + } + if (strcmp(name, "ggml_backend_cpu_numa_init") == 0) { + return (void *)ggml_numa_init; + } + if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) { + return (void *)ggml_is_numa; + } + + // threadpool - TODO: move to ggml-base + if (strcmp(name, "ggml_threadpool_new") == 0) { + return (void *)ggml_threadpool_new; + } + if (strcmp(name, "ggml_threadpool_free") == 0) { + return (void *)ggml_threadpool_free; + } + if (strcmp(name, "ggml_backend_cpu_set_threadpool") == 0) { + return (void *)ggml_backend_cpu_set_threadpool; + } + + return NULL; + + GGML_UNUSED(reg); +} + +static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = { + /* .get_name = */ ggml_backend_cpu_reg_get_name, + /* .get_device_count = */ ggml_backend_cpu_reg_get_device_count, + /* .get_device = */ ggml_backend_cpu_reg_get_device, + /* .get_proc_address = */ ggml_backend_cpu_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_cpu_reg(void) { + // init CPU feature detection + ggml_cpu_init(); + + static struct ggml_backend_reg ggml_backend_cpu_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_cpu_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_cpu_reg; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_cpu_reg) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp new file mode 100644 index 000000000..c22a66287 --- /dev/null +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -0,0 +1,2597 @@ +// Copyright 2024 Mozilla Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// +// _ _ ___ _ _ ___ +// | |_(_)_ _ _ _| _ ) | /_\ / __| +// | _| | ' \ || | _ \ |__ / _ \\__ \. +// \__|_|_||_\_, |___/____/_/ \_\___/ +// |__/ +// +// BASIC LINEAR ALGEBRA SUBPROGRAMS +// +// +// This file implements multithreaded CPU matrix multiplication for the +// common contiguous use case C = Aᵀ * B. These kernels are designed to +// have excellent performance[1] for matrices that fit in the CPU cache +// without imposing any overhead such as cache filling or malloc calls. +// +// This implementation does not guarantee any upper bound with rounding +// errors, which grow along with k. Our goal's to maximally exploit the +// hardware for performance, and then use whatever resources remain for +// improving numerical accuracy. +// +// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. +// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wignored-attributes" +#endif + +#include "sgemm.h" +#include "ggml-impl.h" +#include "ggml-cpu-impl.h" +#include "ggml-quants.h" + +#include +#include + +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + +#if defined(__ARM_NEON) || defined(__AVX512F__) +#define VECTOR_REGISTERS 32 +#else +#define VECTOR_REGISTERS 16 +#endif + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +namespace { + +inline float unhalf(ggml_fp16_t d) { + return GGML_FP16_TO_FP32(d); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED ARITHMETIC OPERATIONS + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); } +inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); } +inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } +#endif // __SSE__ + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); } +inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); } +inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); } +#endif // __AVX__ + +#if defined(__AVX512F__) +inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } +inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); } +inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); } +#endif // __AVX512F__ + +#if defined(__ARM_NEON) +inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); } +inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); } +inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); } +#endif // __ARM_NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); } +inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } +inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if defined(__MMA__) +typedef vector unsigned char vec_t; +typedef __vector_quad acc_t; +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED FUSED MULTIPLY ADD + +/** + * Computes a * b + c. + */ +template +inline U madd(T a, T b, U c) { + return add(mul(a, b), c); +} + +#if defined(__FMA__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> +inline __m256 madd(__m256 a, __m256 b, __m256 c) { + return _mm256_fmadd_ps(a, b, c); +} +#endif +#if defined(__AVX512F__) +template <> +inline __m512 madd(__m512 a, __m512 b, __m512 c) { + return _mm512_fmadd_ps(a, b, c); +} +#endif +#if defined(__AVX512BF16__) +template <> +inline __m512 madd(__m512bh a, __m512bh b, __m512 c) { + return _mm512_dpbf16_ps(c, a, b); +} +template <> +inline __m256 madd(__m256bh a, __m256bh b, __m256 c) { + return _mm256_dpbf16_ps(c, a, b); +} +#endif +#endif + +#if defined(__ARM_FEATURE_FMA) +template <> +inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { + return vfmaq_f32(c, b, a); +} +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) +template <> +inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { + return vfmaq_f16(c, b, a); +} +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED HORIZONTAL SUM + +#if defined(__ARM_NEON) +inline float hsum(float32x4_t x) { + return vaddvq_f32(x); +} +#endif // __ARM_NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) +inline float hsum(float16x8_t x) { + return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), + vcvt_f32_f16(vget_high_f16(x)))); +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline float hsum(__m128 x) { +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); +#else + __m128 t; + t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1)); + x = _mm_add_ps(x, t); + t = _mm_movehl_ps(t, x); + x = _mm_add_ss(x, t); +#endif + return _mm_cvtss_f32(x); +} +#endif + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +inline float hsum(__m256 x) { + return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), + _mm256_castps256_ps128(x))); +} +#endif // __AVX__ + +#if defined(__AVX512F__) +inline float hsum(__m512 x) { + return _mm512_reduce_add_ps(x); +} +#endif // __AVX512F__ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED MEMORY LOADING + +template T load(const U *); + +#if defined(__ARM_NEON) +template <> inline float32x4_t load(const float *p) { + return vld1q_f32(p); +} +#if !defined(_MSC_VER) +// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +template <> inline float16x8_t load(const ggml_fp16_t *p) { + return vld1q_f16((const float16_t *)p); +} +template <> inline float32x4_t load(const ggml_fp16_t *p) { + return vcvt_f32_f16(vld1_f16((const float16_t *)p)); +} +#endif // _MSC_VER +#endif // __ARM_NEON + +#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m128 load(const float *p) { + return _mm_loadu_ps(p); +} +#endif // __SSE__ + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m256 load(const float *p) { + return _mm256_loadu_ps(p); +} +#endif // __AVX__ + +#if defined(__AVX2__) || defined(__AVX512F__) +template <> inline __m256 load(const ggml_bf16_t *p) { + return _mm256_castsi256_ps( + _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)p)), 16)); +} +#endif // __AVX2__ + +#if defined(__F16C__) +template <> inline __m256 load(const ggml_fp16_t *p) { + return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); +} +#endif // __F16C__ + +#if defined(__AVX512F__) +template <> inline __m512 load(const float *p) { + return _mm512_loadu_ps(p); +} +template <> inline __m512 load(const ggml_fp16_t *p) { + return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); +} +template <> inline __m512 load(const ggml_bf16_t *p) { + return _mm512_castsi512_ps( + _mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)p)), 16)); +} +#endif // __AVX512F__ + +#if defined(__AVX512BF16__) +template <> inline __m512bh load(const ggml_bf16_t *p) { + return (__m512bh)_mm512_loadu_ps((const float *)p); +} +template <> inline __m256bh load(const ggml_bf16_t *p) { + return (__m256bh)_mm256_loadu_ps((const float *)p); +} +template <> inline __m512bh load(const float *p) { + return _mm512_cvtne2ps_pbh(_mm512_loadu_ps(p + 16), _mm512_loadu_ps(p)); +} +template <> inline __m256bh load(const float *p) { + return _mm512_cvtneps_pbh(_mm512_loadu_ps(p)); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// CONSTANTS + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; +static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl); +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// FLOATING POINT MATRIX MULTIPLICATION + +template +static inline int64_t BLOCK_SIZE(size_t m) { + const int64_t NB_BLOC_M = (m + M - 1) / M; + return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1; +} + +static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) { + return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1); +} + +template +class tinyBLAS { + public: + tinyBLAS(const ggml_compute_params * params, int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc) + : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) { + } + + bool matmul(int64_t m, int64_t n) { + if (k % KN != 0) + return false; + // compute RM for only need tile with size RM&RM-1 +#if VECTOR_REGISTERS == 32 + if (m % 16 == 0 && (m/16 >= params->nth)) { + const int64_t SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 4>(m, n, SIZE_N, 12); + return true; + } + if (m % 8 == 0 ) { + const int64_t SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 2>(m, n, SIZE_N, 12); + return true; + } + if (m % 4 == 0) { + const int64_t SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 1>(m, n, SIZE_N, 12); + return true; + } +#else // VECTOR_REGISTERS == 16 + if (m % 16 == 0 && (m/16 >= params->nth)) { + const int64_t SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 4>(m, n, SIZE_N, 24); + return true; + } + if (m % 8 == 0 ) { + const int64_t SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 2>(m, n, SIZE_N, 24); + return true; + } + if (m % 4 == 0) { + const int64_t SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 1>(m, n, SIZE_N, 24); + return true; + } +#endif + return false; + } + + private: + template + inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) { + if (SIZE_N == RN) { + return gemm(m, n, BN); + } + if constexpr (RN > 1) { + return mnpack(m, n, SIZE_N, BN); + } else { + GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + GGML_ASSERT(false); // we have miss something. + } + } + + template + inline void gemm_bloc(int64_t ii, int64_t jj) { + D Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; l += KN) { + // help compiler for op order. + if constexpr (RM <= RN) { + V Av[RM]; + for (int64_t i = 0; i < RM; ++i) { + Av[i] = load(A + lda * (ii + i) + l); + } + for (int64_t j = 0; j < RN; ++j) { + V Bv = load(B + ldb * (jj + j) + l); + for (int64_t i = 0; i < RM; ++i) { + Cv[j][i] = madd(Av[i], Bv, Cv[j][i]); + } + } + } else { + V Bv[RN]; + for (int64_t j = 0; j < RN; ++j) { + Bv[j] = load(B + ldb * (jj + j) + l); + } + for (int64_t i = 0; i < RM; ++i) { + V Av = load(A + lda * (ii + i) + l); + for (int64_t j = 0; j < RN; ++j) { + Cv[j][i] = madd(Av, Bv[j], Cv[j][i]); + } + } + } + } + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + + template + NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) { + static std::atomic current_chunk; + + GGML_ASSERT(m % (RM * BM) == 0); + const int64_t ytiles = m / (RM * BM); + const int64_t xtiles = (n + RN -1) / RN; + const int64_t jj_RN = (xtiles - (xtiles * RN - n)); + + // "round" bloc_size to "nearest" BN + const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN; + const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1; + const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles)); + const int64_t nb_job = ytiles * NB_BN; + + if (params->ith == 0) { + GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles); + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed); + } + + ggml_barrier(params->threadpool); + + int64_t job = params->ith; + while (job < nb_job) { + const int64_t ii = (job % ytiles) * RM * BM; + const int64_t jb = job / ytiles; + const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN); + const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN); + + const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN); + const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN); + const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN; + + for (int64_t bi = 0; bi < BM * RM; bi += RM) { + int64_t jj = jj0; + for (; jj < jj1; jj += RN) { + gemm_bloc(ii + bi, jj); + } + if constexpr (RN > 1) { + for (; jj < jj2; jj += RN - 1) { + gemm_bloc(ii + bi, jj); + } + } + GGML_ASSERT(jj == jj2); + } + + // next step. + job = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed); + } + + ggml_barrier(params->threadpool); + return; + } + + const ggml_compute_params * params; + const TA *const A; + const TB *const B; + TC *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; +}; + +////////////////////////////////////////////////////////////////////////////////////////// +// QUANT ZERO MATRIX MULTIPLICATION + +#if defined(__ARM_FEATURE_DOTPROD) +template +class tinyBLAS_Q0_ARM { + public: + tinyBLAS_Q0_ARM(int64_t k, + const TA *A, int64_t lda, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + float32x4_t Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + Cv[j][i] = vmlaq_n_f32(Cv[j][i], + vcvtq_f32_s32(vdotq_s32( + vdotq_s32(vdupq_n_s32(0), + load_lo(A + lda * (ii + i) + l), + load_lo(B + ldb * (jj + j) + l)), + load_hi(A + lda * (ii + i) + l), + load_hi(B + ldb * (jj + j) + l))), + unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)); + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + inline int8x16_t load_lo(const block_q8_0 *b) { + return vld1q_s8(b->qs); + } + + inline int8x16_t load_hi(const block_q8_0 *b) { + return vld1q_s8(b->qs + 16); + } + + inline int8x16_t load_lo(const block_q4_0 *b) { + return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), + vdupq_n_u8(0x0f))), + vdupq_n_s8(0x8)); + } + + inline int8x16_t load_hi(const block_q4_0 *b) { + return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), + vdupq_n_s8(0x8)); + } + + const TA *const A; + const block_q8_0 *const B; + float *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; +#endif // __ARM_FEATURE_DOTPROD + +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) +template +class tinyBLAS_Q0_AVX { + public: + tinyBLAS_Q0_AVX(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { +#if VECTOR_REGISTERS == 32 + case 0x44: + mc = 4; + nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<4>(m0, m, n0, n); +#else + gemm<4, 4>(m0, m, n0, n); +#endif + break; + case 0x43: + mc = 4; + nc = 3; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<3>(m0, m, n0, n); +#else + gemm<4, 3>(m0, m, n0, n); +#endif + break; + case 0x34: + mc = 3; + nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemmMx4<3>(m0, m, n0, n); +#else + gemm<3, 4>(m0, m, n0, n); +#endif + break; + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<2>(m0, m, n0, n); +#else + gemm<4, 2>(m0, m, n0, n); +#endif + break; + case 0x24: + mc = 2; + nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemmMx4<2>(m0, m, n0, n); +#else + gemm<2, 4>(m0, m, n0, n); +#endif + break; +#else + case 0x44: + case 0x43: + case 0x42: + mc = 4; + nc = 2; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<2>(m0, m, n0, n); +#else + gemm<4, 2>(m0, m, n0, n); +#endif + break; + case 0x34: + case 0x24: + mc = 2; + nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemmMx4<2>(m0, m, n0, n); +#else + gemm<2, 4>(m0, m, n0, n); +#endif + break; + case 0x33: +#endif + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; +#if defined(__AVX2__) && defined(__F16C__) + gemm4xN<1>(m0, m, n0, n); +#else + gemm<4, 1>(m0, m, n0, n); +#endif + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; +#if defined(__AVX2__) && defined(__F16C__) + gemmMx4<1>(m0, m, n0, n); +#else + gemm<1, 4>(m0, m, n0, n); +#endif + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + +#if defined(__AVX2__) && defined(__F16C__) +// Templated functions for gemm of dimensions 4xN + template + NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / 4; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * 4; + int64_t jj = n0 + job % xtiles * RN; + __m256 Cv[RN][4] = {}; + for (int64_t l = 0; l < k; ++l) { + uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d); + // Convert delta values for four blocks to float values + __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta)); + __m256i avec0 = load(A + lda * (ii + 0) + l); + __m256i avec1 = load(A + lda * (ii + 1) + l); + __m256i avec2 = load(A + lda * (ii + 2) + l); + __m256i avec3 = load(A + lda * (ii + 3) + l); + for (int64_t j = 0; j < RN; ++j) { + __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d)); + // Computation of product of delta values for four blocks and replicate it across 256 bit lane + __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db)); + dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); + // Computation of dot product and multiplication with appropriate delta value products + Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0), + updot(_mm256_sign_epi8(avec0, avec0), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)), + Cv[j][0]); + Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85), + updot(_mm256_sign_epi8(avec1, avec1), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)), + Cv[j][1]); + Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170), + updot(_mm256_sign_epi8(avec2, avec2), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)), + Cv[j][2]); + Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255), + updot(_mm256_sign_epi8(avec3, avec3), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)), + Cv[j][3]); + } + } + + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < 4; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + // Templated functions for gemm of dimensions Mx4 + template + NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / 4; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * 4; + __m256 Cv[4][RM] = {}; + for (int64_t l = 0; l < k; ++l) { + uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d); + // Convert delta values for four blocks to float values + __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta)); + __m256i bvec0 = load(B + ldb * (jj + 0) + l); + __m256i bvec1 = load(B + ldb * (jj + 1) + l); + __m256i bvec2 = load(B + ldb * (jj + 2) + l); + __m256i bvec3 = load(B + ldb * (jj + 3) + l); + for (int64_t i = 0; i < RM; ++i) { + __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d))); + // Computation of product of delta values for four blocks and replicate it across 256 bit lane + __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db)); + dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); + // Computation of dot product and multiplication with appropriate delta value products + Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))), + Cv[0][i]); + Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))), + Cv[1][i]); + Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))), + Cv[2][i]); + Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))), + Cv[3][i]); + } + } + for (int64_t j = 0; j < 4; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } +#endif + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + __m256 Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) { +#if defined(__AVX2__) + __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l))); +#else + __m128i ali0 = load0(A + lda * (ii + i) + l); + __m128i ali1 = load1(A + lda * (ii + i) + l); + __m128i blj0 = load0(B + ldb * (jj + j) + l); + __m128i blj1 = load1(B + ldb * (jj + j) + l); + + __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); + __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); + __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); + __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); + + // updot + const __m128i oneFill = _mm_set1_epi16(1); + __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); + __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); + __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); +#endif + Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)), + udTmp, + Cv[j][i]); + } + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + inline __m256i load(const block_q8_0 *b) { + return _mm256_loadu_si256((const __m256i *)b->qs); + } + + inline __m128i load0(const block_q8_0 *b) { + return _mm_loadu_si128((const __m128i *)b->qs); + } + + inline __m128i load1(const block_q8_0 *b) { + return _mm_loadu_si128(((const __m128i *)b->qs) + 1); + } + + inline __m256i load(const block_q4_0 *b) { + return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); + } + + inline __m128i load0(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); + } + + inline __m128i load1(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); + } + + inline __m256i load(const block_q5_0 *b) { + return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh)); + } + + inline __m128i load0(const block_q5_0* b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + uint32_t x32; + memcpy(&x32, b->qh, sizeof(uint32_t)); + __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x); + __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1), + _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe), + _mm_shuffle_epi8(_mm_set1_epi32(x32), + _mm_set_epi64x(0x0101010101010101, 0x0000000000000000)))); + bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0)); + return _mm_or_si128(qxl, bytesl); + } + + inline __m128i load1(const block_q5_0* b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + uint32_t x32; + memcpy(&x32, b->qh, sizeof(uint32_t)); + __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)); + __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1), + _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe), + _mm_shuffle_epi8(_mm_set1_epi32(x32), + _mm_set_epi64x(0x0303030303030303, 0x0202020202020202)))); + bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0)); + return _mm_or_si128(qxh, bytesh); + } + + inline __m256i load(const block_iq4_nl *b) { + return MM256_SET_M128I(load1(b), load0(b)); + } + + inline __m128i load0(const block_iq4_nl *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x)); + } + + inline __m128i load1(const block_iq4_nl *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4))); + } + + inline __m256 updot(__m256i u, __m256i s) { + __m256i res; +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); +#elif defined(__AVXVNNI__) + res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s); +#else + res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); +#endif + return _mm256_cvtepi32_ps(res); + } + + static inline __m256i denibble(const uint8_t *p) { + __m128i x = _mm_loadu_si128((const __m128i *)p); + return _mm256_and_si256(_mm256_set1_epi8(15), + _mm256_insertf128_si256(_mm256_castsi128_si256(x), + _mm_srli_epi16(x, 4), 1)); + } + + static inline __m256i bittobyte(const uint8_t *p) { + uint32_t x32; + memcpy(&x32, p, sizeof(uint32_t)); + __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1), + _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe), + _mm256_shuffle_epi8(_mm256_set1_epi32(x32), + _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000)))); + return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0)); + } + + const TA *const A; + const TB *const B; + TC *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; +#endif // __AVX__ + +//PPC Implementation +#if defined(__MMA__) + +#define SAVE_ACC(ACC, ii, jj) \ + __builtin_mma_disassemble_acc(vec_C, ACC); \ + for (int I = 0; I < 4; I++) { \ + for (int J = 0; J < 4; J++) { \ + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \ + } \ + } \ + +template +class tinyBLAS_Q0_PPC { + public: + tinyBLAS_Q0_PPC(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + + template + inline void save_res(int ii, int jj, int idx, vector float* fin_res) { + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); + } + } + } + + template + inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array& comparray, vector float* vs, vector float* fin_res) { + vector signed int vec_C[4]; + vector float CA[4] = {0}; + vector float res[4] = {0}; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int i = 0; i < 4; i++) { + CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); + } + } + + template + void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + int64_t i, j; + TA *aoffset = NULL; + VA *vecOffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; + VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0}; + VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0}; + VB t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + + aoffset = const_cast(a); + vecOffset = vec; + j = (rows >> 3); + if (j > 0) { + do { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; + + i = (cols >> 3); + if (i > 0) { + do { + C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); + C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); + C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); + C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); + C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs); + C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs); + C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs); + C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs); + + __builtin_vsx_disassemble_pair(c1, &C1); + __builtin_vsx_disassemble_pair(c2, &C2); + __builtin_vsx_disassemble_pair(c3, &C3); + __builtin_vsx_disassemble_pair(c4, &C4); + __builtin_vsx_disassemble_pair(c5, &C5); + __builtin_vsx_disassemble_pair(c6, &C6); + __builtin_vsx_disassemble_pair(c7, &C7); + __builtin_vsx_disassemble_pair(c8, &C8); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + t1 = vec_perm(c5[0], c6[0], swiz1); + t2 = vec_perm(c5[0], c6[0], swiz2); + t3 = vec_perm(c7[0], c8[0], swiz1); + t4 = vec_perm(c7[0], c8[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+128); + vec_xst(t6, 0, vecOffset+144); + vec_xst(t7, 0, vecOffset+160); + vec_xst(t8, 0, vecOffset+176); + + t1 = vec_perm(c5[1], c6[1], swiz1); + t2 = vec_perm(c5[1], c6[1], swiz2); + t3 = vec_perm(c7[1], c8[1], swiz1); + t4 = vec_perm(c7[1], c8[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+192); + vec_xst(t6, 0, vecOffset+208); + vec_xst(t7, 0, vecOffset+224); + vec_xst(t8, 0, vecOffset+240); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + aoffset4 += lda; + aoffset5 += lda; + aoffset6 += lda; + aoffset7 += lda; + aoffset8 += lda; + vecOffset += 256; + i--; + } while(i > 0); + } + j--; + } while(j > 0); + } + + if (rows & 4) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; + + i = (cols >> 3); + if (i > 0) { + do { + C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); + C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); + C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); + C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); + + __builtin_vsx_disassemble_pair(c1, &C1); + __builtin_vsx_disassemble_pair(c2, &C2); + __builtin_vsx_disassemble_pair(c3, &C3); + __builtin_vsx_disassemble_pair(c4, &C4); + + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + aoffset4 += lda; + vecOffset += 128; + i--; + } while(i > 0); + } + } + if (rows & 3) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + i = (cols >> 3); + if (i > 0) { + do { + switch(rows) { + case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); + __builtin_vsx_disassemble_pair(c3, &C3); + case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); + __builtin_vsx_disassemble_pair(c2, &C2); + case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); + __builtin_vsx_disassemble_pair(c1, &C1); + break; + } + t1 = vec_perm(c1[0], c2[0], swiz1); + t2 = vec_perm(c1[0], c2[0], swiz2); + t3 = vec_perm(c3[0], c4[0], swiz1); + t4 = vec_perm(c3[0], c4[0], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + + t1 = vec_perm(c1[1], c2[1], swiz1); + t2 = vec_perm(c1[1], c2[1], swiz2); + t3 = vec_perm(c3[1], c4[1], swiz1); + t4 = vec_perm(c3[1], c4[1], swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset+64); + vec_xst(t6, 0, vecOffset+80); + vec_xst(t7, 0, vecOffset+96); + vec_xst(t8, 0, vecOffset+112); + + aoffset1 += lda; + aoffset2 += lda; + aoffset3 += lda; + vecOffset += 128; + i--; + } while(i > 0); + } + } + } + + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + int m_rem = MIN(m - m0, 8); + int n_rem = MIN(n - n0, 8); + // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance + // issues. After resolving them, below code will be enabled. + /*if (m_rem >= 16 && n_rem >= 8) { + mc = 16; + nc = 8; + gemm<16,8>(m0, m, n0, n); + } else if(m_rem >= 8 && n_rem >= 16) { + mc = 8; + nc = 16; + gemm<8,16>(m0, m, n0, n); + }*/ + if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 8) { + mc = 4; + nc = 8; + gemm<4,8>(m0, m, n0, n); + } else if (m_rem >= 8 && n_rem >= 4) { + mc = 8; + nc = 4; + gemm<8,4>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 4) { + mc = 4; + nc = 4; + gemm_small<4, 4>(m0, m, n0, n); + } else if ((m_rem < 4) && (n_rem > 4)) { + nc = 4; + switch(m_rem) { + case 1: + mc = 1; + gemm_small<1, 4>(m0, m, n0, n); + break; + case 2: + mc = 2; + gemm_small<2, 4>(m0, m, n0, n); + break; + case 3: + mc = 3; + gemm_small<3, 4>(m0, m, n0, n); + break; + default: + return; + } + } else if ((m_rem > 4) && (n_rem < 4)) { + mc = 4; + switch(n_rem) { + case 1: + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 2: + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 3: + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + default: + return; + } + } else { + switch((m_rem << 4) | n_rem) { + case 0x43: + mc = 4; + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 0x34: + mc = 3; + nc = 4; + gemm_small<3, 4>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm_small<3, 3>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm_small<3, 2>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm_small<3, 1>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm_small<2, 4>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm_small<2, 3>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm_small<2, 2>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm_small<2, 1>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm_small<1, 4>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm_small<1, 3>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm_small<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm_small<1, 1>(m0, m, n0, n); + break; + default: + return; + } + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + void KERNEL_4x8(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[16] = {0}; + acc_t acc_0, acc_1; + std::array comparray; + vector float fin_res[8] = {0}; + vector float vs[8] = {0}; + for (int l = 0; l < k; l++) { + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + packNormal((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + for(int x = 0; x < 8; x++) { + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]); + } + for (int I = 0; I<4; I++) { + for (int J = 0; J<4; J++) { + *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + } + } + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 4; i++) { + comparray[i] = 0; + int ca = 0; + const int8_t *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } + compute<4>(&acc_0, 0, 0, comparray, vs, fin_res); + compute<4>(&acc_1, 0, 4, comparray, vs, fin_res); + } + save_res<4, 4>(ii, jj, 0, fin_res); + save_res<4, 4>(ii, jj+4, 4, fin_res); + } + + void KERNEL_8x4(int64_t ii, int64_t jj) { + vec_t vec_A[16], vec_B[8] = {0}; + acc_t acc_0, acc_1; + std::array comparray; + vector float fin_res[8] = {0}; + vector float vs[8] = {0}; + for (int l = 0; l < k; l++) { + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); + for(int x = 0; x < 8; x++) { + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); + } + for (int I = 0; I<8; I++) { + for (int J = 0; J<4; J++) { + *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + } + } + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 8; i++) { + comparray[i] = 0; + int ca = 0; + const int8_t *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } + compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); + compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); + } + save_res<4, 4>(ii, jj, 0, fin_res); + save_res<4, 4>(ii+4, jj, 4, fin_res); + } + + void KERNEL_8x8(int64_t ii, int64_t jj) { + vec_t vec_A[16], vec_B[16] = {0}; + acc_t acc_0, acc_1, acc_2, acc_3; + std::array comparray; + vector float fin_res[16] = {0}; + vector float vs[16] = {0}; + for (int l = 0; l < k; l++) { + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(&acc_2); + __builtin_mma_xxsetaccz(&acc_3); + packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + for(int x = 0; x < 8; x++) { + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]); + } + for (int I = 0; I<8; I++) { + for (int J = 0; J<4; J++) { + *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + } + } + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < 8; i++) { + comparray[i] = 0; + int ca = 0; + const int8_t *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } + compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); + compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); + compute<8>(&acc_2, 0, 8, comparray, vs, fin_res); + compute<8>(&acc_3, 4, 12, comparray, vs, fin_res); + } + save_res<4, 4>(ii, jj, 0, fin_res); + save_res<4, 4>(ii+4, jj, 4, fin_res); + save_res<4, 4>(ii, jj+4, 8, fin_res); + save_res<4, 4>(ii+4, jj+4, 12, fin_res); + } + + template + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + vec_t vec_A[8], vec_B[8] = {0}; + vector signed int vec_C[4]; + acc_t acc_0; + + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + std::array comparray; + vector float res[4] = {0}; + vector float fin_res[4] = {0}; + vector float vs[4] = {0}; + vector float CA[4] = {0}; + __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value + for (int l = 0; l < k; l++) { + __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead + __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead + __builtin_mma_xxsetaccz(&acc_0); + packNormal((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); + for(int x = 0; x < 8; x+=4) { + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]); + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]); + __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]); + } + for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d)); + } + } + __builtin_mma_disassemble_acc(vec_C, &acc_0); + auto aoffset = A+(ii*lda)+l; + for (int i = 0; i < RM; i++) { + comparray[i] = 0; + int ca = 0; + const int8_t *at = aoffset->qs; + for (int j = 0; j < 32; j++) + ca += (int)*at++; + comparray[i] = ca; + aoffset += lda; + } + + for (int i = 0; i < RM; i++) { + CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]); + } + } + save_res(ii, jj, 0, fin_res); + } + } + + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else { + static_assert(false, "RN/RM values not supported"); + } + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + kernel(ii, jj); + } + } + + const TA *const A; + const TB *const B; + TC *C; + TA *At; + TB *Bt; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; + +template +class tinyBLAS_PPC { + public: + tinyBLAS_PPC(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + + void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); + + template + void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) { + int64_t i, j; + TA *aoffset = NULL, *boffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; + VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; + VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; + VA t1, t2, t3, t4, t5, t6, t7, t8; + aoffset = const_cast(a); + boffset = vec; + j = (rows >> 3); + if (j > 0) { + do { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset5 = aoffset4 + lda; + aoffset6 = aoffset5 + lda; + aoffset7 = aoffset6 + lda; + aoffset8 = aoffset7 + lda; + aoffset += 8 * lda; + i = (cols >> 3); + if (i > 0) { + do { + C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); + C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); + C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); + C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); + C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5); + C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6); + C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7); + C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8); + __builtin_vsx_disassemble_pair(c1, &C1); + __builtin_vsx_disassemble_pair(c2, &C2); + __builtin_vsx_disassemble_pair(c3, &C3); + __builtin_vsx_disassemble_pair(c4, &C4); + __builtin_vsx_disassemble_pair(c5, &C5); + __builtin_vsx_disassemble_pair(c6, &C6); + __builtin_vsx_disassemble_pair(c7, &C7); + __builtin_vsx_disassemble_pair(c8, &C8); + + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); + t3 = vec_mergeh(c5[0], c6[0]); + t4 = vec_mergeh(c7[0], c8[0]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset); + vec_xst(t6, 0, boffset+4); + vec_xst(t7, 0, boffset+8); + vec_xst(t8, 0, boffset+12); + + t1 = vec_mergel(c1[0], c2[0]); + t2 = vec_mergel(c3[0], c4[0]); + t3 = vec_mergel(c5[0], c6[0]); + t4 = vec_mergel(c7[0], c8[0]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+16); + vec_xst(t6, 0, boffset+20); + vec_xst(t7, 0, boffset+24); + vec_xst(t8, 0, boffset+28); + + t1 = vec_mergeh(c1[1], c2[1]); + t2 = vec_mergeh(c3[1], c4[1]); + t3 = vec_mergeh(c5[1], c6[1]); + t4 = vec_mergeh(c7[1], c8[1]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+32); + vec_xst(t6, 0, boffset+36); + vec_xst(t7, 0, boffset+40); + vec_xst(t8, 0, boffset+44); + + t1 = vec_mergel(c1[1], c2[1]); + t2 = vec_mergel(c3[1], c4[1]); + t3 = vec_mergel(c5[1], c6[1]); + t4 = vec_mergel(c7[1], c8[1]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+48); + vec_xst(t6, 0, boffset+52); + vec_xst(t7, 0, boffset+56); + vec_xst(t8, 0, boffset+60); + + aoffset1 += 8*lda; + aoffset2 += 8*lda; + aoffset3 += 8*lda; + aoffset4 += 8*lda; + boffset += 64; + i--; + } while(i > 0); + } + if (cols & 4) { + c1[0] = vec_xl(0, aoffset1); + c2[0] = vec_xl(0, aoffset2); + c3[0] = vec_xl(0, aoffset3); + c4[0] = vec_xl(0, aoffset4); + c5[0] = vec_xl(0, aoffset5); + c6[0] = vec_xl(0, aoffset6); + c7[0] = vec_xl(0, aoffset7); + c8[0] = vec_xl(0, aoffset8); + + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); + t3 = vec_mergeh(c5[0], c6[0]); + t4 = vec_mergeh(c7[0], c8[0]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset); + vec_xst(t6, 0, boffset+4); + vec_xst(t7, 0, boffset+8); + vec_xst(t8, 0, boffset+12); + + t1 = vec_mergel(c1[0], c2[0]); + t2 = vec_mergel(c3[0], c4[0]); + t3 = vec_mergel(c5[0], c6[0]); + t4 = vec_mergel(c7[0], c8[0]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+16); + vec_xst(t6, 0, boffset+20); + vec_xst(t7, 0, boffset+24); + vec_xst(t8, 0, boffset+28); + } + j--; + } while(j > 0); + } + + if (rows & 4) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + aoffset4 = aoffset3 + lda; + aoffset += 4 * lda; + i = (cols >> 3); + if (i > 0) { + do { + C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); + C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); + C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); + C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); + __builtin_vsx_disassemble_pair(c1, &C1); + __builtin_vsx_disassemble_pair(c2, &C2); + __builtin_vsx_disassemble_pair(c3, &C3); + __builtin_vsx_disassemble_pair(c4, &C4); + + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); + t3 = vec_mergel(c1[0], c2[0]); + t4 = vec_mergel(c3[0], c4[0]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset); + vec_xst(t6, 0, boffset+4); + vec_xst(t7, 0, boffset+8); + vec_xst(t8, 0, boffset+12); + + t1 = vec_mergeh(c1[1], c2[1]); + t2 = vec_mergeh(c3[1], c4[1]); + t3 = vec_mergel(c1[1], c2[1]); + t4 = vec_mergel(c3[1], c4[1]); + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); + vec_xst(t5, 0, boffset+16); + vec_xst(t6, 0, boffset+20); + vec_xst(t7, 0, boffset+24); + vec_xst(t8, 0, boffset+28); + + aoffset1 += 8*lda; + aoffset2 += 8*lda; + aoffset3 += 8*lda; + aoffset4 += 8*lda; + boffset += 32; + i--; + } while(i > 0); + } + + if (cols & 4) { + c1[0] = vec_xl(0, aoffset1); + c2[0] = vec_xl(0, aoffset2); + c3[0] = vec_xl(0, aoffset3); + c4[0] = vec_xl(0, aoffset4); + + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); + t3 = vec_xxpermdi(t1, t2, 0); + t4 = vec_xxpermdi(t1, t2, 3); + vec_xst(t3, 0, boffset); + vec_xst(t4, 0, boffset+4); + + t1 = vec_mergel(c1[0], c2[0]); + t2 = vec_mergel(c3[0], c4[0]); + t3 = vec_xxpermdi(t1, t2, 0); + t4 = vec_xxpermdi(t1, t2, 3); + vec_xst(t3, 0, boffset+8); + vec_xst(t4, 0, boffset+12); + } + } + if (rows & 3) { + aoffset1 = aoffset; + aoffset2 = aoffset1 + lda; + aoffset3 = aoffset2 + lda; + if (cols & 4) { + c1[0] = vec_xl(0, aoffset1); + c2[0] = vec_xl(0, aoffset2); + c3[0] = vec_xl(0, aoffset3); + + t1 = vec_mergeh(c1[0], c2[0]); + t2 = vec_mergeh(c3[0], c4[0]); + t3 = vec_xxpermdi(t1, t2, 0); + t4 = vec_xxpermdi(t1, t2, 3); + vec_xst(t3, 0, boffset); + vec_xst(t4, 0, boffset+4); + + t1 = vec_mergel(c1[0], c2[0]); + t2 = vec_mergel(c3[0], c4[0]); + t3 = vec_xxpermdi(t1, t2, 0); + t4 = vec_xxpermdi(t1, t2, 3); + vec_xst(t3, 0, boffset+8); + vec_xst(t4, 0, boffset+12); + } + } + } + void KERNEL_4x4(int64_t ii, int64_t jj) { + vec_t vec_A[4], vec_B[4], vec_C[4]; + acc_t acc_0; + __builtin_mma_xxsetaccz(&acc_0); + for (int l = 0; l < k; l+=4) { + packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); + } + SAVE_ACC(&acc_0, ii, jj); + } + + void KERNEL_4x8(int64_t ii, int64_t jj) { + vec_t vec_A[4], vec_B[8], vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int64_t l = 0; l < k; l+=4) { + packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]); + __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]); + __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]); + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + } + + void KERNEL_8x4(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[4], vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int64_t l = 0; l < k; l+=4) { + packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]); + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii+4, jj); + } + + void KERNEL_8x8(int64_t ii, int64_t jj) { + vec_t vec_A[16], vec_B[16], vec_C[4]; + acc_t acc_0, acc_1, acc_2, acc_3; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(&acc_2); + __builtin_mma_xxsetaccz(&acc_3); + for (int l = 0; l < k; l+=8) { + packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B); + for(int x = 0; x < 16; x+=2) { + __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); + __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]); + __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]); + } + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + SAVE_ACC(&acc_2, ii+4, jj); + SAVE_ACC(&acc_3, ii+4, jj+4); + } + + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + int m_rem = MIN(m - m0, 16); + int n_rem = MIN(n - n0, 16); + if (m_rem >= 16 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if(m_rem >= 8 && n_rem >= 16) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 8) { + mc = 4; + nc = 8; + gemm<4,8>(m0, m, n0, n); + } else if (m_rem >= 8 && n_rem >= 4) { + mc = 8; + nc = 4; + gemm<8,4>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 4) { + mc = 4; + nc = 4; + gemm<4,4>(m0, m, n0, n); + } else if ((m_rem < 4) && (n_rem > 4)) { + nc = 4; + switch(m_rem) { + case 1: + mc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 2: + mc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 3: + mc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + default: + return; + } + } else if ((m_rem > 4) && (n_rem < 4)) { + mc = 4; + switch(n_rem) { + case 1: + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 2: + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 3: + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + default: + return; + } + } else { + switch((m_rem << 4) | n_rem) { + case 0x43: + mc = 4; + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x42: + mc = 4; + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x41: + mc = 4; + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x34: + mc = 3; + nc = 4; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x33: + mc = 3; + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x32: + mc = 3; + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x31: + mc = 3; + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x24: + mc = 2; + nc = 4; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x23: + mc = 2; + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x22: + mc = 2; + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x21: + mc = 2; + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x14: + mc = 1; + nc = 4; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x13: + mc = 1; + nc = 3; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x12: + mc = 1; + nc = 2; + gemm_small(m0, m, n0, n, mc, nc); + break; + case 0x11: + mc = 1; + nc = 1; + gemm_small(m0, m, n0, n, mc, nc); + break; + default: + return; + } + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + vec_t vec_C[4]; + acc_t acc_0; + __builtin_mma_xxsetaccz(&acc_0); + vec_t vec_A[4], vec_B[4]; + for (int l=0; l= 4 && RM == 1) { + TA* a = const_cast(A+(ii)*lda+l); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + vec_A[0] = (vec_t)vec_xl(0,a); + vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); + } else { + packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + } + __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); + __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); + } + __builtin_mma_disassemble_acc(vec_C, &acc_0); + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J); + } + } + } + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (RM == 4 && RN == 4) { + kernel = &tinyBLAS_PPC::KERNEL_4x4; + } else if (RM == 4 && RN == 8) { + kernel = &tinyBLAS_PPC::KERNEL_4x8; + } else if (RM == 8 && RN == 4) { + kernel = &tinyBLAS_PPC::KERNEL_8x4; + } else if (RM == 8 && RN == 8) { + kernel = &tinyBLAS_PPC::KERNEL_8x8; + } + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + (this->*kernel)(ii, jj); + } + } + + const TA *const A; + const TB *const B; + TC *C; + TA *At; + TB *Bt; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; +#endif +} // namespace + +/** + * Performs optimized matrix multiplication on CPU. + * + * This subroutine may compute C = Aᵀ * B with column major ordering. + * Despite its name, this isn't a generalized implementation. Work is + * only performed when a handwritten kernel is written and available. + * Otherwise the caller should fall back to a general matmul routine. + * + * For example, for single-threaded single-precision GEMM you can say + * + * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, + * 0, 1, + * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32); + * + * @param m is rows in `A` and `C` + * @param n is cols in `B` and `C` + * @param k is cols in `A` and rows in `B` + * @param A is first input matrix (always transposed) + * @param lda is row stride of `A` + * @param B is second input matrix (never transposed) + * @param ldb is row stride of `B` + * @param C is input/output array of output matrices + * @param ldc is row stride of `C` + * @param ith is thread id (must be less than `nth`) + * @param nth is number of threads (must be greater than zero) + * @param Atype is GGML data type of `A` + * @param Btype is GGML data type of `B` + * @param Ctype is GGML data type of `C` + * @return true if this function was able to service the matmul request + */ +bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k, + const void *A, int64_t lda, const void *B, int64_t ldb, void *C, + int64_t ldc, int Atype, int Btype, int Ctype) { + + assert(m >= 0); + assert(n >= 0); + assert(k >= 0); + assert(lda >= k); + assert(ldb >= k); + assert(ldc >= m); + assert(params->nth > 0); + assert(params->ith < params->nth); + + // only enable sgemm for prompt processing + if (n < 2) + return false; + + if (Ctype != GGML_TYPE_F32) + return false; + + switch (Atype) { + + case GGML_TYPE_F32: { + if (Btype != GGML_TYPE_F32) + return false; +#if defined(__AVX512F__) + tinyBLAS<16, __m512, __m512, float, float, float> tb{ params, + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); +#elif defined(__AVX__) || defined(__AVX2__) + tinyBLAS<8, __m256, __m256, float, float, float> tb{ params, + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); +#elif defined(__ARM_NEON) + if (n < 4) + return false; + tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params, + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); +#elif defined(__MMA__) + if (k % 8) + return false; + tinyBLAS_PPC tb{ + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_BF16: { +#if defined(__AVX512BF16__) + if (Btype == GGML_TYPE_BF16) { + tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); + } +#elif defined(__AVX512F__) + if (Btype == GGML_TYPE_BF16) { + tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); + } +#elif defined(__AVX2__) + if (Btype == GGML_TYPE_BF16) { + tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); + } +#endif + return false; + } + case GGML_TYPE_F16: { +#if defined(__AVX512F__) + if (Btype == GGML_TYPE_F16) { + tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k, + (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); + } +#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) + if (Btype == GGML_TYPE_F16) { + tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k, + (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); + } +#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) + if (n < 8) + return false; + if (Btype == GGML_TYPE_F16) { + tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params, + k, (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); + } +#elif defined(__ARM_NEON) && !defined(_MSC_VER) + if (Btype == GGML_TYPE_F32) { + tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params, + k, (const ggml_fp16_t *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); + } +#endif + return false; + } + + case GGML_TYPE_Q8_0: { + if (Btype != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_FEATURE_DOTPROD) + tinyBLAS_Q0_ARM tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; + +#elif defined(__MMA__) + if (n < 8 && n != 4) + return false; + if (m < 8 && m != 4) + return false; + tinyBLAS_Q0_PPC tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; + +#else + return false; +#endif + } + + case GGML_TYPE_Q4_0: { + if (Btype != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; +#elif defined(__ARM_FEATURE_DOTPROD) + tinyBLAS_Q0_ARM tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_Q5_0: { + if (Btype != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ + k, (const block_q5_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_IQ4_NL: { + if (Btype != GGML_TYPE_Q8_0) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) + tinyBLAS_Q0_AVX tb{ + k, (const block_iq4_nl *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; +#else + return false; +#endif + } + + default: + return false; + } + + (void)params; + (void)m; + (void)n; + (void)k; + (void)A; + (void)lda; + (void)B; + (void)ldb; + (void)C; + (void)ldc; + (void)Atype; + (void)Btype; + (void)Ctype; +} diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.h b/ggml/src/ggml-cpu/llamafile/sgemm.h new file mode 100644 index 000000000..3d2909515 --- /dev/null +++ b/ggml/src/ggml-cpu/llamafile/sgemm.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include +#ifdef __cplusplus +extern "C" { +#endif + +bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t, int64_t, int64_t, + const void *, int64_t, const void *, int64_t, void *, int64_t, + int, int, int); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt new file mode 100644 index 000000000..14761650f --- /dev/null +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -0,0 +1,152 @@ +cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES + +find_package(CUDAToolkit) + +if (CUDAToolkit_FOUND) + message(STATUS "CUDA Toolkit found") + + if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + # native == GPUs available at build time + # 52 == Maxwell, lowest CUDA 12 standard + # 60 == P100, FP16 CUDA intrinsics + # 61 == Pascal, __dp4a instruction (per-byte integer dot product) + # 70 == V100, FP16 tensor cores + # 75 == Turing, int8 tensor cores + if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + set(CMAKE_CUDA_ARCHITECTURES "native") + elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) + set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") + else() + set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") + endif() + endif() + message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") + + enable_language(CUDA) + + file(GLOB GGML_HEADERS_CUDA "*.cuh") + list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") + + file(GLOB GGML_SOURCES_CUDA "*.cu") + file(GLOB SRCS "template-instances/fattn-wmma*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmq*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + + if (GGML_CUDA_FA_ALL_QUANTS) + file(GLOB SRCS "template-instances/fattn-vec*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) + else() + file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + endif() + + ggml_add_backend_library(ggml-cuda + ${GGML_HEADERS_CUDA} + ${GGML_SOURCES_CUDA} + ) + + add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) + + if (GGML_CUDA_GRAPHS) + add_compile_definitions(GGML_CUDA_USE_GRAPHS) + endif() + + if (GGML_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() + + if (GGML_CUDA_FORCE_CUBLAS) + add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) + endif() + + if (GGML_CUDA_NO_VMM) + add_compile_definitions(GGML_CUDA_NO_VMM) + endif() + + if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) + add_compile_definitions(GGML_CUDA_F16) + endif() + + if (GGML_CUDA_NO_PEER_COPY) + add_compile_definitions(GGML_CUDA_NO_PEER_COPY) + endif() + + if (GGML_STATIC) + if (WIN32) + # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt) + else () + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + endif() + else() + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + if (GGML_CUDA_NO_VMM) + # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so) + else() + target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver) + endif() + + set(CUDA_CXX_FLAGS "") + + set(CUDA_FLAGS -use_fast_math) + + if (GGML_FATAL_WARNINGS) + list(APPEND CUDA_FLAGS -Werror all-warnings) + endif() + + if (GGML_ALL_WARNINGS AND NOT MSVC) + set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c) + if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "") + list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER}) + endif() + + execute_process( + COMMAND ${NVCC_CMD} -Xcompiler --version + OUTPUT_VARIABLE CUDA_CCFULLVER + ERROR_QUIET + ) + + if (NOT CUDA_CCFULLVER MATCHES clang) + set(CUDA_CCID "GNU") + execute_process( + COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion" + OUTPUT_VARIABLE CUDA_CCVER + ERROR_QUIET + ) + else() + if (CUDA_CCFULLVER MATCHES Apple) + set(CUDA_CCID "AppleClang") + else() + set(CUDA_CCID "Clang") + endif() + string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER}) + endif() + + message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") + + ggml_get_flags(${CUDA_CCID} ${CUDA_CCVER}) + list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later + endif() + + if (NOT MSVC) + list(APPEND CUDA_CXX_FLAGS -Wno-pedantic) + endif() + + list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument + + if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "") + list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED}) + endif() + + target_compile_options(ggml-cuda PRIVATE "$<$:${CUDA_FLAGS}>") +else() + message(FATAL_ERROR "CUDA Toolkit not found") +endif() diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu new file mode 100644 index 000000000..5340eedc0 --- /dev/null +++ b/ggml/src/ggml-cuda/argmax.cu @@ -0,0 +1,91 @@ +#include +#include + +#include "argmax.cuh" +#include "common.cuh" +#include "sum.cuh" + +static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) { + const int64_t row = blockIdx.x; + + float maxval = -FLT_MAX; + int argmax = -1; + const float * rowx = x + row * ncols; + + for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) { + const float val = rowx[col]; + if (val > maxval) { + maxval = val; + argmax = col; + } + } + +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } + } + + const int n_warps = blockDim.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + if (n_warps > 1) { + constexpr int max_warps = 1024 / WARP_SIZE; + __shared__ float shared_maxval[max_warps]; + __shared__ int shared_argmax[max_warps]; + if (lane_id == 0) { + shared_maxval[warp_id] = maxval; + shared_argmax[warp_id] = argmax; + } + + __syncthreads(); + + if (warp_id == 0) { + if (lane_id < n_warps) { + maxval = shared_maxval[lane_id]; + argmax = shared_argmax[lane_id]; + } +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } + } + } + } + + if (warp_id == 0 && lane_id == 0) { + dst[row] = argmax; + } +} + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + int32_t * dst_d = (int32_t *) dst->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t num_blocks = nrows; + const int64_t num_threads = std::min(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); + const dim3 blocks_dim(num_threads, 1, 1); + const dim3 blocks_num(num_blocks, 1, 1); + + argmax_f32<<>>(src0_d, dst_d, ne00); +} diff --git a/ggml/src/ggml-cuda/argmax.cuh b/ggml/src/ggml-cuda/argmax.cuh new file mode 100644 index 000000000..5b7223adc --- /dev/null +++ b/ggml/src/ggml-cuda/argmax.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index e1390a041..ce4b9cfb5 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -1,4 +1,5 @@ #include "binbcast.cuh" +#include static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; @@ -90,6 +91,35 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } +template +static __global__ void k_repeat_back( + const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) { + + const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x; + const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y; + const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z; + const int64_t tid2 = tid23 % ne2; + const int64_t tid3 = tid23 / ne2; + + if (tid0 >= ne0) { + return; + } + + T sum = 0; + for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { + for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { + for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { + for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { + sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00]; + } + } + } + } + dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; +} + template struct bin_bcast_cuda { template @@ -247,6 +277,18 @@ struct bin_bcast_cuda { } }; +template +static void repeat_back_cuda( + const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3); + k_repeat_back<<>> + (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3); +} + template static void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, @@ -286,3 +328,34 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } + +void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_can_repeat(dst, src0)); + + cudaStream_t stream = ctx.stream(); + + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(ne2*ne3 <= (1 << 15)); + + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = nb00 / ts; + const size_t s01 = nb01 / ts; + const size_t s02 = nb02 / ts; + const size_t s03 = nb03 / ts; + + switch (dst->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream); + } break; + default: { + GGML_ASSERT(false); + } break; + } +} diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 198c9ef6f..3ac1c9b03 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eb39b6d23..8d8d3932e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -6,7 +6,7 @@ #include #include -#if defined(GGML_USE_HIPBLAS) +#if defined(GGML_USE_HIP) #define GGML_COMMON_DECL_HIP #define GGML_COMMON_IMPL_HIP #else @@ -26,13 +26,13 @@ #include #include -#if defined(GGML_USE_HIPBLAS) +#if defined(GGML_USE_HIP) #include "vendors/hip.h" #elif defined(GGML_USE_MUSA) #include "vendors/musa.h" #else #include "vendors/cuda.h" -#endif // defined(GGML_USE_HIPBLAS) +#endif // defined(GGML_USE_HIP) #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) @@ -41,15 +41,28 @@ #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons -#define CC_PASCAL 600 -#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products -#define CC_VOLTA 700 -#define CC_TURING 750 -#define CC_AMPERE 800 -#define CC_OFFSET_AMD 1000000 -#define CC_RDNA1 (CC_OFFSET_AMD + 1010) -#define CC_RDNA2 (CC_OFFSET_AMD + 1030) -#define CC_RDNA3 (CC_OFFSET_AMD + 1100) +#define GGML_CUDA_CC_PASCAL 600 +#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define GGML_CUDA_CC_VOLTA 700 +#define GGML_CUDA_CC_TURING 750 +#define GGML_CUDA_CC_AMPERE 800 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 + +// GCN/CNDA, wave size is 64 +#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 +#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue +#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a +#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 + +// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32 +#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 +#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a +#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA + +#define GGML_CUDA_CC_QY1 210 +#define GGML_CUDA_CC_QY2 220 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -95,7 +108,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIP) static const char * cu_get_error_str(CUresult err) { const char * err_str; cuGetErrorString(err, &err_str); @@ -118,46 +131,54 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif // GGML_CUDA_F16 -#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL +#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) +#define GGML_USE_VMM +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) + +#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #define FP16_AVAILABLE -#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL +#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #define FAST_FP16_AVAILABLE #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #define FP16_MMA_AVAILABLE -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define INT8_MMA_AVAILABLE -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + +#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) +#define FLASH_ATTN_AVAILABLE +#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) static constexpr bool fast_fp16_available(const int cc) { - return cc >= CC_PASCAL && cc != 610; + return cc >= GGML_CUDA_CC_PASCAL && cc != 610; } static constexpr bool fp16_mma_available(const int cc) { - return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; + return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; } static constexpr bool int8_mma_available(const int cc) { - return cc < CC_OFFSET_AMD && cc >= CC_TURING; + return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING; } [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n", file_name, line, function_name, arch); GGML_UNUSED(arch_list); #else printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n", file_name, line, function_name, arch, arch_list); -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) __trap(); GGML_UNUSED(no_device_code); // suppress unused function warning @@ -169,41 +190,46 @@ static __device__ void no_device_code( #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #endif // __CUDA_ARCH__ +template +static __device__ __forceinline__ int warp_reduce_sum(int x) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + return __reduce_add_sync(0xffffffff, x); +#else +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); + } + return x; +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +} + +template static __device__ __forceinline__ float warp_reduce_sum(float x) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask, 32); + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); } return x; } +template static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32); - a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32); + for (int offset = width/2; offset > 0; offset >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); + a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); } return a; } +template static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #ifdef FP16_AVAILABLE - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32); - reinterpret_cast(a.x) += __low2half(a_other); - reinterpret_cast(a.y) += __high2half(a_other); + for (int offset = width/2; offset > 0; offset >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); } return a; -#else -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); - } - return a; -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #else NO_DEVICE_CODE; @@ -211,10 +237,11 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + for (int offset = width/2; offset > 0; offset >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width)); } return x; } @@ -222,11 +249,11 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX return __float2half(fmaxf(__half2float(a), __half2float(b))); #else return __hmax(a, b); -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX #else NO_DEVICE_CODE; @@ -236,35 +263,34 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b } static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) - -#if CUDART_VERSION >= CUDART_HMAX +#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000 + return half2(__hmax(a.x, b.x), __hmax(a.y, b.y)); +#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX return __hmax2(a, b); -#else +#elif !defined(GGML_USE_HIP) half2 ret; reinterpret_cast(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b))); reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b))); return ret; -#endif // CUDART_VERSION >= CUDART_HMAX - #else GGML_UNUSED(a); GGML_UNUSED(b); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif } +template static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + for (int offset = width/2; offset > 0; offset >>= 1) { + x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width)); } return x; #else GGML_UNUSED(x); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000) } #if CUDART_VERSION < CUDART_HMASK @@ -276,7 +302,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half #endif // CUDART_VERSION < CUDART_HMASK static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2) c = __builtin_amdgcn_sdot4(a, b, c, false); #elif defined(RDNA3) @@ -302,17 +328,17 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif return c; -#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#if __CUDA_ARCH__ >= MIN_CC_DP4A +#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A return __dp4a(a, b, c); -#else // __CUDA_ARCH__ >= MIN_CC_DP4A +#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A const int8_t * a8 = (const int8_t *) &a; const int8_t * b8 = (const int8_t *) &b; return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3]; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } // TODO: move to ggml-common.h @@ -487,6 +513,7 @@ struct ggml_cuda_device_info { bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory size_t total_vram; + int warp_size; // Number of threads in a dispatch }; cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; @@ -559,7 +586,7 @@ struct ggml_tensor_extra_gpu { }; -#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS) +#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS) #define USE_CUDA_GRAPH #endif @@ -569,6 +596,7 @@ struct ggml_graph_node_properties { int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; void * src_address[GGML_MAX_SRC]; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; }; struct ggml_cuda_graph { diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index dac10ec36..aafbaf803 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -94,7 +94,9 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n } // non-contiguous kernel (slow) -static __global__ void concat_f32_non_cont( +template +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) + concat_f32_non_cont( const char * src0, const char * src1, char * dst, @@ -121,22 +123,28 @@ static __global__ void concat_f32_non_cont( uint64_t nb0, uint64_t nb1, uint64_t nb2, - uint64_t nb3, - int32_t dim) { + uint64_t nb3){ + static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]"); + const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - const float * x; - for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + if constexpr (dim == 0) { + x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10); + } else if constexpr (dim == 1) { + x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10); + } else if constexpr (dim == 2) { + x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10); + } else if constexpr (dim == 3) { + x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10); + } } float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -182,15 +190,32 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); - concat_f32_non_cont<<>>( - (const char *)src0->data, - (const char *)src1->data, - ( char *)dst->data, + auto launch_kernel = [&](auto dim) { + concat_f32_non_cont<<>>( + (const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]); + }; + switch (dim) { + case 0: + launch_kernel(std::integral_constant{}); + break; + case 1: + launch_kernel(std::integral_constant{}); + break; + case 2: + launch_kernel(std::integral_constant{}); + break; + case 3: + launch_kernel(std::integral_constant{}); + break; + default: + GGML_ABORT("Invalid dim: %d", dim); + break; + } } } diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index c0a444707..5b0dfacef 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -26,7 +26,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ template static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) { -#if __CUDA_ARCH__ >= CC_PASCAL +#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE; const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x; @@ -64,7 +64,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h GGML_UNUSED(y); GGML_UNUSED(k); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_PASCAL +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL } template @@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_Q5_1: return dequantize_block_cuda; case GGML_TYPE_Q8_0: - if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) { + if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) { return dequantize_block_q8_0_f16_cuda; } return dequantize_block_cuda; @@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_BF16: + return convert_unary_cuda; default: return nullptr; } diff --git a/ggml/src/ggml-cuda/count-equal.cu b/ggml/src/ggml-cuda/count-equal.cu new file mode 100644 index 000000000..08898115d --- /dev/null +++ b/ggml/src/ggml-cuda/count-equal.cu @@ -0,0 +1,64 @@ +#include "common.cuh" +#include "count-equal.cuh" + +#include + +template +static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) { + const int64_t i0 = (int64_t) blockIdx.x*dk; + const int64_t i1 = min(i0 + dk, k); + + int nequal = 0; + + for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) { + const T xi = x[i]; + const T yi = y[i]; + nequal += xi == yi; + } + + nequal = warp_reduce_sum(nequal); + + if (threadIdx.x != 0) { + return; + } + + atomicAdd((int *) dst, nequal); +} + +void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT( dst->type == GGML_TYPE_I64); + + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + int64_t * dst_d = (int64_t *) dst->data; + + cudaStream_t stream = ctx.stream(); + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int"); + const int64_t dne = GGML_PAD((ne + 4*nsm - 1) / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE); + + CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream)); + + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1); + + switch (src0->type) { + case GGML_TYPE_I32: { + const int * src0_d = (const int *) src0->data; + const int * src1_d = (const int *) src1->data; + count_equal<<>>(src0_d, src1_d, dst_d, dne, ne); + } break; + default: + GGML_ASSERT(false); + break; + } +} diff --git a/ggml/src/ggml-cuda/count-equal.cuh b/ggml/src/ggml-cuda/count-equal.cuh new file mode 100644 index 000000000..8467da79e --- /dev/null +++ b/ggml/src/ggml-cuda/count-equal.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128 + +void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 51deb75fd..54c0f66d2 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -81,6 +81,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { } } +static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { + const block_q8_0 * xi = (const block_q8_0 *) cxi; + float * dsti = (float *) cdsti; + + const float d = (float)xi->d; + + for (int j = 0; j < QK8_0; j++) { + dsti[j] = xi->qs[j] * d; + } +} + static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { const float * xi = (const float *) cxi; block_q4_0 * dsti = (block_q4_0 *) cdsti; @@ -288,6 +299,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, cpy_blck(cx + x_offset, cdst + dst_offset); } +template +static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13) { + const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; + + if (i >= ne) { + return; + } + + const int i03 = i/(ne00 * ne01 * ne02); + const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + + const int i13 = i/(ne10 * ne11 * ne12); + const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + static void ggml_cpy_f16_f32_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -329,6 +366,16 @@ static void ggml_cpy_f32_q8_0_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } +static void ggml_cpy_q8_0_f32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + + const int num_blocks = ne; + cpy_q_f32<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + static void ggml_cpy_f32_q4_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -437,6 +484,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { @@ -471,6 +520,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { return (void*) cpy_f32_q; + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_q_f32; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 796167426..28b06cdda 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -1,6 +1,6 @@ #include "common.cuh" -#define CUDA_CPY_BLOCK_SIZE 32 +#define CUDA_CPY_BLOCK_SIZE 64 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index 5575a90f6..27599a2b0 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -5,72 +5,92 @@ #include #include -static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) { - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE; +template +static __global__ void cross_entropy_loss_f32( + const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) { + extern __shared__ float tmp[]; - const int ne_tmp = WARP_SIZE*nclasses; - - extern __shared__ float tmp_all[]; - float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp; - float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp; - - // Each warp first loads ne_tmp logits/labels into shared memory: - for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) { - const int ig = i0*nclasses + i; // ig == i global - - tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f; - tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f; - } - - // Each thread in the warp then calculates the cross entropy loss for a single row. - // TODO: pad in order to avoid shared memory bank conflicts. + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; // Find maximum for softmax: - float max = -INFINITY; - for (int i = 0; i < nclasses; ++i) { - max = fmaxf(max, tmp_logits[lane_id*nclasses + i]); + float max_logit = -INFINITY; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = logits[i]; + max_logit = fmaxf(max_logit, val); + + if (use_shared) { + tmp[i] = val; + } } + max_logit = warp_reduce_max(max_logit); // Calculate log(softmax(logits)) which is just logits - max: float sum = 0.0f; - for (int i = 0; i < nclasses; ++i) { - float val = tmp_logits[lane_id*nclasses + i] - max; - sum += expf(val); - tmp_logits[lane_id*nclasses + i] = val; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + sum += expf(logit_i - max_logit); } + sum = warp_reduce_sum(sum); sum = logf(sum); // log(exp(logits - max) / sum) = (logits - max) - log(sum) float loss = 0.0f; - for (int i = 0; i < nclasses; ++i) { - loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i]; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + loss += (logit_i - max_logit - sum) * labels[i]; } loss = -warp_reduce_sum(loss) / (float)k; - __syncthreads(); - - if (lane_id == 0) { - tmp_all[warp_id] = loss; - } - - __syncthreads(); - - if (warp_id != 0) { - return; - } - - loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f; - loss = warp_reduce_sum(loss); - - if (lane_id != 0) { + if (threadIdx.x != 0) { return; } dst[blockIdx.x] = loss; } +template +static __global__ void cross_entropy_loss_back_f32( + const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels, + float * __restrict__ dst, const int nclasses) { + extern __shared__ float tmp[]; + + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; + dst += int64_t(blockIdx.x)*nclasses; + + float maxval = -INFINITY; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = logits[i]; + maxval = fmaxf(maxval, val); + + if (use_shared) { + tmp[i] = val; + } + } + maxval = warp_reduce_max(maxval); + + float sum = 0.0f; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval); + sum += val; + + if (use_shared) { + tmp[i] = val; + } else { + dst[i] = val; + } + } + sum = warp_reduce_sum(sum); + const float sm_scale = 1.0f/sum; + + const float d_by_nrows = *grad/gridDim.x; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = use_shared ? tmp[i] : dst[i]; + dst[i] = (val*sm_scale - labels[i])*d_by_nrows; + } +} + void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -93,14 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_pool & pool = ctx.pool(); cudaStream_t stream = ctx.stream(); - const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); - const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1); - const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float); + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(nrows, 1, 1); + const size_t nbytes_shared = ne00*sizeof(float); + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x); - cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + if (nbytes_shared <= smpbo) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } else { + cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } + CUDA_CHECK(cudaGetLastError()); // Combine results from individual blocks: sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream); } + +void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * grad = dst->src[0]; + const ggml_tensor * src0f = dst->src[1]; + const ggml_tensor * src1f = dst->src[2]; + + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT(src1f->type == GGML_TYPE_F32); + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_scalar(grad)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f)); + GGML_ASSERT(ggml_are_same_shape(src0f, dst)); + + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); + + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + const float * src1f_d = (const float *) src1f->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(nrows, 1, 1); + const size_t nbytes_shared = ne00*sizeof(float); + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + if (nbytes_shared <= smpbo) { +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) + cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } else { + cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } +} diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cuh b/ggml/src/ggml-cuda/cross-entropy-loss.cuh index 9d7b8b0f0..9ec7152ff 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cuh +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cuh @@ -3,3 +3,5 @@ #define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256 void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu deleted file mode 100644 index 96a5adef5..000000000 --- a/ggml/src/ggml-cuda/dmmv.cu +++ /dev/null @@ -1,683 +0,0 @@ -#include "dmmv.cuh" -#include "dequantize.cuh" -#include "convert.cuh" - -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 2 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - -static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q2_K * x = (const block_q2_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 16/K_QUANTS_PER_ITERATION; - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int s_offset = 8*im; - const int y_offset = 128*im + l0; - - uint32_t aux[4]; - const uint8_t * d = (const uint8_t *)aux; - const uint8_t * m = (const uint8_t *)(aux + 2); - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); - aux[0] = a[0] & 0x0f0f0f0f; - aux[1] = a[1] & 0x0f0f0f0f; - aux[2] = (a[0] >> 4) & 0x0f0f0f0f; - aux[3] = (a[1] >> 4) & 0x0f0f0f0f; - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) - + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) - + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) - + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) - + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) - + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) - + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) - +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); - sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] - + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; - - } - tmp += dall * sum1 - dmin * sum2; - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q3_K * x = (const block_q3_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop - const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0....15 or 0...7 - - const uint8_t m = 1 << (4*im); - - const int l0 = n*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int y_offset = 128*im + l0; - - uint16_t utmp[4]; - const int8_t * s = (const int8_t *)utmp; - - const uint16_t s_shift = 4*im; - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - const uint8_t * h = x[i].hmask + l0; - - const uint16_t * a = (const uint16_t *)x[i].scales; - utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); - utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); - utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); - utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); - - const float d = x[i].d; - - float sum = 0; - for (int l = 0; l < n; ++l) { - sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) - + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) - + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) - + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); - sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) - + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) - + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) - + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); - } - tmp += d * sum; - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q4_K * x = (const block_q4_K *)vx + ib0; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 - - const int il = tid/step; // 0...3 - const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - -#if K_QUANTS_PER_ITERATION == 2 - uint32_t q32[4]; - const uint8_t * q4 = (const uint8_t *)q32; -#else - uint16_t q16[4]; - const uint8_t * q4 = (const uint8_t *)q16; -#endif - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - -#if K_QUANTS_PER_ITERATION == 2 - const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); - const uint32_t * q2 = q1 + 16; - - q32[0] = q1[0] & 0x0f0f0f0f; - q32[1] = q1[0] & 0xf0f0f0f0; - q32[2] = q2[0] & 0x0f0f0f0f; - q32[3] = q2[0] & 0xf0f0f0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 4; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; - s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#else - const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); - const uint16_t * q2 = q1 + 32; - - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[0] & 0xf0f0; - q16[2] = q2[0] & 0x0f0f; - q16[3] = q2[0] & 0xf0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 2; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; - s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#endif - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { - - const int row = blockIdx.x; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q5_K * x = (const block_q5_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/2; // 0...15 - const int ix = threadIdx.x%2; - - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 2; - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - const uint8_t hm1 = 1 << (2*im); - const uint8_t hm2 = hm1 << 4; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - - uint16_t q16[8]; - const uint8_t * q4 = (const uint8_t *)q16; - - for (int i = ix; i < num_blocks_per_row; i += 2) { - - const uint8_t * ql1 = x[i].qs + q_offset; - const uint8_t * qh = x[i].qh + l0; - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - - float4 sum = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - const uint16_t * q1 = (const uint16_t *)ql1; - const uint16_t * q2 = q1 + 32; - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[8] & 0x0f0f; - q16[2] = (q1[0] >> 4) & 0x0f0f; - q16[3] = (q1[8] >> 4) & 0x0f0f; - q16[4] = q2[0] & 0x0f0f; - q16[5] = q2[8] & 0x0f0f; - q16[6] = (q2[0] >> 4) & 0x0f0f; - q16[7] = (q2[8] >> 4) & 0x0f0f; - for (int l = 0; l < n; ++l) { - sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) - + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); - sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) - + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); - sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) - + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); - sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) - + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); - smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] - + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; - } - tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q6_K * x = (const block_q6_K *)vx + ib0; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - -#if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 - const int is = 0; -#else - const int l0 = 4 * in; // 0, 4, 8, ..., 28 - const int is = in / 4; -#endif - const int ql_offset = 64*im + l0; - const int qh_offset = 32*im + l0; - const int s_offset = 8*im + is; - const int y_offset = 128*im + l0; - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * ql = x[i].ql + ql_offset; - const uint8_t * qh = x[i].qh + qh_offset; - const int8_t * s = x[i].scales + s_offset; - - const float d = x[i].d; - -#if K_QUANTS_PER_ITERATION == 1 - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) - + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) - + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) - + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) - +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); - tmp += sum; -#else - float sum = 0; - for (int l = 0; l < 4; ++l) { - sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) - + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) - + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) - + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); - } - tmp += sum; -#endif - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ - const half * x = (const half *) vx; - - // automatic half -> float type cast if dfloat == float - v.x = x[ib + iqs + 0]; - v.y = x[ib + iqs + 1]; -} - -static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 : - type == GGML_TYPE_Q4_1 ? dequantize_q4_1 : - type == GGML_TYPE_Q5_0 ? dequantize_q5_0 : - type == GGML_TYPE_Q5_1 ? dequantize_q5_1 : - type == GGML_TYPE_Q8_0 ? dequantize_q8_0 : - type == GGML_TYPE_F16 ? convert_f16 : - nullptr; -} - -template -static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { - constexpr int qk = ggml_cuda_type_traits::qk; // quantized weights per x block - constexpr int qr = ggml_cuda_type_traits::qr; // number of quantized weights per data value in x block - constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type); - - const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; - - if (row >= nrows) { - return; - } - - const int tid = threadIdx.x; - - const int iter_stride = 2*GGML_CUDA_DMMV_X; - const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter - const int y_offset = qr == 1 ? 1 : qk/2; - -// partial sum for each thread -#ifdef GGML_CUDA_F16 - half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics -#else - float tmp = 0.0f; -#endif // GGML_CUDA_F16 - - for (int i = 0; i < ncols; i += iter_stride) { - const int col = i + vals_per_iter*tid; - const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index - const int iqs = (col%qk)/qr; // x quant index - const int iybs = col - col%qk; // y block start index - -// processing >2 values per i iter is faster for fast GPUs -#pragma unroll - for (int j = 0; j < vals_per_iter; j += 2) { - // process 2 vals per j iter - - // dequantize - // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val - dfloat2 v; - dequantize_kernel(vx, ib, iqs + j/qr, v); - - // matrix multiplication - // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 -#ifdef GGML_CUDA_F16 - tmp += __hmul2(v, { - y[iybs + iqs + j/qr + 0], - y[iybs + iqs + j/qr + y_offset] - }); -#else - tmp += v.x * y[iybs + iqs + j/qr + 0]; - tmp += v.y * y[iybs + iqs + j/qr + y_offset]; -#endif // GGML_CUDA_F16 - } - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (tid == 0) { -#ifdef GGML_CUDA_F16 - dst[row] = tmp.x + tmp.y; -#else - dst[row] = tmp; -#endif // GGML_CUDA_F16 - } -} - -static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 1, 1); - dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); -} - -static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); -} - -static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -void ggml_cuda_op_dequantize_mul_mat_vec( - ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, cudaStream_t stream) { - GGML_UNUSED(ctx); - const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics -#ifdef GGML_CUDA_F16 - ggml_cuda_pool_alloc src1_dfloat_a(ctx.pool()); - half * src1_dfloat = nullptr; // dfloat == half - - bool src1_convert_f16 = - src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || - src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; - - if (src1_convert_f16) { - src1_dfloat = src1_dfloat_a.alloc(ne00); - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream); - } -#else - const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion -#endif // GGML_CUDA_F16 - - switch (src0->type) { - case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_F16: - convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - default: - GGML_ABORT("fatal error"); - break; - } - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddq_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); -} - -bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) { - return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 || - src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 || - src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K || - src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K || - src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K || - src0_type == GGML_TYPE_F16; -} diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 1fb5c09c3..ee9752da6 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -517,9 +517,9 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { } template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 342f2eb66..4d314dacb 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -5,9 +5,9 @@ #define FATTN_KQ_STRIDE_TILE_F16 64 template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -259,7 +259,7 @@ static __global__ void flash_attn_tile_ext_f16( } half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); - kqsum_j = warp_reduce_sum(kqsum_j); + kqsum_j = warp_reduce_sum((float)kqsum_j); #pragma unroll for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) { diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 827437ca0..bb3360447 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -5,9 +5,9 @@ #define FATTN_KQ_STRIDE_TILE_F32 32 template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_tile_ext_f32( const char * __restrict__ Q, const char * __restrict__ K, @@ -44,13 +44,17 @@ static __global__ void flash_attn_tile_ext_f32( const int ne1, const int ne2, const int ne3) { +#ifndef FLASH_ATTN_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; return; } - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index 448a9a905..34a2992c7 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -2,9 +2,9 @@ #include "fattn-common.cuh" template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -196,7 +196,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum(sum); + sum = warp_reduce_sum((float)sum); if (use_logit_softcap) { sum = logit_softcap*tanhf(sum); @@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } @@ -265,7 +264,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum(kqsum[j]); + kqsum[j] = warp_reduce_sum((float)kqsum[j]); if (threadIdx.x == 0) { kqsum_shared[j][threadIdx.y] = kqsum[j]; } @@ -280,7 +279,7 @@ static __global__ void flash_attn_vec_ext_f16( } kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); + kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]); half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); if (parallel_blocks == 1) { diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index bf5125902..a28fc8b7f 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -2,9 +2,9 @@ #include "fattn-common.cuh" template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ Q, const char * __restrict__ K, @@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32( for (int j = 0; j < ncols; ++j) { float kqmax_new_j = kqmax_new_arr[j]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); if (threadIdx.x == 0) { kqmax_shared[j][threadIdx.y] = kqmax_new_j; } diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index b10d19d93..860d0e6dc 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -7,9 +7,9 @@ // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index f87f33b3e..0b26b0f8e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; - const int32_t precision = KQV->op_params[3]; + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - if (precision != GGML_PREC_DEFAULT) { + if (prec != GGML_PREC_DEFAULT) { if (Q->ne[1] <= 32 || Q->ne[0] > 128) { constexpr int cols_per_block = 16; switch (Q->ne[0]) { @@ -152,7 +152,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g } \ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[1]; + ggml_tensor * Q = dst->src[0]; ggml_tensor * K = dst->src[1]; ggml_tensor * V = dst->src[2]; @@ -227,7 +227,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg } \ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - ggml_tensor * Q = dst->src[1]; + ggml_tensor * Q = dst->src[0]; ggml_tensor * K = dst->src[1]; ggml_tensor * V = dst->src[2]; @@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int32_t precision = KQV->op_params[3]; + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); // On AMD the tile kernels perform poorly, use the vec kernel instead: - if (cc >= CC_OFFSET_AMD) { - if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + if (cc >= GGML_CUDA_CC_OFFSET_AMD) { + if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); @@ -314,7 +314,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (!fast_fp16_available(cc)) { - if (Q->ne[1] <= 8) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); } else { ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); @@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { - if (precision == GGML_PREC_DEFAULT) { + if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); return; } else if(Q->ne[0] <= 128) { diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 4c3703238..4cef53a98 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -3,15 +3,15 @@ template static __global__ void k_get_rows( - const void * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12/*, size_t s13*/) { + const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; @@ -22,10 +22,10 @@ static __global__ void k_get_rows( const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; - const int ib = i00/qk; // block index - const int iqs = (i00%qk)/qr; // quant index + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index const int iybs = i00 - i00%qk; // dst block start index const int y_offset = qr == 1 ? 1 : qk/2; @@ -39,15 +39,15 @@ static __global__ void k_get_rows( template static __global__ void k_get_rows_float( - const src0_t * src0, const int32_t * src1, dst_t * dst, - int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/ - /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/ - /*size_t s0,*/ size_t s1, size_t s2, size_t s3, - /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03, - size_t s10, size_t s11, size_t s12/*, size_t s13*/) { + const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = blockIdx.x*blockDim.x + threadIdx.x; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; + const int i00 = blockIdx.x*blockDim.x + threadIdx.x; + const int i10 = blockDim.y*blockIdx.y + threadIdx.y; const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; @@ -58,14 +58,38 @@ static __global__ void k_get_rows_float( const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; - const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03); + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); dst_row[i00] = src0_row[i00]; } +template +static __global__ void k_get_rows_back_float( + const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) { + const int col = blockIdx.x*blockDim.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int dst_row = blockIdx.y*blockDim.y + threadIdx.y; + + float sum = 0.0f; + + for (int64_t i = 0; i < nrows_grad; ++i) { + if (rows[i] != dst_row) { + continue; + } + sum += grad[i*ncols + col]; + } + + dst[dst_row*ncols + col] = sum; +} + template -static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { +static void get_rows_cuda( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg GGML_ASSERT(ne00 % 2 == 0); k_get_rows<<>>( - src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ - /* s0,*/ s1, s2, s3, - /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); GGML_UNUSED(dst); } template -static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { +static void get_rows_cuda_float( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(ne13 == 1); + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; const dim3 block_nums(block_num_x, ne10, ne11*ne12); @@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr //const size_t s13 = nb13 / ggml_element_size(src1); k_get_rows_float<<>>( - src0_dd, src1_dd, dst_dd, - ne00, /*ne01, ne02, ne03,*/ - /*ne10, ne11,*/ ne12, /*ne13,*/ - /* s0,*/ s1, s2, s3, - /* nb00,*/ nb01, nb02, nb03, - s10, s11, s12/*, s13*/); + src0_dd, src1_dd, dst_dd, + ne00, /*ne01, ne02, ne03,*/ + /*ne10, ne11,*/ ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); GGML_UNUSED(dst); } @@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + + const void * src0_d = (const void *) src0->data; + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); - GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); - - const int32_t * src1_i32 = (const int32_t *) src1_d; + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); switch (src0->type) { case GGML_TYPE_F16: - get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream); + get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_F32: - get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q4_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q4_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q5_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q5_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; case GGML_TYPE_Q8_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream); + get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); break; default: // TODO: k-quants @@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { break; } } + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output + const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass + + GGML_TENSOR_BINARY_OP_LOCALS + + const float * src0_d = (const float *) src0->data; + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(ne02*ne03 == 1); + GGML_ASSERT(ne12*ne13 == 1); + GGML_ASSERT(ne2*ne3 == 1); + + const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1); + const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE; + const dim3 block_nums(block_num_x, ne1, 1); + + k_get_rows_back_float<<>>(src0_d, src1_d, dst_d, ne00, ne10); +} diff --git a/ggml/src/ggml-cuda/getrows.cuh b/ggml/src/ggml-cuda/getrows.cuh index bbf130232..a1ca643f1 100644 --- a/ggml/src/ggml-cuda/getrows.cuh +++ b/ggml/src/ggml-cuda/getrows.cuh @@ -1,5 +1,8 @@ #include "common.cuh" #define CUDA_GET_ROWS_BLOCK_SIZE 256 +#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu similarity index 68% rename from ggml/src/ggml-cuda.cu rename to ggml/src/ggml-cuda/ggml-cuda.cu index 982316f56..383131c77 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1,26 +1,30 @@ #include "ggml-cuda.h" -#include "ggml.h" +#include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-cuda/common.cuh" #include "ggml-cuda/acc.cuh" #include "ggml-cuda/arange.cuh" +#include "ggml-cuda/argmax.cuh" #include "ggml-cuda/argsort.cuh" #include "ggml-cuda/binbcast.cuh" #include "ggml-cuda/clamp.cuh" #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/convert.cuh" +#include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" -#include "ggml-cuda/dmmv.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" +#include "ggml-cuda/mmv.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" +#include "ggml-cuda/opt-step-adamw.cuh" +#include "ggml-cuda/out-prod.cuh" #include "ggml-cuda/pad.cuh" #include "ggml-cuda/pool2d.cuh" #include "ggml-cuda/quantize.cuh" @@ -32,10 +36,13 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" +#include "ggml-cuda/wkv6.cuh" +#include "ggml-cuda/gla.cuh" #include #include #include +#include #include #include #include @@ -53,54 +60,16 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); -static void ggml_cuda_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) { - GGML_UNUSED(level); - GGML_UNUSED(user_data); - fprintf(stderr, "%s", msg); -} - -ggml_log_callback ggml_cuda_log_callback = ggml_cuda_default_log_callback; -void * ggml_cuda_log_user_data = NULL; - -GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data) { - ggml_cuda_log_callback = log_callback; - ggml_cuda_log_user_data = user_data; -} - -#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) -#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) -#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) - -GGML_ATTRIBUTE_FORMAT(2, 3) -static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) { - if (ggml_cuda_log_callback != NULL) { - va_list args; - va_start(args, format); - char buffer[128]; - int len = vsnprintf(buffer, 128, format, args); - if (len < 128) { - ggml_cuda_log_callback(level, buffer, ggml_cuda_log_user_data); - } else { - std::vector buffer2(len + 1); // vsnprintf adds a null terminator - va_end(args); - va_start(args, format); - vsnprintf(&buffer2[0], buffer2.size(), format, args); - ggml_cuda_log_callback(level, buffer2.data(), ggml_cuda_log_user_data); - } - va_end(args); - } -} - [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails - cudaGetDevice(&id); + (void)cudaGetDevice(&id); - GGML_CUDA_LOG_ERROR("CUDA error: %s\n", msg); - GGML_CUDA_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); - GGML_CUDA_LOG_ERROR(" %s\n", stmt); - // abort with GGML_ASSERT to get a stack trace - GGML_ABORT("CUDA error"); + GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg); + GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); + GGML_LOG_ERROR(" %s\n", stmt); + // abort with GGML_ABORT to get a stack trace + GGML_ABORT(GGML_CUDA_NAME " error"); } // this is faster on Windows @@ -124,7 +93,7 @@ int ggml_cuda_get_device() { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); -#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA) +#if defined(GGML_USE_HIP) && defined(GGML_HIP_UMA) auto res = hipMallocManaged(ptr, size); if (res == hipSuccess) { // if error we "need" to know why... @@ -133,7 +102,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) return res; #else -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_HIP) cudaError_t err; if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) { @@ -146,24 +115,90 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) return err; #else return cudaMalloc(ptr, size); -#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#endif // !defined(GGML_USE_HIP) #endif } +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +static int ggml_cuda_parse_id(char devName[]) { + // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp + // these values are not stable so this is susceptible to breakage + // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp + int archMajor = 0x0; + int archMinor = 0x0; + int archNum = GGML_CUDA_CC_OFFSET_AMD; + int archLen = strlen(devName); + char archName[archLen + 1]; + + // strip leading 'gfx' while copying into our buffer + if (archLen > 3) { + strcpy(archName, &devName[3]); + archLen -= 3; + } + + // trim trailing :xnack- or :sramecc- statuses + archLen = strcspn(archName, ":"); + archName[archLen] = '\0'; + + // tease out the version information + if (archLen > 8) { + // versions labeled generic use '-' as delimiter + // strip the trailing "-generic" then iterate through what remains + if ((strstr(archName, "-generic"))) { + archName[archLen - 8] = '\0'; + char * pch; + if ((pch = strtok(archName, "-"))) { + archMajor = (int)strtoul(pch, 0, 16); + if ((pch = strtok(NULL, "-"))) { + archMinor = 0x10 * (int)strtoul(pch, 0, 16); + } + } + } + } else if (archLen >= 3) { + // last two digits should be the minor * 0x10 + stepping + archMinor = (int)strtoul(&archName[archLen - 2], 0, 16); + archName[archLen - 2] = '\0'; + + // only the major version remains + archMajor = (int)strtoul(archName, 0, 16); + } + archNum += archMajor * 0x100; + archNum += archMinor; + return archNum; +} +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346 - rocblas_initialize(); - CUDA_CHECK(cudaDeviceSynchronize()); + { + int major_version = 0; + size_t version_length = 0; + if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) { + std::string version(version_length, '\0'); + if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) { + version.resize(::strlen(version.c_str())); + int parsed_value = 0; + if (std::from_chars(version.c_str(), version.c_str() + version.length(), parsed_value).ec == std::errc()) { + major_version = parsed_value; + } + } + } + if (major_version < 4) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n"); + rocblas_initialize(); + CUDA_CHECK(cudaDeviceSynchronize()); + } + } #endif ggml_cuda_device_info info = {}; cudaError_t err = cudaGetDeviceCount(&info.device_count); if (err != cudaSuccess) { - GGML_CUDA_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err)); + GGML_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err)); return info; } @@ -171,20 +206,20 @@ static ggml_cuda_device_info ggml_cuda_init() { int64_t total_vram = 0; #ifdef GGML_CUDA_FORCE_MMQ - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); + GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); + GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); #endif // GGML_CUDA_FORCE_MMQ #ifdef GGML_CUDA_FORCE_CUBLAS - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__); + GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); + GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); #endif // GGML_CUDA_FORCE_CUBLAS - GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); + GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) +#if defined(GGML_USE_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -196,25 +231,41 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) +#endif // defined(GGML_USE_VMM) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - GGML_CUDA_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; - info.devices[id].nsm = prop.multiProcessorCount; - info.devices[id].smpb = prop.sharedMemPerBlock; -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + info.devices[id].nsm = prop.multiProcessorCount; + info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].warp_size = prop.warpSize; +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; - info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD; + + info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName); + if ((info.devices[id].cc & 0xff00) == 0x0) { + GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n", + id, prop.name, prop.gcnArchName, prop.major, prop.minor); + + // Fallback to prop.major and prop.minor + if (prop.major > 0) { + info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100; + info.devices[id].cc += prop.minor * 0x10; + } + } + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n", + id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, + device_vmm ? "yes" : "no", prop.warpSize); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } for (int id = 0; id < info.device_count; ++id) { @@ -309,7 +360,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC - GGML_CUDA_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz, + GGML_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024)); #endif return ptr; @@ -324,7 +375,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { return; } } - GGML_CUDA_LOG_WARN("Cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); ggml_cuda_set_device(device); CUDA_CHECK(cudaFree(ptr)); pool_size -= size; @@ -332,7 +383,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) +#if defined(GGML_USE_VMM) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -341,6 +392,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { size_t pool_used = 0; size_t pool_size = 0; size_t granularity; +#if defined(GGML_USE_HIP) + std::vector> mappings; +#endif explicit ggml_cuda_pool_vmm(int device) : device(device), @@ -349,7 +403,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { ~ggml_cuda_pool_vmm() { if (pool_addr != 0) { +#if defined(GGML_USE_HIP) + // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 + for (std::pair & mapping : mappings) { + CU_CHECK(cuMemUnmap(mapping.first, mapping.second)); + } +#else CU_CHECK(cuMemUnmap(pool_addr, pool_size)); +#endif CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE)); } } @@ -382,7 +443,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { } // map at the end of the pool - CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0)); + CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); + CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); +#if defined(GGML_USE_HIP) + mappings.push_back({start_ptr, reserve_size}); +#endif // the memory allocation handle is no longer needed after mapping CU_CHECK(cuMemRelease(handle)); @@ -392,7 +457,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; access.location.id = device; access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1)); + CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); // add to the pool pool_size += reserve_size; @@ -404,7 +469,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(pool_addr != 0); - void * ptr = (void *) (pool_addr + pool_used); + void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used)); *actual_size = size; pool_used += size; @@ -423,17 +488,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { pool_used -= size; // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); + GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } }; -#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) +#endif // defined(GGML_USE_VMM) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) +#if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } -#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) +#endif // defined(GGML_USE_VMM) return std::unique_ptr(new ggml_cuda_pool_leg(device)); } @@ -454,26 +519,21 @@ struct ggml_backend_cuda_buffer_context { } }; -GGML_CALL static const char * ggml_backend_cuda_buffer_get_name(ggml_backend_buffer_t buffer) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; - return ctx->name.c_str(); -} - -GGML_CALL static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) { - return buffer->iface.get_name == ggml_backend_cuda_buffer_get_name; -} - -GGML_CALL static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { +static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; delete ctx; } -GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { +static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_cuda_buffer_free_buffer; +} + +static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; return ctx->dev_ptr; } -GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; if (tensor->view_src != NULL) { @@ -493,7 +553,15 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t } } -GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); @@ -501,7 +569,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } -GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { +static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); @@ -509,7 +577,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } -GGML_CALL static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { +static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { if (ggml_backend_buffer_is_cuda(src->buffer)) { ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context; ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context; @@ -530,7 +598,7 @@ GGML_CALL static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t GGML_UNUSED(buffer); } -GGML_CALL static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { +static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); @@ -539,11 +607,11 @@ GGML_CALL static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffe CUDA_CHECK(cudaDeviceSynchronize()); } -static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { - /* .get_name = */ ggml_backend_cuda_buffer_get_name, +static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, /* .get_base = */ ggml_backend_cuda_buffer_get_base, /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, @@ -557,29 +625,27 @@ struct ggml_backend_cuda_buffer_type_context { std::string name; }; -GGML_CALL static const char * ggml_backend_cuda_buffer_type_name(ggml_backend_buffer_type_t buft) { +static const char * ggml_backend_cuda_buffer_type_get_name(ggml_backend_buffer_type_t buft) { ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; return ctx->name.c_str(); } static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_cuda_buffer_type_name; + return buft->iface.get_name == ggml_backend_cuda_buffer_type_get_name; } -GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; ggml_cuda_set_device(buft_ctx->device); - size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0 - void * dev_ptr; cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); - GGML_CUDA_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err)); + (void)cudaGetLastError(); + GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err)); return nullptr; } @@ -588,13 +654,13 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); } -GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { return 128; GGML_UNUSED(buft); } -GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { +static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { size_t size = ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; @@ -609,8 +675,8 @@ GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backen GGML_UNUSED(buft); } -static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { - /* .get_name = */ ggml_backend_cuda_buffer_type_name, +static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { + /* .get_name = */ ggml_backend_cuda_buffer_type_get_name, /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX @@ -618,7 +684,7 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { /* .is_host = */ NULL, }; -GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { +ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -631,9 +697,10 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { static bool ggml_backend_cuda_buffer_type_initialized = false; if (!ggml_backend_cuda_buffer_type_initialized) { - for (int i = 0; i < GGML_CUDA_MAX_DEVICES; i++) { + for (int i = 0; i < ggml_backend_cuda_get_device_count(); i++) { ggml_backend_cuda_buffer_types[i] = { /* .iface = */ ggml_backend_cuda_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), i), /* .context = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)}, }; } @@ -680,7 +747,9 @@ static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_spl } struct ggml_backend_cuda_split_buffer_type_context { + int main_device; std::array tensor_split; + std::string name; }; struct ggml_backend_cuda_split_buffer_context { @@ -703,30 +772,20 @@ struct ggml_backend_cuda_split_buffer_context { std::vector tensor_extras; }; -GGML_CALL static const char * ggml_backend_cuda_split_buffer_get_name(ggml_backend_buffer_t buffer) { - return GGML_CUDA_NAME "_Split"; - GGML_UNUSED(buffer); -} - -static bool ggml_backend_buffer_is_cuda_split(ggml_backend_buffer_t buffer) { - return buffer->iface.get_name == ggml_backend_cuda_split_buffer_get_name; - GGML_UNUSED(ggml_backend_buffer_is_cuda_split); // only used in debug builds currently, avoid unused function warning in release builds -} - -GGML_CALL static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { +static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; delete ctx; } -GGML_CALL static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) { +static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) { // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced return (void *)0x1000; GGML_UNUSED(buffer); } -GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; @@ -774,7 +833,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu tensor->extra = extra; } -GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { // split tensors must always be set in their entirety at once GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); @@ -812,7 +871,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buf } } -GGML_CALL static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { +static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { // split tensors must always be set in their entirety at once GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); @@ -850,16 +909,16 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buf } } -GGML_CALL static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { +static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { GGML_UNUSED(buffer); GGML_UNUSED(value); } -static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { - /* .get_name = */ ggml_backend_cuda_split_buffer_get_name, +static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer, /* .get_base = */ ggml_backend_cuda_split_buffer_get_base, /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, /* .cpy_tensor = */ NULL, @@ -869,17 +928,17 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { // cuda split buffer type -GGML_CALL static const char * ggml_backend_cuda_split_buffer_type_name(ggml_backend_buffer_type_t buft) { - return GGML_CUDA_NAME "_Split"; +static const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_name; + return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_get_name; } -GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point // instead, we allocate them for each tensor separately in init_tensor // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated, @@ -889,13 +948,13 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size); } -GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { return 128; GGML_UNUSED(buft); } -GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { +static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context; size_t total_size = 0; @@ -922,14 +981,14 @@ GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_ return total_size; } -GGML_CALL static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { +static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { return false; GGML_UNUSED(buft); } -static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = { - /* .get_name = */ ggml_backend_cuda_split_buffer_type_name, +static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = { + /* .get_name = */ ggml_backend_cuda_split_buffer_type_get_name, /* .alloc_buffer = */ ggml_backend_cuda_split_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cuda_split_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX @@ -937,11 +996,11 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; -GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split) { +ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) { static std::mutex mutex; std::lock_guard lock(mutex); - static std::map, struct ggml_backend_buffer_type> buft_map; + static std::map>, struct ggml_backend_buffer_type> buft_map; std::array tensor_split_arr = {}; @@ -959,35 +1018,35 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const f } } - auto it = buft_map.find(tensor_split_arr); + auto it = buft_map.find({main_device, tensor_split_arr}); if (it != buft_map.end()) { return &it->second; } + auto * ctx = new ggml_backend_cuda_split_buffer_type_context{ + main_device, + tensor_split_arr, + GGML_CUDA_NAME + std::to_string(main_device) + "_Split", + }; struct ggml_backend_buffer_type buft { /* .iface = */ ggml_backend_cuda_split_buffer_type_interface, - /* .context = */ new ggml_backend_cuda_split_buffer_type_context{tensor_split_arr}, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), main_device), + /* .context = */ ctx, }; - auto result = buft_map.emplace(tensor_split_arr, buft); + auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft); return &result.first->second; } // host buffer type -GGML_CALL static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) { +static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) { return GGML_CUDA_NAME "_Host"; GGML_UNUSED(buft); } -GGML_CALL static const char * ggml_backend_cuda_host_buffer_name(ggml_backend_buffer_t buffer) { - return GGML_CUDA_NAME "_Host"; - - GGML_UNUSED(buffer); -} - -GGML_CALL static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { +static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { CUDA_CHECK(cudaFreeHost(buffer->context)); } @@ -1000,8 +1059,8 @@ static void * ggml_cuda_host_malloc(size_t size) { cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { // clear the error - cudaGetLastError(); - GGML_CUDA_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, + (void)cudaGetLastError(); + GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0, cudaGetErrorString(err)); return nullptr; } @@ -1009,7 +1068,7 @@ static void * ggml_cuda_host_malloc(size_t size) { return ptr; } -GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { void * ptr = ggml_cuda_host_malloc(size); if (ptr == nullptr) { @@ -1019,13 +1078,12 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); buffer->buft = buft; - buffer->iface.get_name = ggml_backend_cuda_host_buffer_name; buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; return buffer; } -GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { +ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = { /* .iface = */ { /* .get_name = */ ggml_backend_cuda_host_buffer_type_name, @@ -1035,6 +1093,7 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), 0), /* .context = */ nullptr, }; @@ -1059,120 +1118,12 @@ typedef void (*ggml_cuda_op_mul_mat_t)( #define MUL_MAT_SRC1_COL_STRIDE 128 -static __global__ void mul_mat_p021_f16_f32( - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / (nchannels_y / nchannels_x); - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; - const float xi = __half2float(x[ix]); - - const int row_y = col_x; - - // y is not transposed but permuted - const int iy = channel*nrows_y + row_y; - - tmp += xi * y[iy]; - } - - // dst is not transposed and not permuted - const int idst = channel*nrows_dst + row_dst; - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / channel_x_divisor; - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - const int idst = channel*nrows_dst + row_dst; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - const int row_y = col_x; - - const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const int iy = channel*nrows_y + row_y; - - const float xi = __half2float(x[ix]); - - tmp += xi * y[iy]; - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static void ggml_mul_mat_p021_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, - const int nchannels_x, const int nchannels_y, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); -} - -static void ggml_mul_mat_vec_nc_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); -} - static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer)); - char * src_ptr = (char *) src->data; - char * dst_ptr = (char *) dst; + const char * src_ptr = (const char *) src->data; + char * dst_ptr = (char *) dst; const int64_t ne0 = src->ne[0]; const int64_t nb0 = src->nb[0]; @@ -1182,7 +1133,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( const enum ggml_type type = src->type; const int64_t ts = ggml_type_size(type); const int64_t bs = ggml_blck_size(type); - int64_t i1_diff = i1_high - i1_low; + const int64_t i1_diff = i1_high - i1_low; const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; if (nb0 == ts && nb1 == ts*ne0/bs) { @@ -1228,7 +1179,9 @@ static void ggml_cuda_op_mul_mat_cublas( const int compute_capability = ggml_cuda_info().devices[id].cc; - if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { + const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + + if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { @@ -1249,23 +1202,38 @@ static void ggml_cuda_op_mul_mat_cublas( to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - CUBLAS_CHECK( - cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, - row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + if (compute_capability == GGML_CUDA_CC_CDNA) { + const float alpha = 1.0f; + const float beta = 0.0f; + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } } else { ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id)); ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id)); @@ -1336,11 +1304,17 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0); if (err != cudaErrorPeerAccessAlreadyEnabled) { CUDA_CHECK(err); + } else { + // reset the error + (void)cudaGetLastError(); } } else { cudaError_t err = cudaDeviceDisablePeerAccess(id_other); if (err != cudaErrorPeerAccessNotEnabled) { CUDA_CHECK(err); + } else { + // reset the error + (void)cudaGetLastError(); } } } @@ -1358,7 +1332,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { static cudaError_t ggml_cuda_Memcpy2DPeerAsync( void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices cudaMemcpy3DPeerParms p = {}; p.dstDevice = dstDevice; @@ -1372,7 +1346,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync( GGML_UNUSED(dstDevice); GGML_UNUSED(srcDevice); return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); -#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } static void ggml_cuda_op_mul_mat( @@ -1420,7 +1394,7 @@ static void ggml_cuda_op_mul_mat( const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); - const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); + const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); GGML_ASSERT(!(split && ne02 > 1)); GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne02 < ne12)); @@ -1499,14 +1473,24 @@ static void ggml_cuda_op_mul_mat( if (src0_is_contiguous) { dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data; } else { - dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0)); + // If src0 is not contiguous it will be copied to a temporary buffer. + // This buffer needs to be cleared entirely because multiple regions will function as padding. + const size_t nbytes_data = ggml_nbytes(src0); + const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); + dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding); + // TODO: remove this for MUSA once the Guilty Lockup issue is resolved +#ifndef GGML_USE_MUSA + CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream)); +#else // GGML_USE_MUSA + CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream)); +#endif // !GGML_USE_MUSA } - // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared: + // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { - const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); - const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); - CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream)); + const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); + const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); + CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream)); } if (src1_on_device && src1_is_contiguous) { @@ -1677,58 +1661,6 @@ static void ggml_cuda_op_mul_mat( } } -static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation - GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); -} - -static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t nb01 = src0->nb[1]; - const int64_t nb02 = src0->nb[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - const int64_t row_stride_x = nb01 / sizeof(half); - const int64_t channel_stride_x = nb02 / sizeof(half); - - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); -} - static __global__ void k_compute_batched_ptrs( const half * src0_as_f16, const half * src1_as_f16, char * dst, const void ** ptrs_src, void ** ptrs_dst, @@ -1818,6 +1750,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co beta = &beta_f32; } + if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) { + cu_compute_type = CUBLAS_COMPUTE_32F; + alpha = &alpha_f32; + beta = &beta_f32; + } + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -1900,23 +1838,19 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); + const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); - bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type) + bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) + bool use_mul_mat_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - // if mmvq is available it's a better choice than dmmv: -#ifndef GGML_CUDA_FORCE_DMMV - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; -#endif // GGML_CUDA_FORCE_DMMV - - bool any_gpus_with_slow_fp16 = false; + bool any_gpus_with_slow_fp16 = false; + bool any_gpus_without_fp16_mma = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; @@ -1927,14 +1861,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - const int cc = ggml_cuda_info().devices[id].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } } else { - const int cc = ggml_cuda_info().devices[ctx.device].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } // debug helpers @@ -1945,18 +1881,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // FP32 precision KQ single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); - } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { - // FP32 precision KQV single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); + if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + // the custom F16 vector kernel can be used over batched cuBLAS GEMM + // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) + ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { - // KQ + KQV multi-batch without FlashAttention + // general KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); - } else if (use_dequantize_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); + } else if (use_mul_mat_vec) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { @@ -2027,7 +1961,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); + GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers"); cudaStream_t stream = ctx.stream(); @@ -2160,17 +2094,29 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { // why is this here instead of mul_mat? - if (dst->src[0] != nullptr && ggml_backend_buffer_is_cuda_split(dst->src[0]->buffer)) { + if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) { ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); } switch (dst->op) { + case GGML_OP_ARGMAX: + ggml_cuda_argmax(ctx, dst); + break; + case GGML_OP_COUNT_EQUAL: + ggml_cuda_count_equal(ctx, dst); + break; case GGML_OP_REPEAT: ggml_cuda_op_repeat(ctx, dst); break; + case GGML_OP_REPEAT_BACK: + ggml_cuda_op_repeat_back(ctx, dst); + break; case GGML_OP_GET_ROWS: ggml_cuda_op_get_rows(ctx, dst); break; + case GGML_OP_GET_ROWS_BACK: + ggml_cuda_op_get_rows_back(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -2201,6 +2147,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_NEG: ggml_cuda_op_neg(ctx, dst); break; + case GGML_UNARY_OP_STEP: + ggml_cuda_op_step(ctx, dst); + break; case GGML_UNARY_OP_GELU: ggml_cuda_op_gelu(ctx, dst); break; @@ -2225,6 +2174,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_HARDSWISH: ggml_cuda_op_hardswish(ctx, dst); break; + case GGML_UNARY_OP_EXP: + ggml_cuda_op_exp(ctx, dst); + break; default: return false; } @@ -2253,12 +2205,18 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_LEAKY_RELU: ggml_cuda_op_leaky_relu(ctx, dst); break; + case GGML_OP_SILU_BACK: + ggml_cuda_op_silu_back(ctx, dst); + break; case GGML_OP_RMS_NORM: ggml_cuda_op_rms_norm(ctx, dst); break; + case GGML_OP_RMS_NORM_BACK: + ggml_cuda_op_rms_norm_back(ctx, dst); + break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { - GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); + GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); return false; } else { ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); @@ -2267,6 +2225,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_MUL_MAT_ID: ggml_cuda_mul_mat_id(ctx, dst); break; + case GGML_OP_OUT_PROD: + ggml_cuda_out_prod(ctx, dst); + break; case GGML_OP_SCALE: ggml_cuda_op_scale(ctx, dst); break; @@ -2297,9 +2258,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOFT_MAX: ggml_cuda_op_soft_max(ctx, dst); break; + case GGML_OP_SOFT_MAX_BACK: + ggml_cuda_op_soft_max_back(ctx, dst); + break; case GGML_OP_ROPE: ggml_cuda_op_rope(ctx, dst); break; + case GGML_OP_ROPE_BACK: + ggml_cuda_op_rope_back(ctx, dst); + break; case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; @@ -2324,13 +2291,25 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_RWKV_WKV6: + ggml_cuda_op_rwkv_wkv6(ctx, dst); + break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_cuda_op_gated_linear_attn(ctx, dst); + break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + ggml_cuda_cross_entropy_loss_back(ctx, dst); + break; + case GGML_OP_OPT_STEP_ADAMW: + ggml_cuda_opt_step_adamw(ctx, dst); + break; default: return false; } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { - GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst)); + GGML_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst)); CUDA_CHECK(err); } @@ -2341,26 +2320,20 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg // backend -GGML_CALL static const char * ggml_backend_cuda_name(ggml_backend_t backend) { +static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; return cuda_ctx->name.c_str(); } -GGML_CALL static void ggml_backend_cuda_free(ggml_backend_t backend) { +static void ggml_backend_cuda_free(ggml_backend_t backend) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; delete cuda_ctx; delete backend; } -GGML_CALL static ggml_backend_buffer_type_t ggml_backend_cuda_get_default_buffer_type(ggml_backend_t backend) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - - return ggml_backend_cuda_buffer_type(cuda_ctx->device); -} - -GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; @@ -2369,7 +2342,7 @@ GGML_CALL static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); } -GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { +static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; @@ -2378,7 +2351,7 @@ GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); } -GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { +static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; @@ -2399,7 +2372,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_ if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { #ifndef NDEBUG - GGML_CUDA_LOG_WARN("%s: backend and buffer devices do not match\n", __func__); + GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif return false; } @@ -2433,7 +2406,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_ return true; } -GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { +static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream())); @@ -2441,6 +2414,67 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { GGML_UNUSED(backend); } +#ifdef USE_CUDA_GRAPH +static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) { + + // Loop over nodes in GGML graph to obtain info needed for CUDA graph + cuda_ctx->cuda_graph->updated_kernel_arg.clear(); + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) { + use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__); +#endif + } + + if (node->op == GGML_OP_MUL_MAT_ID) { + use_cuda_graph = false; // This node type is not supported by CUDA graph capture +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); +#endif + } + + if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) { + // disable CUDA graphs for batch size > 1 for now. + // Changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } + + if (node->op == GGML_OP_CPY) { + // store the copy op parameter which changes with each token. + cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); + // store a pointer to each copy op CUDA kernel to identify it later + void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + if (!ptr) { + use_cuda_graph = false; +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); +#endif + } else { + if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { + ggml_cuda_cpy_fn_ptrs.push_back(ptr); + } + } + } + + if (!use_cuda_graph) { + break; + } + } + + return use_cuda_graph; +} + static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { graph_node_properties->node_address = node->data; graph_node_properties->node_op = node->op; @@ -2451,6 +2485,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p for (int i = 0; i < GGML_MAX_SRC; i++) { graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } + memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); } static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { @@ -2482,146 +2517,119 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra return false; } } + + if (node->op == GGML_OP_SCALE && + memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + return false; + } + return true; } -GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; +static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { - ggml_cuda_set_device(cuda_ctx->device); + if (cuda_graph_update_required) { + // Extract nodes from graph + // First call with null argument gets number of nodes in graph + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); + // Subsequent call with non-null argument gets nodes + cuda_ctx->cuda_graph->nodes.clear(); + cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); + cuda_ctx->cuda_graph->params.clear(); + cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); + if (cuda_ctx->cuda_graph->num_nodes > 0) { + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); -#ifdef USE_CUDA_GRAPH - static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - - // Objects required for CUDA Graph - if (cuda_ctx->cuda_graph == nullptr) { - cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); - } - - bool use_cuda_graph = true; - bool cuda_graph_update_required = false; - // vector of pointers to CUDA cpy kernels, which are required to identify - // kernel parameters which need updated in the graph for each token - std::vector ggml_cuda_cpy_fn_ptrs; - - if (cuda_ctx->cuda_graph->graph == nullptr) { - if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { - cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; -#ifndef NDEBUG - GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to GPU architecture\n", __func__); -#endif - } - } - - // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, - // or previous graph capture failure. - // Also disable for multi-gpu for now. TO DO investigate - if (disable_cuda_graphs_due_to_env - || cuda_ctx->cuda_graph->disable_due_to_gpu_arch - || cuda_ctx->cuda_graph->disable_due_to_too_many_updates - || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { - use_cuda_graph = false; - } - - if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { - cuda_graph_update_required = true; - } - - // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { - cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); - } - - // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token - for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = true; - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); - } - if (!has_matching_properties) { - cuda_graph_update_required = true; - } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); - } - - // Loop over nodes in GGML graph to obtain info needed for CUDA graph - cuda_ctx->cuda_graph->updated_kernel_arg.clear(); - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - - if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) { - use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture -#ifndef NDEBUG - GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to split buffer\n", __func__); -#endif - } - - if (node->op == GGML_OP_MUL_MAT_ID) { - use_cuda_graph = false; // This node type is not supported by CUDA graph capture -#ifndef NDEBUG - GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); -#endif - } - - if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) { - // disable CUDA graphs for batch size > 1 for now. - // Changes in batch size or context size can cause changes to the grid size of some kernels. - use_cuda_graph = false; -#ifndef NDEBUG - GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); -#endif - } - - if (node->op == GGML_OP_CPY) { - // store the copy op parameter which changes with each token. - cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); - // store a pointer to each copy op CUDA kernel to identify it later - void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); - if (!ptr) { - use_cuda_graph = false; -#ifndef NDEBUG - GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to unsupported copy op\n", __func__); -#endif - } else { - if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { - ggml_cuda_cpy_fn_ptrs.push_back(ptr); + // Loop over nodes, and extract kernel parameters from each node + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + cudaGraphNodeType node_type; + CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); + if (node_type == cudaGraphNodeTypeKernel) { + cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime + if (stat == cudaErrorInvalidDeviceFunction) { + // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. + // We don't need to update blas nodes, so clear error and move on. + (void)cudaGetLastError(); + } else { + GGML_ASSERT(stat == cudaSuccess); } } } - - if (!use_cuda_graph) { - break; + } + } else { + // One of the arguments to the copy kernel is updated for each token, hence we need to + // replace that argument with the updated value in the CUDA graph + // on update steps, the live parameters will already be captured + int k = 0; + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { + char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); + cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; + CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); } } - - // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (use_cuda_graph && cuda_graph_update_required) { - cuda_ctx->cuda_graph->number_consecutive_updates++; - } else { - cuda_ctx->cuda_graph->number_consecutive_updates = 0; - } - - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { - cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; -#ifndef NDEBUG - GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); -#endif - } } +} - if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture - CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); - } +static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { -#else - bool use_cuda_graph = false; bool cuda_graph_update_required = false; -#endif // USE_CUDA_GRAPH - bool graph_evaluated_or_captured = false; + if (cuda_ctx->cuda_graph->instance == nullptr) { + cuda_graph_update_required = true; + } + + // Check if the graph size has changed + if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + cuda_graph_update_required = true; + cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + } + + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = true; + if (!cuda_graph_update_required) { + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + if (!has_matching_properties) { + cuda_graph_update_required = true; + } + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + + return cuda_graph_update_required; +} + +static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { + + cudaGraphExecUpdateResultInfo result_info; +#ifdef __HIP_PLATFORM_AMD__ + hipGraphNode_t errorNode; + hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); +#else + cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); +#endif + if (stat == cudaErrorGraphExecUpdateFailure) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); +#endif + + // The pre-existing graph exec cannot be updated due to violated constraints + // so instead clear error and re-instantiate + (void)cudaGetLastError(); + CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); + cuda_ctx->cuda_graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + } else { + GGML_ASSERT(stat == cudaSuccess); + } +} +#endif + +static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, + [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph, + bool & cuda_graph_update_required) { while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. @@ -2639,14 +2647,15 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t for (int j = 0; j < GGML_MAX_SRC; j++) { if (node->src[j] != nullptr) { assert(node->src[j]->buffer); - assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer)); + assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || + ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft)); } } #endif bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); if (!ok) { - GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); } @@ -2658,19 +2667,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); cuda_ctx->cuda_graph->graph = nullptr; } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); -#if 0 - if (disable_cuda_graphs_due_to_failed_capture) { - use_cuda_graph = false; - cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; -#ifndef NDEBUG - GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to failed graph capture\n", __func__); -#endif - } else { - graph_evaluated_or_captured = true; // CUDA graph has been captured - } -#endif + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); graph_evaluated_or_captured = true; // CUDA graph has been captured } else { graph_evaluated_or_captured = true; // ggml graph has been directly evaluated @@ -2683,80 +2681,289 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t } // Perform update to graph (if required for this token), and change copy parameter (required for every token) - - if (cuda_graph_update_required) { - // Extract nodes from graph - // First call with null argument gets number of nodes in graph - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); - // Subsequent call with non-null argument gets nodes - cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); - cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); - if (cuda_ctx->cuda_graph->num_nodes > 0) { - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); - - // Loop over nodes, and extract kernel parameters from each node - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - cudaGraphNodeType node_type; - CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); - if (node_type == cudaGraphNodeTypeKernel) { - cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime - if (stat == cudaErrorInvalidDeviceFunction) { - // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. - // We don't need to update blas nodes, so clear error and move on. - cudaGetLastError(); - } else { - GGML_ASSERT(stat == cudaSuccess); - } - } - } - } - } - - // One of the arguments to the copy kernel is updated for each token, hence we need to - // replace that argument with the updated value in the CUDA graph - if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured - int k = 0; - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { - char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); - cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; - CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); - } - } - } + maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required); // Update graph executable - cudaGraphExecUpdateResultInfo result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); - if (stat == cudaErrorGraphExecUpdateFailure) { -#ifndef NDEBUG - GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__); -#endif - // The pre-existing graph exec cannot be updated due to violated constraints - // so instead clear error and re-instantiate - cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); - } else { - GGML_ASSERT(stat == cudaSuccess); - } + update_cuda_graph_executable(cuda_ctx); + // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); #else graph_evaluated_or_captured = true; -#endif // USE_CUDA_GRAPH +#endif // USE_CUDA_GRAPH } +} + +static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + ggml_cuda_set_device(cuda_ctx->device); + + // vector of pointers to CUDA cpy kernels, which are required to identify + // kernel parameters which need updated in the graph for each token + std::vector ggml_cuda_cpy_fn_ptrs; + +#ifdef USE_CUDA_GRAPH + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + + // Objects required for CUDA Graph + if (cuda_ctx->cuda_graph == nullptr) { + cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); + } + + bool use_cuda_graph = true; + bool cuda_graph_update_required = false; + + if (cuda_ctx->cuda_graph->graph == nullptr) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { + cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); +#endif + } + } + + // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, + // or previous graph capture failure. + // Also disable for multi-gpu for now. TO DO investigate + if (disable_cuda_graphs_due_to_env + || cuda_ctx->cuda_graph->disable_due_to_gpu_arch + || cuda_ctx->cuda_graph->disable_due_to_too_many_updates + || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { + use_cuda_graph = false; + } + + if (use_cuda_graph) { + cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); + + use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, + ggml_cuda_cpy_fn_ptrs, use_cuda_graph); + + // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. + if (use_cuda_graph && cuda_graph_update_required) { + cuda_ctx->cuda_graph->number_consecutive_updates++; + } else { + cuda_ctx->cuda_graph->number_consecutive_updates = 0; + } + + if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { + cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); +#endif + } + } + + if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + } + +#else + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; +#endif // USE_CUDA_GRAPH + + bool graph_evaluated_or_captured = false; + + evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); return GGML_STATUS_SUCCESS; } -GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; +static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + if (ggml_backend_is_cuda(backend)) { + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0)); + } else { +#if 0 + // untested + auto wait_fn = [](void * user_data) { + ggml_backend_event_t event = (ggml_backend_event_t)user_data; + ggml_backend_event_synchronize(event); + }; + + CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event)); +#endif + GGML_ABORT("fatal error"); + } +} + +static const ggml_backend_i ggml_backend_cuda_interface = { + /* .get_name = */ ggml_backend_cuda_get_name, + /* .free = */ ggml_backend_cuda_free, + /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, + /* .synchronize = */ ggml_backend_cuda_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_cuda_graph_compute, + /* .event_record = */ ggml_backend_cuda_event_record, + /* .event_wait = */ ggml_backend_cuda_event_wait, +}; + +static ggml_guid_t ggml_backend_cuda_guid() { + static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 }; + return &guid; +} + +bool ggml_backend_is_cuda(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid()); +} + +int ggml_backend_cuda_get_device_count() { + return ggml_cuda_info().device_count; +} + +void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + snprintf(description, description_size, "%s", prop.name); +} + +void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) { + ggml_cuda_set_device(device); + + CUDA_CHECK(cudaMemGetInfo(free, total)); +} + +bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { + if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) { + return false; + } + +#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) + cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); + if (err != cudaSuccess) { + // clear the error + (void)cudaGetLastError(); + + GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__, + size / 1024.0 / 1024.0, cudaGetErrorString(err)); + return false; + } + return true; +#else + return false; +#endif +} + +void ggml_backend_cuda_unregister_host_buffer(void * buffer) { + if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) { + return; + } + + cudaError_t err = cudaHostUnregister(buffer); + if (err != cudaSuccess) { + // clear the error + (void)cudaGetLastError(); + } +} + + +// backend device + +struct ggml_backend_cuda_device_context { + int device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemGetInfo(free, total)); +} + +static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_cuda_device_get_name(dev); + props->description = ggml_backend_cuda_device_get_description(dev); + props->type = ggml_backend_cuda_device_get_type(dev); + ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); + + bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; +#ifdef GGML_CUDA_NO_PEER_COPY + bool events = false; +#else + bool events = true; +#endif + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ host_buffer, + /* .buffer_from_host_ptr = */ false, + /* .events = */ events, + }; +} + +static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + return ggml_backend_cuda_init(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + return ggml_backend_cuda_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return ggml_backend_cuda_host_buffer_type(); +} + +// TODO: move these functions here +static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + + // split buffers can only be used with GGML_OP_MUL_MAT + if (op->op != GGML_OP_MUL_MAT) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda_split(op->src[i]->buffer->buft)) { + return false; + } + } + } + + // check if all the sources are allocated on this device + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda(op->src[i]->buffer->buft)) { + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)op->src[i]->buffer->buft->context; + if (buft_ctx->device != dev_ctx->device) { + return false; + } + } + } + switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: @@ -2765,6 +2972,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]); default: return false; @@ -2775,12 +2983,29 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons { struct ggml_tensor * a = op->src[0]; struct ggml_tensor * b = op->src[1]; + // for small weight matrices the active device can end up without any rows, don't use row split in those cases + // this avoids some edge cases (and the performance would not be good anyways) + if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) { + ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context; + int64_t row_low; + int64_t row_high; + get_row_split(&row_low, &row_high, a, buft_ctx->tensor_split, dev_ctx->device); + if (row_low == row_high) { + return false; + } + } if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { return false; } +#ifdef GGML_USE_MUSA + if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 && + !ggml_is_transposed(a) && !ggml_is_transposed(b)) { + return false; + } +#endif // GGML_USE_MUSA switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: @@ -2804,11 +3029,19 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_BF16: +#ifdef GGML_USE_MUSA + if (a->type == GGML_TYPE_Q3_K) { + return false; + } +#endif // GGML_USE_MUSA return true; default: return false; } } break; + case GGML_OP_OUT_PROD: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { @@ -2824,6 +3057,10 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return false; } } break; + case GGML_OP_GET_ROWS_BACK: + { + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; @@ -2837,6 +3074,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) { return true; } + if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) { return true; } @@ -2864,7 +3104,22 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return false; } break; case GGML_OP_DUP: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } break; + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + { + return true; + } break; case GGML_OP_REPEAT: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } break; + case GGML_OP_REPEAT_BACK: + return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15); case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; @@ -2879,18 +3134,24 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } return false; } break; + case GGML_OP_SILU_BACK: + return ggml_is_contiguous(op->src[0]); + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; + break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_NORM: case GGML_OP_ADD: case GGML_OP_ADD1: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -2903,10 +3164,18 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: return true; + case GGML_OP_SOFT_MAX_BACK: { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; + } case GGML_OP_ROPE: - return ggml_is_contiguous(op->src[0]); + case GGML_OP_ROPE_BACK: { + const size_t ts = ggml_type_size(op->src[0]->type); + const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2]; + return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts; + } case GGML_OP_IM2COL: - return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: @@ -2918,224 +3187,275 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: return true; - case GGML_OP_FLASH_ATTN_EXT: -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128; -#else - if (op->src[0]->ne[0] == 128) { - return true; + case GGML_OP_FLASH_ATTN_EXT: { +#ifndef FLASH_ATTN_AVAILABLE + return false; +#endif + if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { + return false; } if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { return true; } - return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && - op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + if (op->src[0]->ne[0] == 128) { + return true; + } + if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { + return true; + } + const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; + return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + } case GGML_OP_CROSS_ENTROPY_LOSS: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_OPT_STEP_ADAMW: return true; -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) default: return false; } - - GGML_UNUSED(backend); } -GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (ggml_backend_buft_is_cuda_split(buft)) { - return true; - } - - if (ggml_backend_buft_is_cuda(buft)) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; - return buft_ctx->device == cuda_ctx->device; - } - - return false; +static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev; } -GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) { +static int64_t get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + return 0; + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + +static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { const int min_batch_size = 32; - return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || - (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + return get_op_batch_size(op) >= min_batch_size; - GGML_UNUSED(backend); + GGML_UNUSED(dev); } -static ggml_backend_event_t ggml_backend_cuda_event_new(ggml_backend_t backend) { +static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) { #ifdef GGML_CUDA_NO_PEER_COPY return nullptr; #else - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *)dev->context; - ggml_cuda_set_device(cuda_ctx->device); + ggml_cuda_set_device(dev_ctx->device); cudaEvent_t event; CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); return new ggml_backend_event { - /* .backend = */ backend, + /* .device = */ dev, /* .context = */ event, }; #endif } -static void ggml_backend_cuda_event_free(ggml_backend_event_t event) { - CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context)); +static void ggml_backend_cuda_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context)); delete event; } -static void ggml_backend_cuda_event_record(ggml_backend_event_t event) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)event->backend->context; - - CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream())); -} - -static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - - if (ggml_backend_is_cuda(event->backend)) { - CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0)); - } else { -#if 0 - // untested - auto wait_fn = [](void * user_data) { - ggml_backend_event_t event = (ggml_backend_event_t)user_data; - ggml_backend_event_synchronize(event); - }; - - CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event)); -#endif - GGML_ABORT("fatal error"); - } -} - -static void ggml_backend_cuda_event_synchronize(ggml_backend_event_t event) { +static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); } -static ggml_backend_i ggml_backend_cuda_interface = { - /* .get_name = */ ggml_backend_cuda_name, - /* .free = */ ggml_backend_cuda_free, - /* .get_default_buffer_type = */ ggml_backend_cuda_get_default_buffer_type, - /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, - /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, - /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, - /* .synchronize = */ ggml_backend_cuda_synchronize, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_cuda_graph_compute, - /* .supports_op = */ ggml_backend_cuda_supports_op, - /* .supports_buft = */ ggml_backend_cuda_supports_buft, - /* .offload_op = */ ggml_backend_cuda_offload_op, - /* .event_new = */ ggml_backend_cuda_event_new, - /* .event_free = */ ggml_backend_cuda_event_free, - /* .event_record = */ ggml_backend_cuda_event_record, - /* .event_wait = */ ggml_backend_cuda_event_wait, - /* .event_synchronize = */ ggml_backend_cuda_event_synchronize, +static const ggml_backend_device_i ggml_backend_cuda_device_interface = { + /* .get_name = */ ggml_backend_cuda_device_get_name, + /* .get_description = */ ggml_backend_cuda_device_get_description, + /* .get_memory = */ ggml_backend_cuda_device_get_memory, + /* .get_type = */ ggml_backend_cuda_device_get_type, + /* .get_props = */ ggml_backend_cuda_device_get_props, + /* .init_backend = */ ggml_backend_cuda_device_init_backend, + /* .get_buffer_type = */ ggml_backend_cuda_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_cuda_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_cuda_device_supports_op, + /* .supports_buft = */ ggml_backend_cuda_device_supports_buft, + /* .offload_op = */ ggml_backend_cuda_device_offload_op, + /* .event_new = */ ggml_backend_cuda_device_event_new, + /* .event_free = */ ggml_backend_cuda_device_event_free, + /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, }; -static ggml_guid_t ggml_backend_cuda_guid() { - static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 }; - return &guid; +// backend reg + +struct ggml_backend_cuda_reg_context { + std::vector devices; +}; + +static const char * ggml_backend_cuda_reg_get_name(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return GGML_CUDA_NAME; } -GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) { +static size_t ggml_backend_cuda_reg_get_device_count(ggml_backend_reg_t reg) { + ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context; + return ctx->devices.size(); +} + +static ggml_backend_dev_t ggml_backend_cuda_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; +} + +static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t reg) { + static std::vector features = []() { + std::vector features; + #define _STRINGIFY(...) #__VA_ARGS__ + #define STRINGIFY(...) _STRINGIFY(__VA_ARGS__) + + #ifdef __CUDA_ARCH_LIST__ + features.push_back({ "ARCHS", STRINGIFY(__CUDA_ARCH_LIST__) }); + #endif + + #ifdef GGML_CUDA_FORCE_MMQ + features.push_back({ "FORCE_MMQ", "1" }); + #endif + + #ifdef GGML_CUDA_FORCE_CUBLAS + features.push_back({ "FORCE_CUBLAS", "1" }); + #endif + + #ifndef GGML_USE_VMM + features.push_back({ "NO_VMM", "1" }); + #endif + + #ifdef GGML_CUDA_NO_PEER_COPY + features.push_back({ "NO_PEER_COPY", "1" }); + #endif + + #ifdef GGML_CUDA_F16 + features.push_back({ "F16", "1" }); + #endif + + #ifdef GGML_CUDA_USE_GRAPHS + features.push_back({ "USE_GRAPHS", "1" }); + #endif + + #ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE + features.push_back({ "PEER_MAX_BATCH_SIZE", STRINGIFY(GGML_CUDA_PEER_MAX_BATCH_SIZE) }); + #endif + + #ifdef GGML_CUDA_FA_ALL_QUANTS + features.push_back({ "FA_ALL_QUANTS", "1" }); + #endif + + #undef _STRINGIFY + #undef STRINGIFY + + features.push_back({ nullptr, nullptr }); + + return features; + }(); + + return features.data(); + + GGML_UNUSED(reg); +} + +static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { + GGML_UNUSED(reg); + if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { + return (void *)ggml_backend_cuda_split_buffer_type; + } + if (strcmp(name, "ggml_backend_register_host_buffer") == 0) { + return (void *)ggml_backend_cuda_register_host_buffer; + } + if (strcmp(name, "ggml_backend_unregister_host_buffer") == 0) { + return (void *)ggml_backend_cuda_unregister_host_buffer; + } + if (strcmp(name, "ggml_backend_get_features") == 0) { + return (void *)ggml_backend_cuda_get_features; + } + return nullptr; +} + +static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = { + /* .get_name = */ ggml_backend_cuda_reg_get_name, + /* .get_device_count = */ ggml_backend_cuda_reg_get_device_count, + /* .get_device = */ ggml_backend_cuda_reg_get_device, + /* .get_proc_address = */ ggml_backend_cuda_reg_get_proc_address, +}; + +// backend registry +ggml_backend_reg_t ggml_backend_cuda_reg() { + static ggml_backend_reg reg; + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; + + for (int i = 0; i < ggml_cuda_info().device_count; i++) { + ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; + dev_ctx->device = i; + dev_ctx->name = GGML_CUDA_NAME + std::to_string(i); + + ggml_cuda_set_device(i); + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); + dev_ctx->description = prop.name; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_cuda_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx + }; + ctx->devices.push_back(dev); + } + + reg = ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_cuda_reg_interface, + /* .context = */ ctx + }; + } + + initialized = true; + } + + return ® +} + +ggml_backend_t ggml_backend_cuda_init(int device) { if (device < 0 || device >= ggml_backend_cuda_get_device_count()) { - GGML_CUDA_LOG_ERROR("%s: invalid device %d\n", __func__, device); + GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device); return nullptr; } ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device); if (ctx == nullptr) { - GGML_CUDA_LOG_ERROR("%s: failed to allocate context\n", __func__); + GGML_LOG_ERROR("%s: failed to allocate context\n", __func__); return nullptr; } ggml_backend_t cuda_backend = new ggml_backend { /* .guid = */ ggml_backend_cuda_guid(), /* .interface = */ ggml_backend_cuda_interface, - /* .context = */ ctx + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), + /* .context = */ ctx, }; return cuda_backend; } -GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend) { - return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid()); -} - -GGML_CALL int ggml_backend_cuda_get_device_count() { - return ggml_cuda_info().device_count; -} - -GGML_CALL void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) { - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); - snprintf(description, description_size, "%s", prop.name); -} - -GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) { - ggml_cuda_set_device(device); - - CUDA_CHECK(cudaMemGetInfo(free, total)); -} - -GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { - if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) { - return false; - } - -#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) - cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); - if (err != cudaSuccess) { - // clear the error - cudaGetLastError(); - - GGML_CUDA_LOG_WARN("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__, - size / 1024.0 / 1024.0, cudaGetErrorString(err)); - return false; - } - return true; -#else - return false; -#endif -} - -GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer) { - if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) { - return; - } - - cudaError_t err = cudaHostUnregister(buffer); - if (err != cudaSuccess) { - // clear the error - cudaGetLastError(); - } -} - -// backend registry -GGML_CALL static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) { - ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data); - return cuda_backend; - - GGML_UNUSED(params); -} - -extern "C" GGML_CALL int ggml_backend_cuda_reg_devices(); - -GGML_CALL int ggml_backend_cuda_reg_devices() { - int device_count = ggml_backend_cuda_get_device_count(); - //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization - for (int i = 0; i < device_count; i++) { - char name[128]; - snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i); - ggml_backend_register(name, ggml_backend_reg_cuda_init, ggml_backend_cuda_buffer_type(i), (void *) (intptr_t) i); - } - return device_count; -} +GGML_BACKEND_DL_IMPL(ggml_backend_cuda_reg) diff --git a/ggml/src/ggml-cuda/gla.cu b/ggml/src/ggml-cuda/gla.cu new file mode 100644 index 000000000..f7d615a82 --- /dev/null +++ b/ggml/src/ggml-cuda/gla.cu @@ -0,0 +1,93 @@ +#include "common.cuh" +#include "gla.cuh" + +template +static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale, + const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = HEAD_SIZE; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _k[head_size], _r[head_size], _td[head_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + __syncthreads(); + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4 & k = (float4 &)(_k[j]); + const float4 & r = (float4 &)(_r[j]); + const float4 & td = (float4 &)(_td[j]); + float4 & s = (float4 &)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + s.x = s.x * td.x + kv.x; + s.y = s.y * td.y + kv.y; + s.z = s.z * td.z + kv.z; + s.w = s.w * td.w + kv.w; + + y += r.x * s.x; + y += r.y * s.y; + y += r.z * s.z; + y += r.w * s.w; + } + dst[t] = y * scale; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * k_d = (const float *)dst->src[0]->data; + const float * v_d = (const float *)dst->src[1]->data; + const float * r_d = (const float *)dst->src[2]->data; + const float * td_d = (const float *)dst->src[3]->data; + const float * s_d = (const float *)dst->src[4]->data; + + const int64_t B = dst->src[4]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + float scale; + memcpy(&scale, (float*)dst->op_params, sizeof(float)); + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64 || C / H == 128); + + + if (C / H == 64) { + gated_linear_attn_f32<64><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } else { + gated_linear_attn_f32<128><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } +} diff --git a/ggml/src/ggml-cuda/gla.cuh b/ggml/src/ggml-cuda/gla.cuh new file mode 100644 index 000000000..2c82ad7dd --- /dev/null +++ b/ggml/src/ggml-cuda/gla.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 3d0d8d4e6..86a54e42b 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -69,7 +69,6 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); @@ -92,9 +91,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t OH = is_2D ? dst->ne[2] : 1; const int64_t OW = dst->ne[1]; - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - const int64_t batch = src1->ne[3]; - const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t batch = src1->ne[is_2D ? 3 : 2]; + const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 if(dst->type == GGML_TYPE_F16) { im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a452a3cc3..7d11540af 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -171,7 +171,7 @@ struct mma_int_C_I16J8 { __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) { #ifdef INT8_MMA_AVAILABLE -#if __CUDA_ARCH__ >= CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0])); @@ -183,7 +183,7 @@ struct mma_int_C_I16J8 { asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" : "+r"(x[2]), "+r"(x[3]) : "r"(mma_A.x[1]), "r"(mma_B.x[0])); -#endif // __CUDA_ARCH__ >= CC_AMPERE +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else GGML_UNUSED(mma_A); GGML_UNUSED(mma_B); @@ -193,7 +193,7 @@ struct mma_int_C_I16J8 { __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { #ifdef INT8_MMA_AVAILABLE -#if __CUDA_ARCH__ >= CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1])); @@ -211,7 +211,7 @@ struct mma_int_C_I16J8 { asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" : "+r"(x[2]), "+r"(x[3]) : "r"(mma_A.x[3]), "r"(mma_B.x[1])); -#endif // __CUDA_ARCH__ >= CC_AMPERE +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #else GGML_UNUSED(mma_A); GGML_UNUSED(mma_B); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 78d70cd7a..270251df4 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne00 = src0->ne[0]; - const int64_t nb01 = src0->nb[1]; - const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; GGML_ASSERT(ne10 % QK8_1 == 0); @@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; - const int64_t stride00 = nb01 / ggml_type_size(src0->type); + const int64_t stride00 = ne00 / ggml_blck_size(src0->type); int id = ggml_cuda_get_device(); const int compute_capability = ggml_cuda_info().devices[id].cc; @@ -26,7 +24,11 @@ void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; + // The stream-k decomposition is only faster for recent NVIDIA GPUs. + // Also its fixup needs to allocate a temporary buffer in the memory pool. + // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. + const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11; + const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k}; switch (src0->type) { case GGML_TYPE_Q4_0: @@ -134,7 +136,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return true; } - if (cc < MIN_CC_DP4A) { + if (cc < GGML_CUDA_CC_DP4A) { return false; } @@ -142,9 +144,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return true; #endif //GGML_CUDA_FORCE_MMQ - if (cc < CC_OFFSET_AMD) { - return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + if (cc < GGML_CUDA_CC_OFFSET_AMD) { + return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } - return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index e8a957447..3cd508a1d 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -89,9 +89,9 @@ struct tile_x_sizes { static constexpr int get_mmq_x_max_host(const int cc) { return int8_mma_available(cc) ? 128 : #ifdef GGML_CUDA_FORCE_MMQ - cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; + cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64; #else - cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64; + cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64; #endif // GGML_CUDA_FORCE_MMQ } @@ -100,43 +100,43 @@ static constexpr __device__ int get_mmq_x_max_device() { return 128; #else // INT8_MMA_AVAILABLE -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) return 128; -#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#if __CUDA_ARCH__ >= CC_VOLTA +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #ifdef GGML_CUDA_FORCE_MMQ return MMQ_DP4A_MAX_BATCH_SIZE; #else // GGML_CUDA_FORCE_MMQ return 128; #endif // GGML_CUDA_FORCE_MMQ -#else // __CUDA_ARCH__ >= CC_VOLTA +#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA return 64; -#endif // __CUDA_ARCH__ >= CC_VOLTA +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #endif // INT8_MMA_AVAILABLE } static constexpr int get_mmq_y_host(const int cc) { - return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64); + return cc >= GGML_CUDA_CC_OFFSET_AMD ? (cc == GGML_CUDA_CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64); } static constexpr __device__ int get_mmq_y_device() { -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA1) return 64; #else return 128; #endif // defined RDNA1 #else -#if __CUDA_ARCH__ >= CC_VOLTA +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA return 128; #else return 64; -#endif // __CUDA_ARCH__ >= CC_VOLTA -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} @@ -2569,17 +2569,17 @@ static __device__ void mul_mat_q_process_tile( // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 template -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) __launch_bounds__(WARP_SIZE*nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) +#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) #else -#if __CUDA_ARCH__ >= CC_VOLTA +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA __launch_bounds__(WARP_SIZE*nwarps, 1) #else __launch_bounds__(WARP_SIZE*nwarps, 2) -#endif // __CUDA_ARCH__ >= CC_VOLTA -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { @@ -2594,7 +2594,7 @@ static __global__ void mul_mat_q( constexpr int mmq_y = get_mmq_y_device(); // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: -#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA +#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { constexpr bool fixup = false; mul_mat_q_process_tile @@ -2602,7 +2602,7 @@ static __global__ void mul_mat_q( blockIdx.x, blockIdx.y, 0, ne00/qk); return; } -#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA +#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA const int64_t blocks_per_ne00 = ne00 / qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -2742,6 +2742,7 @@ struct mmq_args { int64_t ne00; int64_t ne01; int64_t stride01; int64_t ne10; int64_t ne11; int64_t stride11; int64_t ne0; + bool use_stream_k; }; template @@ -2764,21 +2765,20 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc); -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shmem_limit_raised[id]) { CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); shmem_limit_raised[id] = true; } -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) const int nty = (args.ne01 + mmq_y - 1) / mmq_y; const int ntx = (args.ne11 + mmq_x - 1) / mmq_x; const dim3 block_nums_xy_tiling(nty, ntx, 1); - const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD; - if (!use_stream_k) { + if (!args.use_stream_k) { if (args.ne01 % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> @@ -2825,7 +2825,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; - const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD; + const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD; int mmq_x_best = 0; int nparts_best = INT_MAX; diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu new file mode 100644 index 000000000..ac45f2d17 --- /dev/null +++ b/ggml/src/ggml-cuda/mmv.cu @@ -0,0 +1,261 @@ +#include "common.cuh" +#include "mmv.cuh" + +template +static __global__ void mul_mat_vec( + const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { + const int64_t row = blockIdx.x; + const int64_t channel = blockIdx.z; + const int tid = threadIdx.x; + + x += (channel/channel_ratio)*stride_channel_x + row*stride_row; + y += channel *stride_channel_y; + dst += channel *stride_channel_dst; + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + float * buf_iw = (float *) data_mmv; + + if (block_size > WARP_SIZE) { + if (tid < WARP_SIZE) { + buf_iw[tid] = 0.0f; + } + __syncthreads(); + } + + float sumf; + + if constexpr (std::is_same::value) { + const half2 * x2 = (const half2 *) x; + + if (std::is_same::value) { + sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + const float2 tmpy = y2[col2]; + sumf += tmpx.x * tmpy.x; + sumf += tmpx.y * tmpy.y; + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2 = make_half2(0.0f, 0.0f); + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmp = y2[col2]; + sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); + } + + sumf = __low2float(sumh2) + __high2float(sumh2); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + } else if constexpr (std::is_same::value) { + const int * x2 = (const int *) x; + sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const int tmpx = x2[col2]; + const float2 tmpy = y2[col2]; + sumf += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; + sumf += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; + } + } else { + static_assert(std::is_same::value, "unsupported type"); + } + + sumf = warp_reduce_sum(sumf); + + if (block_size > WARP_SIZE) { + buf_iw[tid/WARP_SIZE] = sumf; + __syncthreads(); + if (tid >= WARP_SIZE) { + return; + } + sumf = buf_iw[tid]; + sumf = warp_reduce_sum(sumf); + } + + if (tid != 0) { + return; + } + + dst[row] = sumf; +} + +template +static void launch_mul_mat_vec_cuda( + const T * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(nchannels_y % nchannels_x == 0); + const int64_t channel_ratio = nchannels_y / nchannels_x; + + int64_t block_size_best = WARP_SIZE; + int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const int smem = WARP_SIZE*sizeof(float); + const dim3 block_nums(nrows, 1, nchannels_y); + const dim3 block_dims(block_size_best, 1, 1); + switch (block_size_best) { + case 32: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 64: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 96: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 128: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 160: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 192: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 224: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 256: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +template +static void mul_mat_vec_cuda( + const T * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + enum ggml_prec prec, cudaStream_t stream) { + switch (prec) { + case GGML_PREC_DEFAULT: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + case GGML_PREC_F32: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + } +} + +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + GGML_ASSERT(src1->ne[1] == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + GGML_ASSERT(dst->ne[2] == ne12); + + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT( dst->ne[3] == 1); + + const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type); + const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type); + const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type); + const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type); + + switch (src0->type) { + case GGML_TYPE_F16: { + const half * src0_d = (const half *) src0->data; + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, + channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, + channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } +} + +void ggml_cuda_op_mul_mat_vec( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + GGML_ASSERT(src1_ncols == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + + // ggml_cuda_op provides single, contiguous matrices + const int64_t stride_row = ne00; + const int64_t nchannels_x = 1; + const int64_t nchannels_y = 1; + const int64_t channel_stride_x = 0; + const int64_t channel_stride_y = 0; + const int64_t channel_stride_dst = 0; + + switch (src0->type) { + case GGML_TYPE_F16: { + const half * src0_d = (const half *) src0_dd_i; + mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; + mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } + + GGML_UNUSED(ctx); + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_ncols); + GGML_UNUSED(src1_padded_row_size); +} diff --git a/ggml/src/ggml-cuda/dmmv.cuh b/ggml/src/ggml-cuda/mmv.cuh similarity index 55% rename from ggml/src/ggml-cuda/dmmv.cuh rename to ggml/src/ggml-cuda/mmv.cuh index e727eb97f..78a1cd4a6 100644 --- a/ggml/src/ggml-cuda/dmmv.cuh +++ b/ggml/src/ggml-cuda/mmv.cuh @@ -1,20 +1,12 @@ #include "common.cuh" -// dmmv = dequantize_mul_mat_vec +// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available +#define MMV_MAX_ROWS 512 -// TODO: remove this? -#ifndef GGML_CUDA_DMMV_X -#define GGML_CUDA_DMMV_X 32 -#endif +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); -#ifndef GGML_CUDA_MMV_Y -#define GGML_CUDA_MMV_Y 1 -#endif - -void ggml_cuda_op_dequantize_mul_mat_vec( +void ggml_cuda_op_mul_mat_vec( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); - -bool ggml_cuda_dmmv_type_supported(ggml_type src0_type); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 7dbbc9939..4fb466ca0 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -48,10 +48,10 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { } template -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) // tell the compiler to use as many registers as it wants, see nwarps definition below __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q( constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) constexpr int nwarps = 1; constexpr int rows_per_cuda_block = 1; #else constexpr int nwarps = ncols_y <= 4 ? 4 : 2; constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; const int row0 = rows_per_cuda_block*blockIdx.x; @@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda( int64_t nwarps = 1; int64_t rows_per_cuda_block = 1; - if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2 switch(ncols_y) { case 1: nwarps = 4; @@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda( break; } } + const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; const dim3 block_nums(nblocks, 1, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 133e219f0..d991ec972 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -5,20 +5,24 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - float2 mean_var = make_float2(0.f, 0.f); + x += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + + float2 mean_var = make_float2(0.0f, 0.0f); for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; + const float xi = x[col]; mean_var.x += xi; mean_var.y += xi * xi; } // sum up partial sums mean_var = warp_reduce_sum(mean_var); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float2 s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = mean_var; } @@ -32,7 +36,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c const float inv_std = rsqrtf(var + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std; + dst[col] = (x[col] - mean) * inv_std; } } @@ -40,14 +44,8 @@ template static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { // blockIdx.x: num_groups idx // threadIdx.x: block_size idx - int start = blockIdx.x * group_size; - int end = start + group_size; - - start += threadIdx.x; - - if (end >= ne_elements) { - end = ne_elements; - } + const int start = blockIdx.x*group_size + threadIdx.x; + const int end = min(blockIdx.x*group_size + group_size, ne_elements); float tmp = 0.0f; // partial sum for thread in warp @@ -56,10 +54,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -68,11 +67,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); } - float mean = tmp / group_size; + const float mean = tmp / group_size; tmp = 0.0f; for (int j = start; j < end; j += block_size) { - float xi = x[j] - mean; + const float xi = x[j] - mean; dst[j] = xi; tmp += xi * xi; } @@ -80,8 +79,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -90,8 +89,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp = warp_reduce_sum(tmp); } - float variance = tmp / group_size; - float scale = rsqrtf(variance + eps); + const float variance = tmp / group_size; + const float scale = rsqrtf(variance + eps); for (int j = start; j < end; j += block_size) { dst[j] *= scale; } @@ -102,19 +101,23 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; + x += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row*ncols + col]; + const float xi = x[col]; tmp += xi * xi; } // sum up partial sums tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); __shared__ float s_sum[32]; - int warp_id = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { s_sum[warp_id] = tmp; } @@ -127,12 +130,63 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row*ncols + col] = scale * x[row*ncols + col]; + dst[col] = scale * x[col]; + } +} + +template +static __global__ void rms_norm_back_f32( + const float * grad, const float * xf, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + grad += int64_t(row)*ncols; + xf += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + + float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass + float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + + for (int col = tid; col < ncols; col += block_size) { + const float xfi = xf[col]; + sum_xx += xfi * xfi; + sum_xg += xfi * grad[col]; + } + + // sum up partial sums + sum_xx = warp_reduce_sum(sum_xx); + sum_xg = warp_reduce_sum(sum_xg); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum_xx[32]; + __shared__ float s_sum_xg[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum_xx[warp_id] = sum_xx; + s_sum_xg[warp_id] = sum_xg; + } + __syncthreads(); + + sum_xx = s_sum_xx[lane_id]; + sum_xx = warp_reduce_sum(sum_xx); + + sum_xg = s_sum_xg[lane_id]; + sum_xg = warp_reduce_sum(sum_xg); + } + + const float mean_eps = sum_xx / ncols + eps; + const float sum_eps = sum_xx + ncols*eps; + + const float scale_grad = rsqrtf(mean_eps); + const float scale_x = -scale_grad * sum_xg/sum_eps; + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale_grad*grad[col] + scale_x*xf[col]; } } static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); norm_f32<<>>(x, dst, ncols, eps); @@ -142,7 +196,8 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i } } -static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { +static void group_norm_f32_cuda( + const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { if (group_size < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); @@ -153,7 +208,6 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou } static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); rms_norm_f32<<>>(x, dst, ncols, eps); @@ -163,6 +217,16 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } +static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_back_f32<<>>(grad, xf, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_back_f32<1024><<>>(grad, xf, dst, ncols, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -179,6 +243,7 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } @@ -198,6 +263,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float eps; memcpy(&eps, dst->op_params + 1, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream); @@ -219,6 +285,33 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * grad = dst->src[0]; // gradients + const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass + + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(grad)); + + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream); +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 431a8f74d..d63d34380 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/opt-step-adamw.cu b/ggml/src/ggml-cuda/opt-step-adamw.cu new file mode 100644 index 000000000..35154f299 --- /dev/null +++ b/ggml/src/ggml-cuda/opt-step-adamw.cu @@ -0,0 +1,78 @@ +#include "ggml-impl.h" +#include "opt-step-adamw.cuh" + +#include + +static __global__ void opt_step_adamw_f32( + float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, + const float * __restrict__ pars, const int64_t k) { + + const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + + const float gi = g[i]; + const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1); + const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2); + + g_m[i] = gmi; + g_v[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrtf(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; +} + +static void opt_step_adamw_f32_cuda( + float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) { + + const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); + const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); + opt_step_adamw_f32<<>>(x, g, g_m, g_v, pars, k); +} + +void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0_grad = dst->src[1]; + const ggml_tensor * src0_grad_m = dst->src[2]; + const ggml_tensor * src0_grad_v = dst->src[3]; + const ggml_tensor * adamw_params = dst->src[4]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32); + GGML_ASSERT(adamw_params->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src0_grad)); + GGML_ASSERT(ggml_is_contiguous(src0_grad_m)); + GGML_ASSERT(ggml_is_contiguous(src0_grad_v)); + GGML_ASSERT(ggml_is_contiguous(adamw_params)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v)); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); + + float * src0_d = (float *) src0->data; + const float * src0_grad_d = (const float *) src0_grad->data; + float * src0_grad_m_d = (float *) src0_grad_m->data; + float * src0_grad_v_d = (float *) src0_grad_v->data; + const float * adamw_params_d = (const float *) adamw_params->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t ne = ggml_nelements(src0); + + opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream); +} diff --git a/ggml/src/ggml-cuda/opt-step-adamw.cuh b/ggml/src/ggml-cuda/opt-step-adamw.cuh new file mode 100644 index 000000000..58d6f6e5d --- /dev/null +++ b/ggml/src/ggml-cuda/opt-step-adamw.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256 + +void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu new file mode 100644 index 000000000..c9b2b699c --- /dev/null +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -0,0 +1,68 @@ +#include "out-prod.cuh" + +#include + +void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ne01 == ne11); + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + + GGML_ASSERT(ne2 % src0->ne[2] == 0); + GGML_ASSERT(ne3 % src0->ne[3] == 0); + + GGML_ASSERT(ne2 == src1->ne[2]); + GGML_ASSERT(ne3 == src1->ne[3]); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + cublasHandle_t handle = ctx.cublas_handle(); + + const float alpha = 1.0f; + const float beta = 0.0f; + + CUBLAS_CHECK(cublasSetStream(handle, stream)); + + const int64_t lda = nb01 / sizeof(float); + const int64_t ldc = nb1 / sizeof(float); + + const bool src1_T = ggml_is_transposed(src1); + const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); + GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); + + // data strides in dimensions 2/3 + const size_t s02 = nb02 / sizeof(float); + const size_t s03 = nb03 / sizeof(float); + const size_t s12 = nb12 / sizeof(float); + const size_t s13 = nb13 / sizeof(float); + const size_t s2 = nb2 / sizeof(float); + const size_t s3 = nb3 / sizeof(float); + + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + + // TODO batched matrix multiplication + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, + src1_d + i3 *s13 + i2 *s12, ldb, + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); + } + } +} diff --git a/ggml/src/ggml-cuda/out-prod.cuh b/ggml/src/ggml-cuda/out-prod.cuh new file mode 100644 index 000000000..a0046f5f8 --- /dev/null +++ b/ggml/src/ggml-cuda/out-prod.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 45408ce86..1702e4ce2 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -69,8 +69,8 @@ static __global__ void quantize_mmq_q8_1( // Exchange max. abs. value between vals_per_scale/4 threads. #pragma unroll - for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); + for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE)); } float sum; @@ -79,8 +79,8 @@ static __global__ void quantize_mmq_q8_1( // Exchange calculate sum across vals_per_sum/4 threads. #pragma unroll - for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) { - sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE); + for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { + sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); } } diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 88f586d68..18f691b2d 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -4,6 +4,11 @@ struct rope_corr_dims { float v[2]; }; + +struct mrope_sections { + int v[4]; +}; + static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -11,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +template static __device__ void rope_yarn( - float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta) { + const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor, + float mscale, float & cos_theta, float & sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -24,24 +30,28 @@ static __device__ void rope_yarn( // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - *cos_theta = cosf(theta) * mscale; - *sin_theta = sinf(theta) * mscale; + cos_theta = cosf(theta) * mscale; + sin_theta = sinf(theta) * mscale; + if (!forward) { + sin_theta *= -1.0f; + } } -template +template static __global__ void rope_norm( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row_dst*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -49,39 +59,43 @@ static __global__ void rope_norm( return; } - const int i = row*ne0 + i0; - const int i2 = row/p_delta_rows; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; - const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + const int idst = row_dst*ne0 + i0; + const int ix = channel_x*s2 + row_x*s1 + i0; + + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + 1]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + 1]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + 1] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { return; } - const int row = blockDim.x*blockIdx.x + threadIdx.x; + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= n_dims) { - const int i = row*ne0 + i0; + const int i = row_dst*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -89,29 +103,140 @@ static __global__ void rope_neox( return; } - const int i = row*ne0 + i0/2; - const int i2 = row/p_delta_rows; + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; - const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + + const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims/2]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template +template +static __global__ void rope_multi( + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, + const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne0) { + return; + } + + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + + if (i0 >= n_dims) { + const int i = row_dst*ne0 + i0; + + dst[i + 0] = x[i + 0]; + dst[i + 1] = x[i + 1]; + + return; + } + + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < sections.v[0]) { + theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); + } + else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + } + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; + + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; +} + +template +static __global__ void rope_vision( + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne0) { + return; + } + + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + + const int idst = row_dst*ne0 + i0/2; + const int ix = channel_x*s2 + row_x*s1 + i0/2; + + const int sect_dims = sections.v[0] + sections.v[1]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < sections.v[0]) { + const int p = sector; + theta_base = pos[channel_x]*powf(theta_scale, p); + } + else if (sector >= sections.v[0] && sector < sec_w) { + const int p = sector - sections.v[0]; + theta_base = pos[channel_x + ne2]*powf(theta_scale, p); + } + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims]; + + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims] = x0*sin_theta + x1*cos_theta; +} + +template static void rope_norm_cuda( - const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -120,22 +245,21 @@ static void rope_norm_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_norm<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_norm<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } else { - rope_norm<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_norm<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } } -template +template static void rope_neox_cuda( - const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -144,48 +268,66 @@ static void rope_neox_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_neox<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_neox<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } else { - rope_neox<<>>( - x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors - ); + rope_neox<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors); } } -static void rope_norm_cuda_f16( - const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { +template +static void rope_multi_cuda( + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); - rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + rope_multi<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + } else { + rope_multi<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + } } -static void rope_norm_cuda_f32( - const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { +template +static void rope_vision_cuda( + const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, + const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq) + // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE); - rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + rope_vision<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + } else { + rope_vision<<>>( + x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + } } -static void rope_neox_cuda_f16( - const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - - rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -static void rope_neox_cuda_f32( - const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream -) { - - rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); -} - -void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +template +void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src2 = dst->src[2]; @@ -196,20 +338,24 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; + const int64_t ne00 = src0->ne[0]; // head dims + const int64_t ne01 = src0->ne[1]; // num heads + const int64_t ne02 = src0->ne[2]; // num heads const int64_t nr = ggml_nrows(src0); + const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); + const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; //const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + mrope_sections sections; // RoPE alteration for extended context float freq_base; @@ -225,8 +371,19 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } const int32_t * pos = (const int32_t *) src1_d; @@ -241,31 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // compute if (is_neox) { if (src0->type == GGML_TYPE_F32) { - rope_neox_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_neox_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_neox_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_neox_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + } else { + GGML_ABORT("fatal error"); + } + } else if (is_mrope && !is_vision) { + if (src0->type == GGML_TYPE_F32) { + rope_multi_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_multi_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + } else { + GGML_ABORT("fatal error"); + } + } else if (is_vision) { + if (src0->type == GGML_TYPE_F32) { + rope_vision_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_vision_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32) { - rope_norm_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_norm_cuda( + (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_norm_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, stream - ); + rope_norm_cuda( + (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } else { GGML_ABORT("fatal error"); } } } + +void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_rope_impl(ctx, dst); +} + +void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_rope_impl(ctx, dst); +} diff --git a/ggml/src/ggml-cuda/rope.cuh b/ggml/src/ggml-cuda/rope.cuh index 0f787a0b2..9139f3b22 100644 --- a/ggml/src/ggml-cuda/rope.cuh +++ b/ggml/src/ggml-cuda/rope.cuh @@ -3,3 +3,5 @@ #define CUDA_ROPE_BLOCK_SIZE 256 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index c24abae1f..da377200e 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -1,5 +1,7 @@ #include "common.cuh" +#include "ggml.h" #include "softmax.cuh" +#include template static __device__ __forceinline__ float t2f32(T val) { @@ -11,14 +13,26 @@ __device__ float __forceinline__ t2f32(half val) { return __half2float(val); } -template -static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { +// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. +// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif +template +static __global__ void soft_max_f32( + const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, + const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; const int tid = threadIdx.x; const int rowx = blockIdx.x; const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension + x += int64_t(rowx)*ncols; + mask += int64_t(rowy)*ncols * (mask != nullptr); + dst += int64_t(rowx)*ncols; + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; const int warp_id = threadIdx.x / WARP_SIZE; @@ -29,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst extern __shared__ float data_soft_max_f32[]; float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication // shared memory buffer to cache values between iterations: - float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols; + float * vals = use_shared ? buf_iw + WARP_SIZE : dst; float max_val = -INFINITY; @@ -41,10 +55,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst break; } - const int64_t ix = (int64_t)rowx*ncols + col; - const int64_t iy = (int64_t)rowy*ncols + col; - - const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f); + const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f); vals[col] = val; max_val = max(max_val, val); @@ -110,8 +121,32 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst return; } - const int64_t idst = (int64_t)rowx*ncols + col; - dst[idst] = vals[col] * inv_sum; + dst[col] = vals[col] * inv_sum; + } +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +static __global__ void soft_max_back_f32( + const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + + grad += int64_t(rowx)*ncols; + dstf += int64_t(rowx)*ncols; + dst += int64_t(rowx)*ncols; + + float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dgf_dot += dstf[col]*grad[col]; + } + + dgf_dot = warp_reduce_sum(dgf_dot); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; } } @@ -121,7 +156,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; const dim3 block_dims(nth, 1, 1); const dim3 block_nums(nrows_x, 1, 1); - const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); const uint32_t n_head = nrows_x/nrows_y; @@ -131,50 +166,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); // FIXME: this limit could be raised by ~2-4x on Ampere or newer - if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { + if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 64: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 128: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 256: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 512: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 1024: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 2048: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; case 4096: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; default: - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + soft_max_f32<<>> + (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); break; } } else { - const size_t shmem_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); + const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); + soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2); } } +static void soft_max_back_f32_cuda( + const float * grad, const float * dstf, float * dst, + const int ncols, const int nrows, const float scale, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + + soft_max_back_f32<<>>(grad, dstf, dst, ncols, scale); +} + void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const void * src1_d = src1 ? (const void *)src1->data : nullptr; + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + float * dst_d = (float *) dst->data; - float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); @@ -189,18 +242,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float scale = 1.0f; float max_bias = 0.0f; - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); if (use_f16) { - const half * src1_dd = (const half *)src1_d; - - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } else { - const float * src1_dd = (const float *)src1_d; - - soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream); } } + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // grad + const ggml_tensor * src1 = dst->src[1]; // forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream); +} diff --git a/ggml/src/ggml-cuda/softmax.cuh b/ggml/src/ggml-cuda/softmax.cuh index 4ef4ff86c..93dfee835 100644 --- a/ggml/src/ggml-cuda/softmax.cuh +++ b/ggml/src/ggml-cuda/softmax.cuh @@ -3,3 +3,5 @@ #define CUDA_SOFT_MAX_BLOCK_SIZE 1024 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu index 0d5e953ee..e0dafc1d2 100644 --- a/ggml/src/ggml-cuda/sum.cu +++ b/ggml/src/ggml-cuda/sum.cu @@ -1,15 +1,19 @@ +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 +#define USE_CUB +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700 + +#ifdef USE_CUB +#include +using namespace cub; +#endif // USE_CUB + #include "sumrows.cuh" #include "sum.cuh" #include -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) -#include -using namespace cub; -#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) - void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) { -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#ifdef USE_CUB size_t tmp_size = 0; DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream); ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); @@ -19,7 +23,7 @@ void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14. sum_rows_f32_cuda(x, dst, ne, 1, stream); GGML_UNUSED(pool); -#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +#endif // USE_CUB } void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 8ac669f94..6b21f407d 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -10,6 +10,16 @@ static __global__ void neg_f32(const float * x, float * dst, const int k) { dst[i] = -x[i]; } +static __global__ void step_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = x[i] > 0.0f; +} + static __global__ void gelu_f32(const float * x, float * dst, const int k) { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -41,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { dst[i] = x[i] / (1.0f + expf(-x[i])); } +static __global__ void silu_back_f32( + const float * grad, const float * xf, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xfi = xf[i]; + const float s = 1.0f / (1.0f + expf(-xfi)); + dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s)); +} + static __global__ void tanh_f32(const float * x, float * dst, int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -85,6 +108,15 @@ static __global__ void hardswish_f32(const float * x, float * dst, const int k) dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } +static __global__ void exp_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = expf(x[i]); +} + static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) { const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -134,6 +166,11 @@ static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t neg_f32<<>>(x, dst, k); } +static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE; + step_f32<<>>(x, dst, k); +} + static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); @@ -149,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ silu_f32<<>>(x, dst, k); } +static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_back_f32<<>>(grad, x, dst, k); +} + static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE; tanh_f32<<>>(x, dst, k); @@ -174,6 +216,11 @@ static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaSt hardswish_f32<<>>(x, dst, k); } +static void exp_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_EXP_BLOCK_SIZE - 1) / CUDA_EXP_BLOCK_SIZE; + exp_f32<<>>(x, dst, k); +} + static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) { const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; leaky_relu_f32<<>>(x, dst, k, negative_slope); @@ -213,6 +260,20 @@ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -241,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // input from forward pass + const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream); +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -325,6 +404,20 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + exp_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} + void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index ed2ffc461..e7f62643a 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -1,12 +1,15 @@ #include "common.cuh" #define CUDA_NEG_BLOCK_SIZE 256 +#define CUDA_STEP_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_SILU_BACK_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 #define CUDA_SIGMOID_BLOCK_SIZE 256 #define CUDA_HARDSIGMOID_BLOCK_SIZE 256 +#define CUDA_EXP_BLOCK_SIZE 256 #define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 #define CUDA_SQRT_BLOCK_SIZE 256 @@ -15,10 +18,14 @@ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -29,6 +36,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index db9f6a165..1746b0732 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #if CUDART_VERSION < 11020 diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index d0c377255..8594093f0 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -3,6 +3,7 @@ #include #include #include +#include #ifdef __HIP_PLATFORM_AMD__ // for rocblas_initialize() #include "rocblas/rocblas.h" @@ -18,6 +19,12 @@ #define CUBLAS_TF32_TENSOR_OP_MATH 0 #define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_32F HIPBLAS_R_32F +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate @@ -30,6 +37,7 @@ #define cublasSetStream hipblasSetStream #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t +#define cublasOperation_t hipblasOperation_t #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess @@ -72,6 +80,21 @@ #define cudaMemGetInfo hipMemGetInfo #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute #define cudaStreamCreateWithFlags hipStreamCreateWithFlags #define cudaStreamDestroy hipStreamDestroy #define cudaStreamFireAndForget hipStreamFireAndForget @@ -79,6 +102,28 @@ #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define __trap() do { abort(); __builtin_unreachable(); } while(0) @@ -94,6 +139,14 @@ #define __CUDA_ARCH__ 1300 +#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) +#define GCN +#endif + +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) +#define CDNA +#endif + #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ defined(__gfx1150__) || defined(__gfx1151__) #define RDNA3 @@ -112,6 +165,8 @@ #define __has_builtin(x) 0 #endif +typedef hip_bfloat16 nv_bfloat16; + typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); static __device__ __forceinline__ int __vsubss4(const int a, const int b) { diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index e50a103ac..6cc1b69ee 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_32F CUDA_R_32F @@ -26,6 +27,7 @@ #define cublasSetStream mublasSetStream #define cublasSgemm mublasSgemm #define cublasStatus_t mublasStatus_t +#define cublasOperation_t mublasOperation_t #define cublasGetStatusString mublasStatus_to_string #define cudaDataType_t musaDataType_t #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer @@ -56,6 +58,7 @@ #define cudaLaunchHostFunc musaLaunchHostFunc #define cudaMalloc musaMalloc #define cudaMallocHost musaMallocHost +#define cudaMallocManaged musaMallocManaged #define cudaMemcpy musaMemcpy #define cudaMemcpyAsync musaMemcpyAsync #define cudaMemcpyPeerAsync musaMemcpyPeerAsync @@ -131,41 +134,4 @@ #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed #define cudaStreamEndCapture musaStreamEndCapture -// XXX: Clang builtins mapping -#define __vsub4 __vsub4_musa -#define __vcmpeq4 __vcmpeq4_musa -#define __vcmpne4 __vcmpne4_musa - -#ifndef __has_builtin - #define __has_builtin(x) 0 -#endif - -typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); - -static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) { - return __vsubss4(a, b); -} - -static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) { - const uint8x4_t& va = reinterpret_cast(a); - const uint8x4_t& vb = reinterpret_cast(b); - unsigned int c; - uint8x4_t& vc = reinterpret_cast(c); -#pragma unroll - for (int i = 0; i < 4; ++i) { - vc[i] = va[i] == vb[i] ? 0xff : 0x00; - } - return c; -} - -static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) { - const uint8x4_t& va = reinterpret_cast(a); - const uint8x4_t& vb = reinterpret_cast(b); - unsigned int c; - uint8x4_t& vc = reinterpret_cast(c); -#pragma unroll - for (int i = 0; i < 4; ++i) { - vc[i] = va[i] == vb[i] ? 0x00 : 0xff; - } - return c; -} +typedef mt_bfloat16 nv_bfloat16; diff --git a/ggml/src/ggml-cuda/wkv6.cu b/ggml/src/ggml-cuda/wkv6.cu new file mode 100644 index 000000000..bbdafbee5 --- /dev/null +++ b/ggml/src/ggml-cuda/wkv6.cu @@ -0,0 +1,89 @@ +#include "common.cuh" +#include "wkv6.cuh" + +static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = CUDA_WKV_BLOCK_SIZE; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + __syncthreads(); + _tf[tid] = tf[head_i * head_size + tid]; + __syncthreads(); + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + __syncthreads(); + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4& k = (float4&)(_k[j]); + const float4& r = (float4&)(_r[j]); + const float4& tf = (float4&)(_tf[j]); + const float4& td = (float4&)(_td[j]); + float4& s = (float4&)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + y += r.x * (tf.x * kv.x + s.x); + y += r.y * (tf.y * kv.y + s.y); + y += r.z * (tf.z * kv.z + s.z); + y += r.w * (tf.w * kv.w + s.w); + + s.x = s.x * td.x + kv.x; + s.y = s.y * td.y + kv.y; + s.z = s.z * td.z + kv.z; + s.w = s.w * td.w + kv.w; + } + dst[t] = y; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * k_d = (const float *)dst->src[0]->data; + const float * v_d = (const float *)dst->src[1]->data; + const float * r_d = (const float *)dst->src[2]->data; + const float * tf_d = (const float *)dst->src[3]->data; + const float * td_d = (const float *)dst->src[4]->data; + const float * s_d = (const float *)dst->src[5]->data; + + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64 + + rwkv_wkv_f32<<>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d); +} diff --git a/ggml/src/ggml-cuda/wkv6.cuh b/ggml/src/ggml-cuda/wkv6.cuh new file mode 100644 index 000000000..a7124ee51 --- /dev/null +++ b/ggml/src/ggml-cuda/wkv6.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_WKV_BLOCK_SIZE 64 + +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt new file mode 100644 index 000000000..7a877bdc1 --- /dev/null +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -0,0 +1,118 @@ +if (NOT EXISTS $ENV{ROCM_PATH}) + if (NOT EXISTS /opt/rocm) + set(ROCM_PATH /usr) + else() + set(ROCM_PATH /opt/rocm) + endif() +else() + set(ROCM_PATH $ENV{ROCM_PATH}) +endif() + +list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) +list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake") + +# CMake on Windows doesn't support the HIP language yet +if (WIN32) + set(CXX_IS_HIPCC TRUE) +else() + string(REGEX MATCH "hipcc(\.bat)?$" CXX_IS_HIPCC "${CMAKE_CXX_COMPILER}") +endif() + +if (CXX_IS_HIPCC) + if (LINUX) + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") + endif() + + message(WARNING "Setting hipcc as the C++ compiler is legacy behavior." + " Prefer setting the HIP compiler directly. See README for details.") + endif() +else() + # Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES. + if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) + set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) + endif() + cmake_minimum_required(VERSION 3.21) + enable_language(HIP) +endif() + +find_package(hip REQUIRED) +find_package(hipblas REQUIRED) +find_package(rocblas REQUIRED) + +if (${hip_VERSION} VERSION_LESS 5.5) + message(FATAL_ERROR "At least ROCM/HIP V5.5 is required") +endif() + +message(STATUS "HIP and hipBLAS found") + +file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") +list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") + +file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") +file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) +file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) + +if (GGML_CUDA_FA_ALL_QUANTS) + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") + list(APPEND GGML_SOURCES_ROCM ${SRCS}) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) +else() + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") + list(APPEND GGML_SOURCES_ROCM ${SRCS}) + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") + list(APPEND GGML_SOURCES_ROCM ${SRCS}) + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") + list(APPEND GGML_SOURCES_ROCM ${SRCS}) +endif() + +ggml_add_backend_library(ggml-hip + ${GGML_HEADERS_ROCM} + ${GGML_SOURCES_ROCM} + ) + +# TODO: do not use CUDA definitions for HIP +if (NOT GGML_BACKEND_DL) + target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) +endif() + +add_compile_definitions(GGML_USE_HIP) + +if (GGML_HIP_UMA) + add_compile_definitions(GGML_HIP_UMA) +endif() + +if (GGML_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) +endif() + +if (GGML_CUDA_FORCE_CUBLAS) + add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) +endif() + +if (GGML_CUDA_NO_PEER_COPY) + add_compile_definitions(GGML_CUDA_NO_PEER_COPY) +endif() + +if (GGML_HIP_GRAPHS) + add_compile_definitions(GGML_HIP_GRAPHS) +endif() + +if (GGML_HIP_NO_VMM) + add_compile_definitions(GGML_HIP_NO_VMM) +endif() + +if (CXX_IS_HIPCC) + set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) + target_link_libraries(ggml-hip PRIVATE hip::device) +else() + set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP) +endif() + +if (GGML_STATIC) + message(FATAL_ERROR "Static linking not supported for HIP/ROCm") +endif() + +target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas) diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 961f3c67b..eab017889 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -1,32 +1,482 @@ #pragma once -#include "ggml.h" - // GGML internal header +#include "ggml.h" +#include "gguf.h" + #include +#include #include // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/ -#include #include -#include // memcpy -#include // fabsf +#include +#include -#undef MIN -#undef MAX +#ifdef __ARM_FEATURE_SVE +#include +#endif // __ARM_FEATURE_SVE -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#if defined(__ARM_NEON) && !defined(__CUDACC__) +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include +#endif -#if defined(_MSC_VER) +#if defined(__F16C__) +#include +#endif -#define m512bh(p) p -#define m512i(p) p +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef MIN +# define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +#ifndef MAX +# define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif + +// required for mmap as gguf only guarantees 32-byte alignment +#define TENSOR_ALIGNMENT 32 + +// static_assert should be a #define, but if it's not, +// fall back to the _Static_assert C11 keyword. +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef __cplusplus + #ifndef static_assert + #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) + #define static_assert(cond, msg) _Static_assert(cond, msg) + #else + #define static_assert(cond, msg) struct global_scope_noop_trick + #endif + #endif +#endif + +static inline int ggml_up32(int n) { + return (n + 31) & ~31; +} + +//static inline int ggml_up64(int n) { +// return (n + 63) & ~63; +//} + +static inline int ggml_up(int n, int m) { + // assert m is a power of 2 + GGML_ASSERT((m & (m - 1)) == 0); + return (n + m - 1) & ~(m - 1); +} + +// +// logging +// + +GGML_ATTRIBUTE_FORMAT(2, 3) +GGML_API void ggml_log_internal (enum ggml_log_level level, const char * format, ...); +GGML_API void ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data); + +#define GGML_LOG(...) ggml_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__) +#define GGML_LOG_INFO(...) ggml_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define GGML_LOG_WARN(...) ggml_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define GGML_LOG_ERROR(...) ggml_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define GGML_LOG_DEBUG(...) ggml_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#define GGML_LOG_CONT(...) ggml_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__) + +#define GGML_DEBUG 0 + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) GGML_LOG_DEBUG(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) GGML_LOG_DEBUG(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) GGML_LOG_DEBUG(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +// tensor params + +static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) { + GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings + assert(params_size <= GGML_MAX_OP_PARAMS); + memcpy(tensor->op_params, params, params_size); +} + +static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); + return ((const int32_t *)(tensor->op_params))[i]; +} + +static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(float)); + return ((const float *)(tensor->op_params))[i]; +} + +static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); + ((int32_t *)(tensor->op_params))[i] = value; +} + +static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) { + assert(i < GGML_MAX_OP_PARAMS / sizeof(float)); + ((float *)(tensor->op_params))[i] = value; +} + +struct ggml_map_custom1_op_params { + ggml_custom1_op_t fun; + int n_tasks; + void * userdata; +}; + +struct ggml_map_custom2_op_params { + ggml_custom2_op_t fun; + int n_tasks; + void * userdata; +}; + +struct ggml_map_custom3_op_params { + ggml_custom3_op_t fun; + int n_tasks; + void * userdata; +}; + +// bitset + +typedef uint32_t ggml_bitset_t; + +static_assert(sizeof(ggml_bitset_t) == 4, "bitset_t constants must be updated"); +#define BITSET_SHR 5 // log2(sizeof(ggml_bitset_t)*8) +#define BITSET_MASK (sizeof(ggml_bitset_t)*8 - 1) + +static size_t ggml_bitset_size(size_t n) { + return (n + BITSET_MASK) >> BITSET_SHR; +} + +static inline bool ggml_bitset_get(const ggml_bitset_t * bitset, size_t i) { + return !!(bitset[i >> BITSET_SHR] & (1u << (i & BITSET_MASK))); +} + +static inline void ggml_bitset_set(ggml_bitset_t * bitset, size_t i) { + bitset[i >> BITSET_SHR] |= (1u << (i & BITSET_MASK)); +} + +static inline void ggml_bitset_clear(ggml_bitset_t * bitset, size_t i) { + bitset[i >> BITSET_SHR] &= ~(1u << (i & BITSET_MASK)); +} + +// hash set + +#define GGML_HASHSET_FULL ((size_t)-1) +#define GGML_HASHSET_ALREADY_EXISTS ((size_t)-2) + +struct ggml_hash_set { + size_t size; + ggml_bitset_t * used; // whether or not the keys are in use i.e. set + struct ggml_tensor ** keys; // actual tensors in the set, keys[i] is only defined if ggml_bitset_get(used, i) +}; + +struct ggml_hash_set ggml_hash_set_new(size_t size); +void ggml_hash_set_free(struct ggml_hash_set * hash_set); + +// returns the minimum size for a hash set that can hold min_sz elements +size_t ggml_hash_size(size_t min_sz); + +// remove all elements from the hash set +void ggml_hash_set_reset(struct ggml_hash_set * hash_set); + +// returns true if key is in the hash set +static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); + +// returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key); + +// returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full +static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key); + +// return index, asserts if table is full +static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key); + +// hash function for ggml_tensor +static inline size_t ggml_hash(const struct ggml_tensor * p) { + // the last 4 bits are always zero due to alignment + return (size_t)(uintptr_t)p >> 4; +} + +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key) { + size_t h = ggml_hash(key) % hash_set->size; + + // linear probing + size_t i = h; + while (ggml_bitset_get(hash_set->used, i) && hash_set->keys[i] != key) { + i = (i + 1) % hash_set->size; + if (i == h) { + // visited all hash table entries -> not found + return GGML_HASHSET_FULL; + } + } + return i; +} + +static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) { + size_t i = ggml_hash_find(hash_set, key); + return i != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, i); +} + +static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) { + size_t h = ggml_hash(key) % hash_set->size; + + // linear probing + size_t i = h; + do { + if (!ggml_bitset_get(hash_set->used, i)) { + ggml_bitset_set(hash_set->used, i); + hash_set->keys[i] = key; + return i; + } + if (hash_set->keys[i] == key) { + return GGML_HASHSET_ALREADY_EXISTS; + } + i = (i + 1) % hash_set->size; + } while (i != h); + + // visited all hash table entries -> not found + GGML_ABORT("fatal error"); +} + +static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) { + size_t h = ggml_hash(key) % hash_set->size; + + // linear probing + size_t i = h; + do { + if (!ggml_bitset_get(hash_set->used, i)) { + ggml_bitset_set(hash_set->used, i); + hash_set->keys[i] = key; + return i; + } + if (hash_set->keys[i] == key) { + return i; + } + i = (i + 1) % hash_set->size; + } while (i != h); + + // visited all hash table entries -> not found + GGML_ABORT("fatal error"); +} + +// computation graph + +enum ggml_cgraph_eval_order { + GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0, + GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT, + GGML_CGRAPH_EVAL_ORDER_COUNT +}; + +struct ggml_cgraph { + int size; // maximum number of nodes/leafs/grads/grad_accs + int n_nodes; // number of nodes currently in use + int n_leafs; // number of leafs currently in use + + struct ggml_tensor ** nodes; // tensors with data that can change if the graph is evaluated + struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes + struct ggml_tensor ** grad_accs; // accumulators for node gradients + struct ggml_tensor ** leafs; // tensors with constant data + + struct ggml_hash_set visited_hash_set; + + enum ggml_cgraph_eval_order order; +}; + +// returns a slice of cgraph with nodes [i0, i1) +// the slice does not have leafs or gradients +// if you need the gradients, get them from the original graph +struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1); + +// Memory allocation + +GGML_API void * ggml_aligned_malloc(size_t size); +GGML_API void ggml_aligned_free(void * ptr, size_t size); + +// FP16 to FP32 conversion + +#if defined(__ARM_NEON) + #if defined(_MSC_VER) || (defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) + typedef uint16_t ggml_fp16_internal_t; + #else + typedef __fp16 ggml_fp16_internal_t; + #endif +#endif + +#if defined(__ARM_NEON) && !defined(_MSC_VER) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) + #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + + #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + + static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + ggml_fp16_internal_t tmp; + memcpy(&tmp, &h, sizeof(ggml_fp16_t)); + return (float)tmp; + } + + static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + ggml_fp16_t res; + ggml_fp16_internal_t tmp = f; + memcpy(&res, &tmp, sizeof(ggml_fp16_t)); + return res; + } + +#elif defined(__F16C__) + + #ifdef _MSC_VER + #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) + #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) + #else + #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) + #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) + #endif + +#elif defined(__POWER9_VECTOR__) + + #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + /* the inline asm below is about 12% faster than the lookup method */ + #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) + #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) + + static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + register float f; + register double d; + __asm__( + "mtfprd %0,%2\n" + "xscvhpdp %0,%0\n" + "frsp %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=f"(f): + /* in */ "r"(h)); + return f; + } + + static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + register double d; + register ggml_fp16_t r; + __asm__( /* xscvdphp can work on double or single precision */ + "xscvdphp %0,%2\n" + "mffprd %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=r"(r): + /* in */ "f"(f)); + return r; + } #else -#define m512bh(p) (__m512bh)(p) -#define m512i(p) (__m512i)(p) + // FP16 <-> FP32 + // ref: https://github.com/Maratyszcza/FP16 + static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; + } + + static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; + } + + static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; + #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L) + const float exp_scale = 0x1.0p-112f; + #else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); + #endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); + } + + static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { + #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; + #else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); + #endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); + } + + #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) + #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) + +#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) + +// precomputed f32 table for f16 (256 KB) +// defined in ggml.c, initialized in ggml_init() +GGML_API float ggml_table_f32_f16[1 << 16]; + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, +// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. +// This is also true for POWER9. +#if !defined(GGML_FP16_TO_FP32) +inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return ggml_table_f32_f16[s]; +} + +#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) +#endif + +#if !defined(GGML_FP32_TO_FP16) +#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) #endif /** @@ -104,647 +554,14 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) #ifdef __cplusplus -extern "C" { -#endif - -// static_assert should be a #define, but if it's not, -// fall back to the _Static_assert C11 keyword. -// if C99 - static_assert is noop -// ref: https://stackoverflow.com/a/53923785/4039976 -#ifndef __cplusplus -#ifndef static_assert -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) -#define static_assert(cond, msg) _Static_assert(cond, msg) -#else -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif -#endif -#endif - -// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 -#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) -#ifndef __FMA__ -#define __FMA__ -#endif -#ifndef __F16C__ -#define __F16C__ -#endif -#endif - -// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available -#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)) -#ifndef __SSE3__ -#define __SSE3__ -#endif -#ifndef __SSSE3__ -#define __SSSE3__ -#endif -#endif - -#if defined(__ARM_FEATURE_SVE) -#include -#include -#endif - -// 16-bit float -// on Arm, we use __fp16 -// on x86, we use uint16_t -#if defined(__ARM_NEON) - -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - -#ifdef _MSC_VER - -typedef uint16_t ggml_fp16_internal_t; - -#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } - -#else - -typedef __fp16 ggml_fp16_internal_t; - -#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } - -#endif // _MSC_VER - -#if !defined(__aarch64__) - -// 32-bit ARM compatibility - -// vaddlvq_s16 -// vpaddq_s16 -// vpaddq_s32 -// vaddvq_s32 -// vaddvq_f32 -// vmaxvq_f32 -// vcvtnq_s32_f32 -// vzip1_u8 -// vzip2_u8 - -inline static int32_t vaddlvq_s16(int16x8_t v) { - int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v))); - return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2); -} - -inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { - int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); - int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); - return vcombine_s16(a0, b0); -} - -inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) { - int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a)); - int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b)); - return vcombine_s32(a0, b0); -} - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} - -inline static float vaddvq_f32(float32x4_t v) { - return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); -} - -inline static float vmaxvq_f32(float32x4_t v) { - return - MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), - MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); -} - -inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { - int32x4_t res; - - res[0] = roundf(vgetq_lane_f32(v, 0)); - res[1] = roundf(vgetq_lane_f32(v, 1)); - res[2] = roundf(vgetq_lane_f32(v, 2)); - res[3] = roundf(vgetq_lane_f32(v, 3)); - - return res; -} - -inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { - uint8x8_t res; - - res[0] = a[0]; res[1] = b[0]; - res[2] = a[1]; res[3] = b[1]; - res[4] = a[2]; res[5] = b[2]; - res[6] = a[3]; res[7] = b[3]; - - return res; -} - -inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { - uint8x8_t res; - - res[0] = a[4]; res[1] = b[4]; - res[2] = a[5]; res[3] = b[5]; - res[4] = a[6]; res[5] = b[6]; - res[6] = a[7]; res[7] = b[7]; - - return res; -} - -// vld1q_s16_x2 -// vld1q_u8_x2 -// vld1q_u8_x4 -// vld1q_s8_x2 -// vld1q_s8_x4 -// TODO: double-check these work correctly - -typedef struct ggml_int16x8x2_t { - int16x8_t val[2]; -} ggml_int16x8x2_t; - -inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) { - ggml_int16x8x2_t res; - - res.val[0] = vld1q_s16(ptr + 0); - res.val[1] = vld1q_s16(ptr + 8); - - return res; -} - -typedef struct ggml_uint8x16x2_t { - uint8x16_t val[2]; -} ggml_uint8x16x2_t; - -inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) { - ggml_uint8x16x2_t res; - - res.val[0] = vld1q_u8(ptr + 0); - res.val[1] = vld1q_u8(ptr + 16); - - return res; -} - -typedef struct ggml_uint8x16x4_t { - uint8x16_t val[4]; -} ggml_uint8x16x4_t; - -inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) { - ggml_uint8x16x4_t res; - - res.val[0] = vld1q_u8(ptr + 0); - res.val[1] = vld1q_u8(ptr + 16); - res.val[2] = vld1q_u8(ptr + 32); - res.val[3] = vld1q_u8(ptr + 48); - - return res; -} - -typedef struct ggml_int8x16x2_t { - int8x16_t val[2]; -} ggml_int8x16x2_t; - -inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) { - ggml_int8x16x2_t res; - - res.val[0] = vld1q_s8(ptr + 0); - res.val[1] = vld1q_s8(ptr + 16); - - return res; -} - -typedef struct ggml_int8x16x4_t { - int8x16_t val[4]; -} ggml_int8x16x4_t; - -inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) { - ggml_int8x16x4_t res; - - res.val[0] = vld1q_s8(ptr + 0); - res.val[1] = vld1q_s8(ptr + 16); - res.val[2] = vld1q_s8(ptr + 32); - res.val[3] = vld1q_s8(ptr + 48); - - return res; -} - -// NOTE: not tested -inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) { - int8x16_t res; - - res[ 0] = a[b[ 0]]; - res[ 1] = a[b[ 1]]; - res[ 2] = a[b[ 2]]; - res[ 3] = a[b[ 3]]; - res[ 4] = a[b[ 4]]; - res[ 5] = a[b[ 5]]; - res[ 6] = a[b[ 6]]; - res[ 7] = a[b[ 7]]; - res[ 8] = a[b[ 8]]; - res[ 9] = a[b[ 9]]; - res[10] = a[b[10]]; - res[11] = a[b[11]]; - res[12] = a[b[12]]; - res[13] = a[b[13]]; - res[14] = a[b[14]]; - res[15] = a[b[15]]; - - return res; -} - -// NOTE: not tested -inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { - uint8x16_t res; - - res[ 0] = a[b[ 0]]; - res[ 1] = a[b[ 1]]; - res[ 2] = a[b[ 2]]; - res[ 3] = a[b[ 3]]; - res[ 4] = a[b[ 4]]; - res[ 5] = a[b[ 5]]; - res[ 6] = a[b[ 6]]; - res[ 7] = a[b[ 7]]; - res[ 8] = a[b[ 8]]; - res[ 9] = a[b[ 9]]; - res[10] = a[b[10]]; - res[11] = a[b[11]]; - res[12] = a[b[12]]; - res[13] = a[b[13]]; - res[14] = a[b[14]]; - res[15] = a[b[15]]; - - return res; -} - -#else - -#define ggml_int16x8x2_t int16x8x2_t -#define ggml_uint8x16x2_t uint8x16x2_t -#define ggml_uint8x16x4_t uint8x16x4_t -#define ggml_int8x16x2_t int8x16x2_t -#define ggml_int8x16x4_t int8x16x4_t - -#define ggml_vld1q_s16_x2 vld1q_s16_x2 -#define ggml_vld1q_u8_x2 vld1q_u8_x2 -#define ggml_vld1q_u8_x4 vld1q_u8_x4 -#define ggml_vld1q_s8_x2 vld1q_s8_x2 -#define ggml_vld1q_s8_x4 vld1q_s8_x4 -#define ggml_vqtbl1q_s8 vqtbl1q_s8 -#define ggml_vqtbl1q_u8 vqtbl1q_u8 - -#endif // !defined(__aarch64__) - -#if !defined(__ARM_FEATURE_DOTPROD) - -inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { - const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); - const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); - - return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))); -} - -#else - -#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c) - -#endif // !defined(__ARM_FEATURE_DOTPROD) - -#endif // defined(__ARM_NEON) - -#if defined(__ARM_NEON) && !defined(_MSC_VER) - -#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) - -#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) - -static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - ggml_fp16_internal_t tmp; - memcpy(&tmp, &h, sizeof(ggml_fp16_t)); - return (float)tmp; -} - -static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { - ggml_fp16_t res; - ggml_fp16_internal_t tmp = f; - memcpy(&res, &tmp, sizeof(ggml_fp16_t)); - return res; -} - -#else - -#ifdef __wasm_simd128__ -#include -#else -#ifdef __POWER9_VECTOR__ -#include -#undef bool -#define bool _Bool -#else -#if defined(_MSC_VER) || defined(__MINGW32__) -#include -#else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) -#if !defined(__riscv) -#include -#endif -#endif -#endif -#endif -#endif - -#ifdef __riscv_v_intrinsic -#include -#endif - -#if defined(__loongarch64) -#if defined(__loongarch_asx) -#include -#endif -#if defined(__loongarch_sx) -#include -#endif -#endif - -#if defined(__loongarch_asx) - -typedef union { - int32_t i; - float f; -} ft_union; - -/* float type data load instructions */ -static __m128 __lsx_vreplfr2vr_s(float val) { - ft_union fi_tmpval = {.f = val}; - return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); -} - -static __m256 __lasx_xvreplfr2vr_s(float val) { - ft_union fi_tmpval = {.f = val}; - return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); } #endif -#ifdef __F16C__ - -#ifdef _MSC_VER -#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) -#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) -#else -#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) -#endif - -#elif defined(__POWER9_VECTOR__) - -#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) -/* the inline asm below is about 12% faster than the lookup method */ -#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) - -static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - register float f; - register double d; - __asm__( - "mtfprd %0,%2\n" - "xscvhpdp %0,%0\n" - "frsp %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=f"(f): - /* in */ "r"(h)); - return f; -} - -static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { - register double d; - register ggml_fp16_t r; - __asm__( /* xscvdphp can work on double or single precision */ - "xscvdphp %0,%2\n" - "mffprd %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=r"(r): - /* in */ "f"(f)); - return r; -} - -#else - -// FP16 <-> FP32 -// ref: https://github.com/Maratyszcza/FP16 - -static inline float fp32_from_bits(uint32_t w) { - union { - uint32_t as_bits; - float as_value; - } fp32; - fp32.as_bits = w; - return fp32.as_value; -} - -static inline uint32_t fp32_to_bits(float f) { - union { - float as_value; - uint32_t as_bits; - } fp32; - fp32.as_value = f; - return fp32.as_bits; -} - -static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { - const uint32_t w = (uint32_t) h << 16; - const uint32_t sign = w & UINT32_C(0x80000000); - const uint32_t two_w = w + w; - - const uint32_t exp_offset = UINT32_C(0xE0) << 23; -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float exp_scale = 0x1.0p-112f; -#else - const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); -#endif - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = sign | - (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); -} - -static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float scale_to_inf = 0x1.0p+112f; - const float scale_to_zero = 0x1.0p-110f; -#else - const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); - const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); -#endif - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); -} - -#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x) -#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x) - -#endif // __F16C__ - -#endif // defined(__ARM_NEON) && (!defined(__MSC_VER) - -#ifdef __ARM_FEATURE_SVE -#include -#endif // __ARM_FEATURE_SVE - -// precomputed f32 table for f16 (256 KB) -// defined in ggml.c, initialized in ggml_init() -extern float ggml_table_f32_f16[1 << 16]; - -// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, -// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON. -// This is also true for POWER9. -#if !defined(GGML_FP16_TO_FP32) -inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { - uint16_t s; - memcpy(&s, &f, sizeof(uint16_t)); - return ggml_table_f32_f16[s]; -} - -#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x) -#endif - -#if !defined(GGML_FP32_TO_FP16) -#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) -#endif - -// bitset - -static_assert(sizeof(ggml_bitset_t) == 4, "bitset_t constants must be updated"); -#define BITSET_SHR 5 // log2(sizeof(ggml_bitset_t)*8) -#define BITSET_MASK (sizeof(ggml_bitset_t)*8 - 1) - -static size_t ggml_bitset_size(size_t n) { - return (n + BITSET_MASK) >> BITSET_SHR; -} - -static inline bool ggml_bitset_get(const ggml_bitset_t * bitset, size_t i) { - return !!(bitset[i >> BITSET_SHR] & (1u << (i & BITSET_MASK))); -} - -static inline void ggml_bitset_set(ggml_bitset_t * bitset, size_t i) { - bitset[i >> BITSET_SHR] |= (1u << (i & BITSET_MASK)); -} - -static inline void ggml_bitset_clear(ggml_bitset_t * bitset, size_t i) { - bitset[i >> BITSET_SHR] &= ~(1u << (i & BITSET_MASK)); -} - -// hash set - -#define GGML_HASHSET_FULL ((size_t)-1) -#define GGML_HASHSET_ALREADY_EXISTS ((size_t)-2) - -struct ggml_hash_set ggml_hash_set_new(size_t size); -void ggml_hash_set_free(struct ggml_hash_set * hash_set); - -// returns the minimum size for a hash set that can hold min_sz elements -size_t ggml_hash_size(size_t min_sz); - -// remove all elements from the hash set -void ggml_hash_set_reset(struct ggml_hash_set * hash_set); - -// returns true if key is in the hash set -static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); - -// returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted -static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); - -// returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full -static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key); - -// return index, asserts if table is full -static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key); - -// hash function for ggml_tensor -static inline size_t ggml_hash(const struct ggml_tensor * p) { - // the last 4 bits are always zero due to alignment - return (size_t)(uintptr_t)p >> 4; -} - -static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) { - size_t h = ggml_hash(key) % hash_set->size; - - // linear probing - size_t i = h; - while (ggml_bitset_get(hash_set->used, i) && hash_set->keys[i] != key) { - i = (i + 1) % hash_set->size; - if (i == h) { - // visited all hash table entries -> not found - return GGML_HASHSET_FULL; - } - } - return i; -} - -static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) { - size_t i = ggml_hash_find(hash_set, key); - return i != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, i); -} - -static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) { - size_t h = ggml_hash(key) % hash_set->size; - - // linear probing - size_t i = h; - do { - if (!ggml_bitset_get(hash_set->used, i)) { - ggml_bitset_set(hash_set->used, i); - hash_set->keys[i] = key; - return i; - } - if (hash_set->keys[i] == key) { - return GGML_HASHSET_ALREADY_EXISTS; - } - i = (i + 1) % hash_set->size; - } while (i != h); - - // visited all hash table entries -> not found - GGML_ABORT("fatal error"); -} - -static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) { - size_t h = ggml_hash(key) % hash_set->size; - - // linear probing - size_t i = h; - do { - if (!ggml_bitset_get(hash_set->used, i)) { - ggml_bitset_set(hash_set->used, i); - hash_set->keys[i] = key; - return i; - } - if (hash_set->keys[i] == key) { - return i; - } - i = (i + 1) % hash_set->size; - } while (i != h); - - // visited all hash table entries -> not found - GGML_ABORT("fatal error"); -} - #ifdef __cplusplus -} -#endif +#include + +// expose GGUF internals for test code +GGML_API size_t gguf_type_size(enum gguf_type type); +GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); +GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta); +#endif // __cplusplus diff --git a/ggml/src/ggml-kompute/CMakeLists.txt b/ggml/src/ggml-kompute/CMakeLists.txt new file mode 100644 index 000000000..c9109d5e8 --- /dev/null +++ b/ggml/src/ggml-kompute/CMakeLists.txt @@ -0,0 +1,166 @@ + +find_package(Vulkan COMPONENTS glslc REQUIRED) +find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc) + +if (NOT glslc_executable) + message(FATAL_ERROR "glslc not found") +endif() + +ggml_add_backend_library(ggml-kompute + ggml-kompute.cpp + ../../include/ggml-kompute.h + ) + +target_link_libraries(ggml-kompute PRIVATE ggml-base kompute) +target_include_directories(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + +add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1) + +function(compile_shader) + set(options) + set(oneValueArgs) + set(multiValueArgs SOURCES) + cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + foreach(source ${compile_shader_SOURCES}) + get_filename_component(filename ${source} NAME) + set(spv_file ${filename}.spv) + add_custom_command( + OUTPUT ${spv_file} + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source} + ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp + ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp + ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp + ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp + COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source} + COMMENT "Compiling ${source} to ${spv_file}" + ) + + get_filename_component(RAW_FILE_NAME ${spv_file} NAME) + set(FILE_NAME "shader${RAW_FILE_NAME}") + string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME}) + string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE) + string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}") + set(OUTPUT_HEADER_FILE "${HEADER_FILE}") + message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}") + if(CMAKE_GENERATOR MATCHES "Visual Studio") + add_custom_command( + OUTPUT ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_BINARY_DIR}/bin/$/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} + DEPENDS ${spv_file} xxd + COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$/xxd" + ) + else() + add_custom_command( + OUTPUT ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE} + COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE} + DEPENDS ${spv_file} xxd + COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd" + ) + endif() + endforeach() +endfunction() + +if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt") + message(STATUS "Kompute found") + set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level") + add_subdirectory(kompute) + + # Compile our shaders + compile_shader(SOURCES + kompute-shaders/op_scale.comp + kompute-shaders/op_scale_8.comp + kompute-shaders/op_add.comp + kompute-shaders/op_addrow.comp + kompute-shaders/op_mul.comp + kompute-shaders/op_silu.comp + kompute-shaders/op_relu.comp + kompute-shaders/op_gelu.comp + kompute-shaders/op_softmax.comp + kompute-shaders/op_norm.comp + kompute-shaders/op_rmsnorm.comp + kompute-shaders/op_diagmask.comp + kompute-shaders/op_mul_mat_mat_f32.comp + kompute-shaders/op_mul_mat_f16.comp + kompute-shaders/op_mul_mat_q8_0.comp + kompute-shaders/op_mul_mat_q4_0.comp + kompute-shaders/op_mul_mat_q4_1.comp + kompute-shaders/op_mul_mat_q4_k.comp + kompute-shaders/op_mul_mat_q6_k.comp + kompute-shaders/op_getrows_f32.comp + kompute-shaders/op_getrows_f16.comp + kompute-shaders/op_getrows_q4_0.comp + kompute-shaders/op_getrows_q4_1.comp + kompute-shaders/op_getrows_q6_k.comp + kompute-shaders/op_rope_norm_f16.comp + kompute-shaders/op_rope_norm_f32.comp + kompute-shaders/op_rope_neox_f16.comp + kompute-shaders/op_rope_neox_f32.comp + kompute-shaders/op_cpy_f16_f16.comp + kompute-shaders/op_cpy_f16_f32.comp + kompute-shaders/op_cpy_f32_f16.comp + kompute-shaders/op_cpy_f32_f32.comp + ) + + # Create a custom target for our generated shaders + add_custom_target(generated_shaders DEPENDS + shaderop_scale.h + shaderop_scale_8.h + shaderop_add.h + shaderop_addrow.h + shaderop_mul.h + shaderop_silu.h + shaderop_relu.h + shaderop_gelu.h + shaderop_softmax.h + shaderop_norm.h + shaderop_rmsnorm.h + shaderop_diagmask.h + shaderop_mul_mat_mat_f32.h + shaderop_mul_mat_f16.h + shaderop_mul_mat_q8_0.h + shaderop_mul_mat_q4_0.h + shaderop_mul_mat_q4_1.h + shaderop_mul_mat_q4_k.h + shaderop_mul_mat_q6_k.h + shaderop_getrows_f32.h + shaderop_getrows_f16.h + shaderop_getrows_q4_0.h + shaderop_getrows_q4_1.h + shaderop_getrows_q6_k.h + shaderop_rope_norm_f16.h + shaderop_rope_norm_f32.h + shaderop_rope_neox_f16.h + shaderop_rope_neox_f32.h + shaderop_cpy_f16_f16.h + shaderop_cpy_f16_f32.h + shaderop_cpy_f32_f16.h + shaderop_cpy_f32_f32.h + ) + + # Create a custom command that depends on the generated_shaders + add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp + COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp + DEPENDS generated_shaders + COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp" + ) + + # Add the stamp to the main sources to ensure dependency tracking + target_sources(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp) +else() + message(WARNING "Kompute not found") +endif() diff --git a/ggml/src/ggml-kompute.cpp b/ggml/src/ggml-kompute/ggml-kompute.cpp similarity index 82% rename from ggml/src/ggml-kompute.cpp rename to ggml/src/ggml-kompute/ggml-kompute.cpp index 41ac63fa4..505792271 100644 --- a/ggml/src/ggml-kompute.cpp +++ b/ggml/src/ggml-kompute/ggml-kompute.cpp @@ -1,4 +1,4 @@ -#include "ggml.h" +#include "ggml-impl.h" #include "ggml-backend.h" #include "ggml-backend-impl.h" #include "ggml-kompute.h" @@ -20,6 +20,7 @@ #include "shaderop_mul_mat_q8_0.h" #include "shaderop_mul_mat_q4_0.h" #include "shaderop_mul_mat_q4_1.h" +#include "shaderop_mul_mat_q4_k.h" #include "shaderop_mul_mat_q6_k.h" #include "shaderop_mul_mat_mat_f32.h" #include "shaderop_getrows_f32.h" @@ -27,8 +28,10 @@ #include "shaderop_getrows_q4_0.h" #include "shaderop_getrows_q4_1.h" #include "shaderop_getrows_q6_k.h" -#include "shaderop_rope_f16.h" -#include "shaderop_rope_f32.h" +#include "shaderop_rope_norm_f16.h" +#include "shaderop_rope_norm_f32.h" +#include "shaderop_rope_neox_f16.h" +#include "shaderop_rope_neox_f32.h" #include "shaderop_cpy_f16_f16.h" #include "shaderop_cpy_f16_f32.h" #include "shaderop_cpy_f32_f16.h" @@ -42,6 +45,7 @@ #include #include #include +#include #include #include #include @@ -273,18 +277,9 @@ static std::vector ggml_vk_available_devices_internal(size_t mem return results; } -// public API returns a C-style array -ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count) { - auto devices = ggml_vk_available_devices_internal(memoryRequired); - *count = devices.size(); - if (devices.empty()) { - return nullptr; - } - - size_t nbytes = sizeof (ggml_vk_device) * (devices.size()); - auto * arr = static_cast(malloc(nbytes)); - memcpy(arr, devices.data(), nbytes); - return arr; +static std::vector& ggml_vk_available_devices() { + static std::vector devices = ggml_vk_available_devices_internal(0); + return devices; } static void ggml_vk_filterByVendor(std::vector& devices, const std::string& targetVendor) { @@ -341,7 +336,7 @@ ggml_vk_device ggml_vk_current_device() { if (!komputeManager()->hasDevice()) return ggml_vk_device(); - auto devices = ggml_vk_available_devices_internal(0); + auto devices = ggml_vk_available_devices(); ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data()); GGML_ASSERT(!devices.empty()); return devices.front(); @@ -352,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t std::vector descriptorPoolSizes = { vk::DescriptorPoolSize( vk::DescriptorType::eStorageBuffer, - 3 * size // Descriptor count is number of possible tensors to pass into an algorithm + 4 * size // Descriptor count is number of possible tensors to pass into an algorithm ) }; @@ -795,7 +790,8 @@ static void ggml_vk_soft_max( const std::shared_ptr& out, uint32_t inAOff, uint32_t inBOff, uint32_t outOff, int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03, - float scale + float scale, float max_bias, float m0, float m1, + uint32_t n_head_log2 ) { const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv, kp::shader_data::op_softmax_comp_spv_len); @@ -803,12 +799,14 @@ static void ggml_vk_soft_max( struct PushConstants { uint32_t inAOff, inBOff, outOff; int32_t ne00, ne01, ne02; - float scale; + float scale, max_bias, m0, m1; + uint32_t n_head_log2; int32_t mask; } pushConsts { safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4), ne00, ne01, ne02, - scale, + scale, max_bias, m0, m1, + n_head_log2, bool(inB) }; @@ -918,9 +916,9 @@ static void ggml_vk_mul_mat_f16( const std::shared_ptr& out, uint32_t inAOff, uint32_t inBOff, uint32_t outOff, int32_t ne00, int32_t ne01, int32_t ne02, - uint32_t nb00, uint32_t nb01, uint32_t nb02, + uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, - uint32_t nb10, uint32_t nb11, uint32_t nb12, + uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13, int32_t ne0, int32_t ne1, uint32_t r2, uint32_t r3 ) { @@ -930,17 +928,17 @@ static void ggml_vk_mul_mat_f16( struct PushConstants { uint32_t inAOff, inBOff, outOff; int32_t ne00, ne01, ne02; - uint32_t nb00, nb01, nb02; + uint32_t nb00, nb01, nb02, nb03; int32_t ne10, ne11, ne12; - uint32_t nb10, nb11, nb12; + uint32_t nb10, nb11, nb12, nb13; int32_t ne0, ne1; uint32_t r2, r3; } pushConsts { safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4), ne00, ne01, ne02, - nb00, nb01, nb02, + nb00, nb01, nb02, nb03, ne10, ne11, ne12, - nb10, nb11, nb12, + nb10, nb11, nb12, nb13, ne0, ne1, r2, r3 }; @@ -1020,6 +1018,8 @@ static void ggml_vk_mul_mat_impl( int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0, int32_t ne1, + uint32_t nb01, uint32_t nb02, uint32_t nb03, + uint32_t nb11, uint32_t nb12, uint32_t nb13, uint32_t r2, uint32_t r3 ) { struct PushConstants { @@ -1027,19 +1027,23 @@ static void ggml_vk_mul_mat_impl( int32_t ne00, ne01, ne02; int32_t ne10, ne12; int32_t ne0, ne1; + uint32_t nb01, nb02, nb03; + uint32_t nb11, nb12, nb13; uint32_t r2, r3; } pushConsts { safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4), ne00, ne01, ne02, ne10, ne12, ne0, ne1, + nb01, nb02, nb03, + nb11, nb12, nb13, r2, r3 }; auto name = std::string(__func__) + "_" + suffix; std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(name)) { - const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2; + const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8; s_algo = komputeManager()->algorithm(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(name); @@ -1075,34 +1079,84 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) { ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward(args)...); } +static void ggml_vk_mul_mat_q4_k( + kp::Sequence& seq, + const std::shared_ptr& inA, + const std::shared_ptr& inB, + const std::shared_ptr& out, + uint32_t inAOff, uint32_t inBOff, uint32_t outOff, + int32_t ne00, int32_t ne01, int32_t ne02, + int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, + int32_t ne0, int32_t ne1, + uint32_t nb01, uint32_t nb02, uint32_t nb03, + uint32_t nb11, uint32_t nb12, uint32_t nb13, + uint32_t r2, uint32_t r3 +) { + const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv, + kp::shader_data::op_mul_mat_q4_k_comp_spv_len); + + struct PushConstants { + uint32_t inAOff, inBOff, outOff; + int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12; + uint32_t nb01, nb02, nb03, nb11, nb12, nb13; + uint32_t r2, r3; + } pushConsts { + inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4), + ne00, ne10, ne0, ne1, ne01, ne02, ne12, + nb01, nb02, nb03, nb11, nb12, nb13, + r2, r3 + }; + + std::shared_ptr s_algo = nullptr; + if (!komputeManager()->hasAlgorithm(__func__)) { + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts}); + } else { + s_algo = komputeManager()->getAlgorithm(__func__); + s_algo->setTensors({inA, inB, out}); + s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}); + s_algo->setPushConstants({pushConsts}); + s_algo->updateDescriptors(s_kompute_context->pool.get()); + } + seq.record(s_algo); +} + static void ggml_vk_mul_mat_q6_k( kp::Sequence& seq, const std::shared_ptr& inA, const std::shared_ptr& inB, const std::shared_ptr& out, uint32_t inAOff, uint32_t inBOff, uint32_t outOff, - int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1, - int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02 + int32_t ne00, int32_t ne01, int32_t ne02, + int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13, + int32_t ne0, int32_t ne1, + uint32_t nb01, uint32_t nb02, uint32_t nb03, + uint32_t nb11, uint32_t nb12, uint32_t nb13, + uint32_t r2, uint32_t r3 ) { const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv, kp::shader_data::op_mul_mat_q6_k_comp_spv_len); struct PushConstants { uint32_t inAOff, inBOff, outOff; - int32_t ne00, ne10, ne0, ne1, ne01, gqa; + int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12; + uint32_t nb01, nb02, nb03, nb11, nb12, nb13; + uint32_t r2, r3; } pushConsts { inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4), - ne00, ne10, ne0, ne1, ne01, ne12/ne02 + ne00, ne10, ne0, ne1, ne01, ne02, ne12, + nb01, nb02, nb03, nb11, nb12, nb13, + r2, r3 }; std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(__func__)) { - const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2; - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts}); + const uint32_t local_x = 2; + const uint32_t local_y = ggml_vk_current_device().subgroupSize; + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(__func__); s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}); + s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}); s_algo->setPushConstants({pushConsts}); s_algo->updateDescriptors(s_kompute_context->pool.get()); } @@ -1190,10 +1244,11 @@ static void ggml_vk_rope( kp::Sequence& seq, const std::shared_ptr& inA, const std::shared_ptr& inB, + const std::shared_ptr& inC, const std::shared_ptr& out, - uint32_t inAOff, uint32_t inBOff, uint32_t outOff, + uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff, ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig, - float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, + float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow, int32_t ne01, int32_t ne02, int32_t ne03, uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03, int32_t ne0, @@ -1201,11 +1256,17 @@ static void ggml_vk_rope( ) { GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32); - static const auto spirv_f16 = getSpirvShader( - kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len + static const auto spirv_norm_f16 = getSpirvShader( + kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len ); - static const auto spirv_f32 = getSpirvShader( - kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len + static const auto spirv_norm_f32 = getSpirvShader( + kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len + ); + static const auto spirv_neox_f16 = getSpirvShader( + kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len + ); + static const auto spirv_neox_f32 = getSpirvShader( + kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len ); int type_size = src0t == GGML_TYPE_F16 ? 2 : 4; @@ -1220,32 +1281,40 @@ static void ggml_vk_rope( GGML_ASSERT(nb0 % type_size == 0); struct PushConstants { - uint32_t inAOff, inBOff, outOff; + uint32_t inAOff, inBOff, inCOff, outOff; int32_t n_dims, mode, n_ctx_orig; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base, freq_scale; + bool has_freq_factors; + float ext_factor, attn_factor, beta_fast, beta_slow; uint32_t nb00, nb01, nb02, nb03; int32_t ne0; uint32_t nb0, nb1, nb2, nb3; } pushConsts { - safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size), + safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size), n_dims, mode, n_ctx_orig, - freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, + freq_base, freq_scale, + has_freq_factors, + ext_factor, attn_factor, beta_fast, beta_slow, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3 }; - auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32"); + auto & inC_ = inC ? inC : inA; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_f16 = src0t == GGML_TYPE_F16; + + auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32"); std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(name)) { + auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32; s_algo = komputeManager()->algorithm( - name, s_kompute_context->pool.get(), {inA, inB, out}, - src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32, + name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts} ); } else { s_algo = komputeManager()->getAlgorithm(name); - s_algo->setTensors({inA, inB, out}); + s_algo->setTensors({inA, inB, inC_, out}); s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)}); s_algo->setPushConstants({pushConsts}); s_algo->updateDescriptors(s_kompute_context->pool.get()); @@ -1323,22 +1392,16 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) { ggml_vk_cpy(spirv, 2, 4, std::forward(args)...); } -static bool ggml_vk_supports_op(const struct ggml_tensor * op) { - switch (op->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - break; - default: - return false; - } - +static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + int64_t n = ggml_nelements(op); switch (op->op) { case GGML_OP_UNARY: + if (n % 4 != 0) return false; switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU: + if (n % 8 != 0) return false; + // fall through + case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_SILU: return ggml_is_contiguous(op->src[0]); default: @@ -1356,8 +1419,18 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) { case GGML_OP_SOFT_MAX: case GGML_OP_RMS_NORM: case GGML_OP_NORM: - case GGML_OP_ROPE: return true; + case GGML_OP_ROPE: + { + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } + return true; + } case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -1396,12 +1469,13 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) { switch (op->src[0]->type) { case GGML_TYPE_F32: - case GGML_TYPE_Q6_K: return op->ne[3] == 1; + case GGML_TYPE_Q6_K: case GGML_TYPE_F16: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: return true; default: ; @@ -1410,6 +1484,8 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) { ; } return false; + + GGML_UNUSED(dev); } static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) { @@ -1458,11 +1534,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml any_commands_recorded = true; - if (!ggml_vk_supports_op(dst)) { - fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); - GGML_ABORT("unsupported op"); - } - const int32_t ne00 = src0 ? src0->ne[0] : 0; const int32_t ne01 = src0 ? src0->ne[1] : 0; const int32_t ne02 = src0 ? src0->ne[2] : 0; @@ -1500,9 +1571,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml const static std::shared_ptr nullTensor = nullptr; uint32_t off_src0 = 0; uint32_t off_src1 = 0; + uint32_t off_src2 = 0; uint32_t off_dst = 0; const std::shared_ptr& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor; const std::shared_ptr& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor; + const std::shared_ptr& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor; const std::shared_ptr& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor; switch (dst->op) { @@ -1578,11 +1651,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32); -#pragma message("TODO: add ALiBi support") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192") - GGML_ASSERT(max_bias == 0.0f); + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; - ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale); + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2); } break; case GGML_OP_DIAG_MASK_INF: { @@ -1634,32 +1712,44 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml case GGML_TYPE_F16: ggml_vk_mul_mat_f16( seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12, + ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ne0, ne1, r2, r3 ); break; case GGML_TYPE_Q8_0: ggml_vk_mul_mat_q8_0( seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3 + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, + nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 ); break; case GGML_TYPE_Q4_0: ggml_vk_mul_mat_q4_0( seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3 + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, + nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 ); break; case GGML_TYPE_Q4_1: ggml_vk_mul_mat_q4_1( seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3 + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, + nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 + ); + break; + case GGML_TYPE_Q4_K: + ggml_vk_mul_mat_q4_k( + seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, + nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 ); break; case GGML_TYPE_Q6_K: ggml_vk_mul_mat_q6_k( seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, - ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02 + ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, + nb01, nb02, nb03, nb11, nb12, nb13, r2, r3 ); break; default: { @@ -1688,13 +1778,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml } break; case GGML_OP_ROPE: { -#pragma message("TODO: implement phi3 frequency factors support") -#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225") - GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet"); - -#pragma message("TODO: update rope NORM mode to match NEOX mode") -#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634") - GGML_ASSERT(ne10 == ne02); GGML_ASSERT(src0t == dstt); // const int n_past = ((int32_t *) dst->op_params)[0]; @@ -1703,6 +1786,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + const bool has_freq_factors = dst->src[2] != nullptr; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); @@ -1711,8 +1796,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); ggml_vk_rope( - seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig, - freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, + seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig, + freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3 ); } break; @@ -1820,11 +1905,6 @@ static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) { } } -static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t buffer) { - auto * ctx = static_cast(buffer->buft->context); - return ctx->name.c_str(); -} - static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) { auto * memory = (ggml_vk_memory *)buffer->context; if (ggml_vk_has_device()) { @@ -1868,10 +1948,10 @@ static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint } static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = { - /* .get_name = */ ggml_backend_kompute_buffer_get_name, /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer, /* .get_base = */ ggml_backend_kompute_buffer_get_base, /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor, /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor, /* .cpy_tensor = */ NULL, @@ -1912,24 +1992,31 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = { }; ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) { - static std::vector bufts = []() { - std::vector vec; - auto devices = ggml_vk_available_devices_internal(0); - vec.reserve(devices.size()); + static std::mutex mutex; + std::lock_guard lock(mutex); - for (const auto & dev : devices) { - vec.push_back({ - /* .iface = */ ggml_backend_kompute_buffer_type_interface, - /* .context = */ new ggml_backend_kompute_buffer_type_context(dev.index, dev.bufferAlignment, dev.maxAlloc) - }); + auto devices = ggml_vk_available_devices(); + int32_t device_count = (int32_t) devices.size(); + GGML_ASSERT(device < device_count); + GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES); + + static ggml_backend_buffer_type + ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES]; + + static bool ggml_backend_kompute_buffer_type_initialized = false; + + if (!ggml_backend_kompute_buffer_type_initialized) { + for (int32_t i = 0; i < device_count; i++) { + ggml_backend_kompute_buffer_types[i] = { + /* .iface = */ ggml_backend_kompute_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i), + /* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc }, + }; } - return vec; - }(); + ggml_backend_kompute_buffer_type_initialized = true; + } - auto it = std::find_if(bufts.begin(), bufts.end(), [device](const ggml_backend_buffer_type & t) { - return device == static_cast(t.context)->device; - }); - return it < bufts.end() ? &*it : nullptr; + return &ggml_backend_kompute_buffer_types[device]; } // backend @@ -1951,31 +2038,15 @@ static void ggml_backend_kompute_free(ggml_backend_t backend) { delete backend; } -static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(ggml_backend_t backend) { - auto * ctx = static_cast(backend->context); - return ggml_backend_kompute_buffer_type(ctx->device); -} - static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { auto * ctx = static_cast(backend->context); ggml_vk_graph_compute(ctx, cgraph); return GGML_STATUS_SUCCESS; } -static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - GGML_UNUSED(backend); - return ggml_vk_supports_op(op); -} - -static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - GGML_UNUSED(backend); - return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name; -} - static struct ggml_backend_i kompute_backend_i = { /* .get_name = */ ggml_backend_kompute_name, /* .free = */ ggml_backend_kompute_free, - /* .get_default_buffer_type = */ ggml_backend_kompute_get_default_buffer_type, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, @@ -1985,14 +2056,8 @@ static struct ggml_backend_i kompute_backend_i = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_kompute_graph_compute, - /* .supports_op = */ ggml_backend_kompute_supports_op, - /* .supports_buft = */ ggml_backend_kompute_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, /* .event_record = */ NULL, /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, }; static ggml_guid_t ggml_backend_kompute_guid() { @@ -2007,6 +2072,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) { ggml_backend_t kompute_backend = new ggml_backend { /* .guid = */ ggml_backend_kompute_guid(), /* .interface = */ kompute_backend_i, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device), /* .context = */ s_kompute_context, }; @@ -2017,22 +2083,169 @@ bool ggml_backend_is_kompute(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid()); } -static ggml_backend_t ggml_backend_reg_kompute_init(const char * params, void * user_data) { - GGML_UNUSED(params); - return ggml_backend_kompute_init(intptr_t(user_data)); -} - -extern "C" int ggml_backend_kompute_reg_devices(); - -int ggml_backend_kompute_reg_devices() { - auto devices = ggml_vk_available_devices_internal(0); - for (const auto & device : devices) { - ggml_backend_register( - ggml_kompute_format_name(device.index).c_str(), - ggml_backend_reg_kompute_init, - ggml_backend_kompute_buffer_type(device.index), - reinterpret_cast(intptr_t(device.index)) - ); - } +static size_t ggml_backend_kompute_get_device_count() { + auto devices = ggml_vk_available_devices(); return devices.size(); } + +static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) { + auto devices = ggml_vk_available_devices(); + GGML_ASSERT((size_t) device < devices.size()); + snprintf(description, description_size, "%s", devices[device].name); +} + +static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) { + auto devices = ggml_vk_available_devices(); + GGML_ASSERT((size_t) device < devices.size()); + *total = devices[device].heapSize; + *free = devices[device].heapSize; +} + +////////////////////////// + +struct ggml_backend_kompute_device_context { + int device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; + ggml_backend_kompute_get_device_memory(ctx->device, free, total); +} + +static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; + return ggml_backend_kompute_buffer_type(ctx->device); +} + +static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) { + return false; + } + + ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; + ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context; + + return buft_ctx->device == ctx->device; +} + +static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_kompute_device_get_name(dev); + props->description = ggml_backend_kompute_device_get_description(dev); + props->type = ggml_backend_kompute_device_get_type(dev); + ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* async = */ false, + /* host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* events = */ false, + }; +} + +static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context; + return ggml_backend_kompute_init(ctx->device); +} + +static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_kompute_device_i = { + /* .get_name = */ ggml_backend_kompute_device_get_name, + /* .get_description = */ ggml_backend_kompute_device_get_description, + /* .get_memory = */ ggml_backend_kompute_device_get_memory, + /* .get_type = */ ggml_backend_kompute_device_get_type, + /* .get_props = */ ggml_backend_kompute_device_get_props, + /* .init_backend = */ ggml_backend_kompute_device_init, + /* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_kompute_device_supports_op, + /* .supports_buft = */ ggml_backend_kompute_device_supports_buft, + /* .offload_op = */ ggml_backend_kompute_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return "Kompute"; +} + +static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return ggml_backend_kompute_get_device_count(); +} + +static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) { + static std::vector devices; + + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) { + ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context; + char desc[256]; + ggml_backend_kompute_get_device_description(i, desc, sizeof(desc)); + ctx->device = i; + ctx->name = "Kompute" + std::to_string(i); + ctx->description = desc; + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_kompute_device_i, + /* .reg = */ reg, + /* .context = */ ctx, + }); + } + initialized = true; + } + } + + GGML_ASSERT(device < devices.size()); + return devices[device]; +} + +static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = { + /* .get_name = */ ggml_backend_kompute_reg_get_name, + /* .get_device_count = */ ggml_backend_kompute_reg_get_device_count, + /* .get_device = */ ggml_backend_kompute_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_kompute_reg() { + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_kompute_reg_i, + /* .context = */ nullptr, + }; + + return ® +} + +GGML_BACKEND_DL_IMPL(ggml_backend_kompute_reg) diff --git a/ggml/src/kompute b/ggml/src/ggml-kompute/kompute similarity index 100% rename from ggml/src/kompute rename to ggml/src/ggml-kompute/kompute diff --git a/ggml/src/kompute-shaders/common.comp b/ggml/src/ggml-kompute/kompute-shaders/common.comp similarity index 93% rename from ggml/src/kompute-shaders/common.comp rename to ggml/src/ggml-kompute/kompute-shaders/common.comp index 62d62b025..dbe4cf804 100644 --- a/ggml/src/kompute-shaders/common.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/common.comp @@ -3,6 +3,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16: require #extension GL_EXT_shader_explicit_arithmetic_types_int8: require #extension GL_EXT_shader_explicit_arithmetic_types_int16: require +#extension GL_EXT_shader_explicit_arithmetic_types_int64: require #extension GL_EXT_control_flow_attributes: enable #extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_EXT_debug_printf : enable @@ -15,6 +16,7 @@ #define TWOPI_F 6.283185307179586f #define QK_K 256 +#define K_SCALE_SIZE 12 #define u8BufToU16(buf, idx) (((uint16_t(buf[idx + 1]) << 8)) | buf[idx]) #define u8BufToFloat16(buf, idx) uint16BitsToHalf u8BufToU16(buf, idx) @@ -64,6 +66,14 @@ mat4 dequantize_q4_1(const block_q4_1 xb, uint il) { return reg; } +#define sizeof_block_q4_k 144 +struct block_q4_k { + float16_t d; + float16_t dmin; + uint8_t scales[K_SCALE_SIZE]; + uint8_t qs[QK_K/2]; +}; + #define sizeof_block_q6_k 210 struct block_q6_k { uint8_t ql[QK_K/2]; // quants, lower 4 bits diff --git a/ggml/src/kompute-shaders/op_add.comp b/ggml/src/ggml-kompute/kompute-shaders/op_add.comp similarity index 100% rename from ggml/src/kompute-shaders/op_add.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_add.comp diff --git a/ggml/src/kompute-shaders/op_addrow.comp b/ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp similarity index 100% rename from ggml/src/kompute-shaders/op_addrow.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_addrow.comp diff --git a/ggml/src/kompute-shaders/op_cpy_f16_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp similarity index 100% rename from ggml/src/kompute-shaders/op_cpy_f16_f16.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f16.comp diff --git a/ggml/src/kompute-shaders/op_cpy_f16_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp similarity index 100% rename from ggml/src/kompute-shaders/op_cpy_f16_f32.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_cpy_f16_f32.comp diff --git a/ggml/src/kompute-shaders/op_cpy_f32_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp similarity index 100% rename from ggml/src/kompute-shaders/op_cpy_f32_f16.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f16.comp diff --git a/ggml/src/kompute-shaders/op_cpy_f32_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp similarity index 100% rename from ggml/src/kompute-shaders/op_cpy_f32_f32.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_cpy_f32_f32.comp diff --git a/ggml/src/kompute-shaders/op_diagmask.comp b/ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp similarity index 100% rename from ggml/src/kompute-shaders/op_diagmask.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_diagmask.comp diff --git a/ggml/src/kompute-shaders/op_gelu.comp b/ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp similarity index 100% rename from ggml/src/kompute-shaders/op_gelu.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_gelu.comp diff --git a/ggml/src/kompute-shaders/op_getrows.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp similarity index 100% rename from ggml/src/kompute-shaders/op_getrows.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_getrows.comp diff --git a/ggml/src/kompute-shaders/op_getrows_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp similarity index 100% rename from ggml/src/kompute-shaders/op_getrows_f16.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_getrows_f16.comp diff --git a/ggml/src/kompute-shaders/op_getrows_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp similarity index 100% rename from ggml/src/kompute-shaders/op_getrows_f32.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_getrows_f32.comp diff --git a/ggml/src/kompute-shaders/op_getrows_q4_0.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp similarity index 100% rename from ggml/src/kompute-shaders/op_getrows_q4_0.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_0.comp diff --git a/ggml/src/kompute-shaders/op_getrows_q4_1.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp similarity index 100% rename from ggml/src/kompute-shaders/op_getrows_q4_1.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_getrows_q4_1.comp diff --git a/ggml/src/kompute-shaders/op_getrows_q6_k.comp b/ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp similarity index 100% rename from ggml/src/kompute-shaders/op_getrows_q6_k.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_getrows_q6_k.comp diff --git a/ggml/src/kompute-shaders/op_mul.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul.comp similarity index 100% rename from ggml/src/kompute-shaders/op_mul.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_mul.comp diff --git a/ggml/src/kompute-shaders/op_mul_mat_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp similarity index 91% rename from ggml/src/kompute-shaders/op_mul_mat_f16.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp index 8f0a9031f..0ab1b2fc2 100644 --- a/ggml/src/kompute-shaders/op_mul_mat_f16.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_f16.comp @@ -20,12 +20,14 @@ layout (push_constant) uniform parameter { uint nb00; uint nb01; uint nb02; + uint nb03; int ne10; int ne11; int ne12; uint nb10; uint nb11; uint nb12; + uint nb13; int ne0; int ne1; uint r2; @@ -42,7 +44,7 @@ void main() { const uint i12 = im%pcs.ne12; const uint i13 = im/pcs.ne12; - const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02; + const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03; const uint x = offset0 / 2 + pcs.inAOff; // Based from inA @@ -52,7 +54,7 @@ void main() { break; } - const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB + const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; float sumf = 0; for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) { diff --git a/ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp similarity index 100% rename from ggml/src/kompute-shaders/op_mul_mat_mat_f32.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_mat_f32.comp diff --git a/ggml/src/kompute-shaders/op_mul_mat_q4_0.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp similarity index 100% rename from ggml/src/kompute-shaders/op_mul_mat_q4_0.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_0.comp diff --git a/ggml/src/kompute-shaders/op_mul_mat_q4_1.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp similarity index 100% rename from ggml/src/kompute-shaders/op_mul_mat_q4_1.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_1.comp diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp new file mode 100644 index 000000000..a5752a3a0 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q4_k.comp @@ -0,0 +1,140 @@ +#version 450 + +#include "common.comp" + +#define N_DST 4 +#define SIZE_OF_BLOCK sizeof_block_q4_k + +layout(local_size_x = 4) in; +layout(local_size_y = 8) in; +layout(local_size_z = 1) in; + +layout (binding = 0) readonly buffer tensorInA { block_q4_k inA[]; }; +layout (binding = 1) readonly buffer tensorInB { float inB[]; }; +layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; + +layout (push_constant) uniform parameter { + uint inAOff; + uint inBOff; + uint outOff; + int ne00; + int ne10; + int ne0; + int ne1; + int ne01; + int ne02; + int ne12; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; + uint r2; + uint r3; +} pcs; + +void main() { + const uint16_t kmask1 = uint16_t(0x3f3f); + const uint16_t kmask2 = uint16_t(0x0f0f); + const uint16_t kmask3 = uint16_t(0xc0c0); + + const uint ix = gl_SubgroupInvocationID/8; // 0...3 + const uint it = gl_SubgroupInvocationID%8; // 0...7 + const uint iq = it/4; // 0 or 1 + const uint ir = it%4; // 0...3 + + const uint nb = pcs.ne00/QK_K; + + const uint r0 = gl_WorkGroupID.x; + const uint r1 = gl_WorkGroupID.y; + const uint im = gl_WorkGroupID.z; + + const uint first_row = r0 * N_DST; + const uint ib_row = first_row * nb; + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); + const uint offset1 = r1*pcs.nb11 + (i12 )*pcs.nb12 + (i13 )*pcs.nb13; + + const uint xblk = offset0 + pcs.inAOff; + const uint y = (offset1 / 4) + pcs.inBOff; + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f, 0.f, 0.f, 0.f}; + float all_sum = 0.f; + + uint y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + for (uint ib = ix; ib < nb; ib += 4) { + const uint blk_idx = ib + xblk; + + float sumy[4] = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = inB[y4+i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = inB[y4+i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = inB[y4+i+128]; sumy[2] += yh[i+0]; + yh[i+8] = inB[y4+i+160]; sumy[3] += yh[i+8]; + } + + for (int row = 0; row < N_DST; row++) { + uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK); + + uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0); + uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2); + uint16_t sc_2 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 4); + uint16_t sc_3 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 6); + uint16_t sc_4 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 8); + + uint16_t sc16[4]; + sc16[0] = sc_0 & kmask1; + sc16[1] = sc_2 & kmask1; + sc16[2] = ((sc_4 >> 0) & kmask2) | ((sc_0 & kmask3) >> 2); + sc16[3] = ((sc_4 >> 4) & kmask2) | ((sc_2 & kmask3) >> 2); + + float acc1[4] = {0.f, 0.f, 0.f, 0.f}; + float acc2[4] = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + uint16_t q1 = u8BufToU16(inA[blk_idx + row_idx].qs, 32 * iq + 8 * ir + i); + uint16_t q2 = u8BufToU16(inA[blk_idx + row_idx].qs, 64 + 32 * iq + 8 * ir + i); + acc1[0] += yl[i+0] * (q1 & 0x000F); + acc1[1] += yl[i+1] * (q1 & 0x0F00); + acc1[2] += yl[i+8] * (q1 & 0x00F0); + acc1[3] += yl[i+9] * (q1 & 0xF000); + acc2[0] += yh[i+0] * (q2 & 0x000F); + acc2[1] += yh[i+1] * (q2 & 0x0F00); + acc2[2] += yh[i+8] * (q2 & 0x00F0); + acc2[3] += yh[i+9] * (q2 & 0xF000); + } + + uint8_t sc8_0 = uint8_t(sc16[0] & 0xFF); + uint8_t sc8_1 = uint8_t(sc16[0] >> 8 ); + uint8_t sc8_2 = uint8_t(sc16[1] & 0xFF); + uint8_t sc8_3 = uint8_t(sc16[1] >> 8 ); + uint8_t sc8_4 = uint8_t(sc16[2] & 0xFF); + uint8_t sc8_5 = uint8_t(sc16[2] >> 8 ); + uint8_t sc8_6 = uint8_t(sc16[3] & 0xFF); + uint8_t sc8_7 = uint8_t(sc16[3] >> 8 ); + + float dall = float(inA[blk_idx + row_idx].d); + float dmin = float(inA[blk_idx + row_idx].dmin); + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8_0 + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8_1 * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8_4 + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8_5 * 1.f/16.f) - + dmin * (sumy[0] * sc8_2 + sumy[1] * sc8_3 + sumy[2] * sc8_6 + sumy[3] * sc8_7); + } + + y4 += 4 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = subgroupAdd(sumf[row]); + if (subgroupElect()) { + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + first_row + row + pcs.outOff] = all_sum; + } + } +} diff --git a/ggml/src/kompute-shaders/op_mul_mat_q6_k.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp similarity index 86% rename from ggml/src/kompute-shaders/op_mul_mat_q6_k.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp index c9baebdf4..d331d1a70 100644 --- a/ggml/src/kompute-shaders/op_mul_mat_q6_k.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q6_k.comp @@ -21,7 +21,16 @@ layout (push_constant) uniform parameter { int ne0; int ne1; int ne01; - int gqa; + int ne02; + int ne12; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; + uint r2; + uint r3; } pcs; void main() { @@ -34,12 +43,15 @@ void main() { const uint r0 = gl_WorkGroupID.x; const uint r1 = gl_WorkGroupID.y; - const uint r2 = gl_WorkGroupID.z; + const uint im = gl_WorkGroupID.z; const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID); - const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0); - const uint x = row * nb + offset0; // Based from inA without base offset - const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB + + const uint i12 = im%pcs.ne12; + const uint i13 = im/pcs.ne12; + + const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); + const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; float sumf = 0; @@ -89,6 +101,6 @@ void main() { const float tot = subgroupAdd(sumf); if (subgroupElect()) { - out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot; + out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot; } } diff --git a/ggml/src/kompute-shaders/op_mul_mat_q8_0.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp similarity index 100% rename from ggml/src/kompute-shaders/op_mul_mat_q8_0.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_mul_mat_q8_0.comp diff --git a/ggml/src/kompute-shaders/op_mul_mv_q_n.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp similarity index 76% rename from ggml/src/kompute-shaders/op_mul_mv_q_n.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp index 440b5ab2c..a6517cc1f 100644 --- a/ggml/src/kompute-shaders/op_mul_mv_q_n.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n.comp @@ -14,10 +14,15 @@ void main() { const uint i12 = im%pcs.ne12; const uint i13 = im/pcs.ne12; - const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02); + // pointers to src0 rows + uint ax[N_ROWS]; + for (int row = 0; row < N_ROWS; ++row) { + const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK); - const uint x = offset0; // Based from inA without base offset - const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB + ax[row] = offset0 + pcs.inAOff; + } + + const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff; float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f}; @@ -32,8 +37,7 @@ void main() { for (uint ib = ix; ib < nb; ib += 16) { for (int row = 0; row < N_ROWS; row++) { - const uint block_index = x + ib + row * nb; - sumf[row] += block_q_n_dot_y(block_index, yb, il); + sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il); } yb += BLOCKS_IN_QUANT * 16; diff --git a/ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp similarity index 80% rename from ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp index 7912b09ac..a9a2f2218 100644 --- a/ggml/src/kompute-shaders/op_mul_mv_q_n_pre.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/op_mul_mv_q_n_pre.comp @@ -1,5 +1,5 @@ layout(local_size_x_id = 0) in; -layout(local_size_y = 1) in; +layout(local_size_y = 8) in; layout(local_size_z = 1) in; layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; @@ -17,6 +17,12 @@ layout (push_constant) uniform parameter { int ne12; int ne0; int ne1; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; uint r2; uint r3; } pcs; diff --git a/ggml/src/kompute-shaders/op_norm.comp b/ggml/src/ggml-kompute/kompute-shaders/op_norm.comp similarity index 100% rename from ggml/src/kompute-shaders/op_norm.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_norm.comp diff --git a/ggml/src/kompute-shaders/op_relu.comp b/ggml/src/ggml-kompute/kompute-shaders/op_relu.comp similarity index 100% rename from ggml/src/kompute-shaders/op_relu.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_relu.comp diff --git a/ggml/src/kompute-shaders/op_rmsnorm.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp similarity index 100% rename from ggml/src/kompute-shaders/op_rmsnorm.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_rmsnorm.comp diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp new file mode 100644 index 000000000..63659cbfe --- /dev/null +++ b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp @@ -0,0 +1,52 @@ +#version 450 + +#include "rope_common.comp" + +layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; +layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; }; +layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; }; + +void main() { + const uint i3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.y; + const uint i1 = gl_WorkGroupID.x; + + float corr_dims[2]; + rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + + const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); + + float theta_base = float(inB[pcs.inBOff + i2]); + float inv_ndims = -1.f/pcs.n_dims; + + float cos_theta; + float sin_theta; + + for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) { + if (i0 < pcs.n_dims) { + uint ic = i0/2; + + float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0); + + const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f; + + rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); + + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 2) + pcs.outOff; // Based from out_ + + const float x0 = float(inA[src]); + const float x1 = float(inA[src+pcs.n_dims/2]); + + out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta); + out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta); + } else { + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ + + out_[dst_data] = inA[src]; + out_[dst_data+1] = inA[src+1]; + } + } +} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp new file mode 100644 index 000000000..4df56204d --- /dev/null +++ b/ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp @@ -0,0 +1,52 @@ +#version 450 + +#include "rope_common.comp" + +layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; +layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; }; +layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; }; + +void main() { + const uint i3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.y; + const uint i1 = gl_WorkGroupID.x; + + float corr_dims[2]; + rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + + const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); + + float theta_base = float(inB[pcs.inBOff + i2]); + float inv_ndims = -1.f/pcs.n_dims; + + float cos_theta; + float sin_theta; + + for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) { + if (i0 < pcs.n_dims) { + uint ic = i0/2; + + float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0); + + const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f; + + rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); + + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + ic*pcs.nb0) / 4) + pcs.outOff; // Based from out_ + + const float x0 = inA[src]; + const float x1 = inA[src+pcs.n_dims/2]; + + out_[dst_data] = x0*cos_theta - x1*sin_theta; + out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ + + out_[dst_data] = inA[src]; + out_[dst_data+1] = inA[src+1]; + } + } +} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp new file mode 100644 index 000000000..a3c0eda8b --- /dev/null +++ b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp @@ -0,0 +1,52 @@ +#version 450 + +#include "rope_common.comp" + +layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; +layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; }; +layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; }; + +void main() { + const uint i3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.y; + const uint i1 = gl_WorkGroupID.x; + + float corr_dims[2]; + rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + + const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); + + float theta_base = float(inB[pcs.inBOff + i2]); + float inv_ndims = -1.f/pcs.n_dims; + + float cos_theta; + float sin_theta; + + for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) { + if (i0 < pcs.n_dims) { + uint ic = i0/2; + + float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0); + + const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f; + + rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); + + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ + + const float x0 = float(inA[src]); + const float x1 = float(inA[src+1]); + + out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta); + out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta); + } else { + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ + + out_[dst_data] = inA[src]; + out_[dst_data+1] = inA[src+1]; + } + } +} diff --git a/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp new file mode 100644 index 000000000..b7963ae72 --- /dev/null +++ b/ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp @@ -0,0 +1,52 @@ +#version 450 + +#include "rope_common.comp" + +layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; +layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; +layout(binding = 2) buffer restrict readonly tensorInC { float inC[]; }; +layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; }; + +void main() { + const uint i3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.y; + const uint i1 = gl_WorkGroupID.x; + + float corr_dims[2]; + rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); + + const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); + + float theta_base = float(inB[pcs.inBOff + i2]); + float inv_ndims = -1.f/pcs.n_dims; + + float cos_theta; + float sin_theta; + + for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) { + if (i0 < pcs.n_dims) { + uint ic = i0/2; + + float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0); + + const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f; + + rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); + + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ + + const float x0 = inA[src]; + const float x1 = inA[src+1]; + + out_[dst_data] = x0*cos_theta - x1*sin_theta; + out_[dst_data+1] = x0*sin_theta + x1*cos_theta; + } else { + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ + + out_[dst_data] = inA[src]; + out_[dst_data+1] = inA[src+1]; + } + } +} diff --git a/ggml/src/kompute-shaders/op_scale.comp b/ggml/src/ggml-kompute/kompute-shaders/op_scale.comp similarity index 100% rename from ggml/src/kompute-shaders/op_scale.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_scale.comp diff --git a/ggml/src/kompute-shaders/op_scale_8.comp b/ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp similarity index 100% rename from ggml/src/kompute-shaders/op_scale_8.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_scale_8.comp diff --git a/ggml/src/kompute-shaders/op_silu.comp b/ggml/src/ggml-kompute/kompute-shaders/op_silu.comp similarity index 100% rename from ggml/src/kompute-shaders/op_silu.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_silu.comp diff --git a/ggml/src/kompute-shaders/op_softmax.comp b/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp similarity index 78% rename from ggml/src/kompute-shaders/op_softmax.comp rename to ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp index 7bc9176ca..4165295bf 100644 --- a/ggml/src/kompute-shaders/op_softmax.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/op_softmax.comp @@ -18,6 +18,10 @@ layout(push_constant) uniform PushConstants { int ne01; int ne02; float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; int mask; } pcs; @@ -34,17 +38,29 @@ void main() { const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB const uint pdst = extra_off + pcs.outOff; // Based from out_ + float slope = 1.0f; + + // ALiBi + if (pcs.max_bias > 0.0f) { + int64_t h = i02; + + float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1; + int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1; + + slope = pow(base, float(exp)); + } + // parallel max float localMax = uintBitsToFloat(0xFF800000); for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f)); + localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f)); } float max_ = subgroupMax(localMax); // parallel sum float localSum = 0.0f; for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { - const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_); + const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_); localSum += exp_psrc0; out_[pdst + i00] = exp_psrc0; } diff --git a/ggml/src/kompute-shaders/rope_common.comp b/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp similarity index 98% rename from ggml/src/kompute-shaders/rope_common.comp rename to ggml/src/ggml-kompute/kompute-shaders/rope_common.comp index df4702896..0fca640dc 100644 --- a/ggml/src/kompute-shaders/rope_common.comp +++ b/ggml/src/ggml-kompute/kompute-shaders/rope_common.comp @@ -8,12 +8,14 @@ layout(local_size_x = 1) in; layout (push_constant) uniform parameter { uint inAOff; uint inBOff; + uint inCOff; uint outOff; int n_dims; int mode; int n_ctx_orig; float freq_base; float freq_scale; + bool has_freq_factors; float ext_factor; float attn_factor; float beta_fast; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m deleted file mode 100644 index f04e5af71..000000000 --- a/ggml/src/ggml-metal.m +++ /dev/null @@ -1,3491 +0,0 @@ -#import "ggml-metal.h" - -#import "ggml-backend-impl.h" -#import "ggml.h" - -#import - -#import - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -#ifdef GGML_METAL_NDEBUG -#define GGML_METAL_LOG_INFO(...) -#define GGML_METAL_LOG_WARN(...) -#define GGML_METAL_LOG_ERROR(...) -#else -#define GGML_METAL_LOG_INFO(...) ggml_metal_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) -#define GGML_METAL_LOG_WARN(...) ggml_metal_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) -#define GGML_METAL_LOG_ERROR(...) ggml_metal_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) -#endif - -#define UNUSED(x) (void)(x) - -struct ggml_metal_kernel { - id pipeline; -}; - -enum ggml_metal_kernel_type { - GGML_METAL_KERNEL_TYPE_ADD, - GGML_METAL_KERNEL_TYPE_ADD_ROW, - GGML_METAL_KERNEL_TYPE_SUB, - GGML_METAL_KERNEL_TYPE_SUB_ROW, - GGML_METAL_KERNEL_TYPE_MUL, - GGML_METAL_KERNEL_TYPE_MUL_ROW, - GGML_METAL_KERNEL_TYPE_DIV, - GGML_METAL_KERNEL_TYPE_DIV_ROW, - GGML_METAL_KERNEL_TYPE_REPEAT_F32, - GGML_METAL_KERNEL_TYPE_REPEAT_F16, - GGML_METAL_KERNEL_TYPE_REPEAT_I32, - GGML_METAL_KERNEL_TYPE_REPEAT_I16, - GGML_METAL_KERNEL_TYPE_SCALE, - GGML_METAL_KERNEL_TYPE_SCALE_4, - GGML_METAL_KERNEL_TYPE_CLAMP, - GGML_METAL_KERNEL_TYPE_TANH, - GGML_METAL_KERNEL_TYPE_RELU, - GGML_METAL_KERNEL_TYPE_SIGMOID, - GGML_METAL_KERNEL_TYPE_GELU, - GGML_METAL_KERNEL_TYPE_GELU_4, - GGML_METAL_KERNEL_TYPE_GELU_QUICK, - GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, - GGML_METAL_KERNEL_TYPE_SILU, - GGML_METAL_KERNEL_TYPE_SILU_4, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, - GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, - GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, - GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, - GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, - GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, - GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, - GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, - GGML_METAL_KERNEL_TYPE_RMS_NORM, - GGML_METAL_KERNEL_TYPE_GROUP_NORM, - GGML_METAL_KERNEL_TYPE_NORM, - GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, - GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, - GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, - //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, - //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, - //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, - GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, - GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, - GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, - GGML_METAL_KERNEL_TYPE_IM2COL_F16, - GGML_METAL_KERNEL_TYPE_IM2COL_F32, - GGML_METAL_KERNEL_TYPE_UPSCALE_F32, - GGML_METAL_KERNEL_TYPE_PAD_F32, - GGML_METAL_KERNEL_TYPE_ARANGE_F32, - GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, - GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, - GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, - GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 - GGML_METAL_KERNEL_TYPE_CPY_F32_F32, - GGML_METAL_KERNEL_TYPE_CPY_F32_F16, - GGML_METAL_KERNEL_TYPE_CPY_F16_F16, - GGML_METAL_KERNEL_TYPE_CPY_F16_F32, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, - GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, - GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, - GGML_METAL_KERNEL_TYPE_CONCAT, - GGML_METAL_KERNEL_TYPE_SQR, - GGML_METAL_KERNEL_TYPE_SQRT, - GGML_METAL_KERNEL_TYPE_SIN, - GGML_METAL_KERNEL_TYPE_COS, - GGML_METAL_KERNEL_TYPE_SUM_ROWS, - - GGML_METAL_KERNEL_TYPE_COUNT -}; - -struct ggml_backend_metal_context { - int n_cb; - - id device; - id queue; - - dispatch_queue_t d_queue; - - struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; - - bool support_simdgroup_reduction; - bool support_simdgroup_mm; - - bool should_capture_next_compute; - - // abort ggml_metal_graph_compute if callback returns true - ggml_abort_callback abort_callback; - void * abort_callback_data; -}; - -// MSL code -// TODO: move the contents here when ready -// for now it is easier to work in a separate file -// static NSString * const msl_library_source = @"see metal.metal"; - -// Here to assist with NSBundle Path Hack -@interface GGMLMetalClass : NSObject -@end -@implementation GGMLMetalClass -@end - -static void ggml_metal_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) { - fprintf(stderr, "%s", msg); - - UNUSED(level); - UNUSED(user_data); -} - -ggml_log_callback ggml_metal_log_callback = ggml_metal_default_log_callback; -void * ggml_metal_log_user_data = NULL; - -GGML_ATTRIBUTE_FORMAT(2, 3) -static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ - if (ggml_metal_log_callback != NULL) { - va_list args; - va_start(args, format); - char buffer[128]; - int len = vsnprintf(buffer, 128, format, args); - if (len < 128) { - ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data); - } else { - char* buffer2 = malloc(len+1); - va_end(args); - va_start(args, format); - vsnprintf(buffer2, len+1, format, args); - buffer2[len] = 0; - ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data); - free(buffer2); - } - va_end(args); - } -} - -static void * ggml_metal_host_malloc(size_t n) { - void * data = NULL; - -#if TARGET_OS_OSX - kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); - if (err != KERN_SUCCESS) { - GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); - return NULL; - } -#else - const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); - if (result != 0) { - GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); - return NULL; - } -#endif - - return data; -} - -static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) { - GGML_METAL_LOG_INFO("%s: allocating\n", __func__); - -#if TARGET_OS_OSX && !GGML_METAL_NDEBUG - // Show all the Metal device instances in the system - NSArray * devices = MTLCopyAllDevices(); - for (id device in devices) { - GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); - } - [devices release]; // since it was created by a *Copy* C method -#endif - - // Pick and show default Metal device - id device = MTLCreateSystemDefaultDevice(); - GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); - - // Configure context - struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context)); - ctx->device = device; - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); - ctx->queue = [ctx->device newCommandQueue]; - ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); - - id metal_library; - - // load library - // - // - first check if the library is embedded - // - then check if the library is in the bundle - // - if not found, load the source and compile it - // - if that fails, return NULL - { - NSBundle * bundle = nil; -#ifdef SWIFT_PACKAGE - bundle = SWIFTPM_MODULE_BUNDLE; -#else - bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; -#endif - - NSError * error = nil; - -#if GGML_METAL_EMBED_LIBRARY - const bool try_metallib = false; -#else - const bool try_metallib = true; -#endif - - NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; - if (try_metallib && path_lib != nil) { - // pre-compiled library found - NSURL * libURL = [NSURL fileURLWithPath:path_lib]; - GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); - - metal_library = [ctx->device newLibraryWithURL:libURL error:&error]; - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } else { -#if GGML_METAL_EMBED_LIBRARY - GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__); - - extern const char ggml_metallib_start[]; - extern const char ggml_metallib_end[]; - - NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding]; -#else - GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); - - NSString * path_source; - NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; - - GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); - - if (path_resource) { - path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"]; - } else { - path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; - } - - if (path_source == nil) { - GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); - path_source = @"ggml-metal.metal"; - } - - GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); - - NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error]; - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } -#endif // GGML_METAL_EMBED_LIBRARY - - @autoreleasepool { - // dictionary of preprocessor macros - NSMutableDictionary * prep = [NSMutableDictionary dictionary]; - - MTLCompileOptions* options = [MTLCompileOptions new]; - options.preprocessorMacros = prep; - - //[options setFastMathEnabled:false]; - - metal_library = [ctx->device newLibraryWithSource:src options:options error:&error]; - if (error) { - GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } - } - } - - // print MTL GPU family: - GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]); - - const NSInteger MTLGPUFamilyMetal3 = 5001; - - // determine max supported GPU family - // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf - // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf - { - for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { - if ([ctx->device supportsFamily:i]) { - GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); - break; - } - } - - for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { - if ([ctx->device supportsFamily:i]) { - GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); - break; - } - } - - for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) { - if ([ctx->device supportsFamily:i]) { - GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i); - break; - } - } - } - - ctx->support_simdgroup_reduction = [ctx->device supportsFamily:MTLGPUFamilyApple7]; - ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3]; - - ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7]; - - GGML_METAL_LOG_INFO("%s: simdgroup reduction support = %s\n", __func__, ctx->support_simdgroup_reduction ? "true" : "false"); - GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false"); - GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); - - ctx->should_capture_next_compute = false; - -#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) - if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6); - } -#elif TARGET_OS_OSX - if (ctx->device.maxTransferRate != 0) { - GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6); - } else { - GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__); - } -#endif - - // load kernels - { - NSError * error = nil; - - for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { - ctx->kernels[i].pipeline = nil; - } - - /* - GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ - (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ - (int) kernel->pipeline.threadExecutionWidth); \ - */ -#define GGML_METAL_ADD_KERNEL(e, name, supported) \ - if (supported) { \ - struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ - id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ - kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ - [metal_function release]; \ - if (error) { \ - GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ - [metal_library release]; \ - return NULL; \ - } \ - } else { \ - GGML_METAL_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ - } - - // simd_sum and simd_max requires MTLGPUFamilyApple7 - - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); - //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); - } - - [metal_library release]; - return ctx; -} - -static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { - GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); - - for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { - [ctx->kernels[i].pipeline release]; - } - - [ctx->queue release]; - [ctx->device release]; - - dispatch_release(ctx->d_queue); - - free(ctx); -} - -// temporarily defined here for compatibility between ggml-backend and the old API - -struct ggml_backend_metal_buffer { - void * data; - size_t size; - - id metal; -}; - -struct ggml_backend_metal_buffer_context { - void * all_data; - size_t all_size; - bool owned; - - // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap - int n_buffers; - struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; -}; - -// finds the Metal buffer that contains the tensor data on the GPU device -// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the -// Metal buffer based on the host memory pointer -// -static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) { - //GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); - - const int64_t tsize = ggml_nbytes(t); - - ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; - - struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context; - - // find the view that contains the tensor fully - for (int i = 0; i < buf_ctx->n_buffers; ++i) { - const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data; - - //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size); - if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) { - *offs = (size_t) ioffs; - - //GGML_METAL_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); - - return buf_ctx->buffers[i].metal; - } - } - - GGML_METAL_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); - - return nil; -} - -static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) { - for (size_t i = 0, n = 3; i < n; ++i) { - if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { - return false; - } - } - - switch (op->op) { - case GGML_OP_UNARY: - switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_SILU: - return ggml_is_contiguous(op->src[0]); - default: - return false; - } - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_ACC: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_REPEAT: - case GGML_OP_SCALE: - case GGML_OP_CLAMP: - return true; - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - return ggml_is_contiguous(op->src[0]); - case GGML_OP_SUM_ROWS: - case GGML_OP_SOFT_MAX: - case GGML_OP_RMS_NORM: - case GGML_OP_GROUP_NORM: - return ctx->support_simdgroup_reduction; - case GGML_OP_NORM: - case GGML_OP_ROPE: - return true; - case GGML_OP_IM2COL: - return op->src[0]->type == GGML_TYPE_F16; - case GGML_OP_POOL_1D: - case GGML_OP_POOL_2D: - return false; - case GGML_OP_UPSCALE: - case GGML_OP_PAD: - case GGML_OP_ARANGE: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_ARGSORT: - case GGML_OP_LEAKY_RELU: - return true; - case GGML_OP_FLASH_ATTN_EXT: - if (op->src[1]->type != GGML_TYPE_F16) { - return false; - } - if (op->src[2]->type != GGML_TYPE_F16) { - return false; - } - if (op->src[0]->ne[0] == 256) { - return false; - } - return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels - case GGML_OP_SSM_CONV: - case GGML_OP_SSM_SCAN: - return true; - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - return ctx->support_simdgroup_reduction && - (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); - case GGML_OP_CPY: - case GGML_OP_DUP: - case GGML_OP_CONT: - { - switch (op->src[0]->type) { - case GGML_TYPE_F32: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_IQ4_NL: - return true; - default: - return false; - } - case GGML_TYPE_F16: - switch (op->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - return true; - default: - return false; - } - default: - return false; - }; - } - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_GET_ROWS: - { - return op->ne[3] == 1; - } - default: - return false; - } -} - -static enum ggml_status ggml_metal_graph_compute( - struct ggml_backend_metal_context * ctx, - struct ggml_cgraph * gf) { - - @autoreleasepool { - MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor; - edesc.dispatchType = MTLDispatchTypeSerial; - - // create multiple command buffers and enqueue them - // then, we encode the graph into the command buffers in parallel - - const int n_nodes = gf->n_nodes; - const int n_cb = ctx->n_cb; - const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb; - - const bool should_capture = ctx->should_capture_next_compute; - if (should_capture) { - ctx->should_capture_next_compute = false; - - MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; - descriptor.captureObject = ctx->queue; - - NSError * error = nil; - if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { - GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); - GGML_ABORT("capture failed"); - } - } - - id command_buffer_builder[n_cb]; - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - command_buffer_builder[cb_idx] = command_buffer; - - // always enqueue the first two command buffers - // enqueue all of the command buffers if we don't need to abort - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer enqueue]; - } - } - - const id *command_buffers = command_buffer_builder; - - dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) { - const int cb_idx = iter; - - size_t offs_src0 = 0; - size_t offs_src1 = 0; - size_t offs_src2 = 0; - size_t offs_dst = 0; - - id command_buffer = command_buffers[cb_idx]; - id encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc]; - - const int node_start = (cb_idx + 0) * n_nodes_per_cb; - const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes); - - for (int i = node_start; i < node_end; ++i) { - if (i == -1) { - [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers]; - continue; - } - - //GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); - - struct ggml_tensor * src0 = gf->nodes[i]->src[0]; - struct ggml_tensor * src1 = gf->nodes[i]->src[1]; - struct ggml_tensor * src2 = gf->nodes[i]->src[2]; - struct ggml_tensor * dst = gf->nodes[i]; - - if (ggml_is_empty(dst)) { - continue; - } - - switch (dst->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_TRANSPOSE: - case GGML_OP_PERMUTE: - { - // noop -> next node - } continue; - default: - { - } break; - } - - if (!ggml_metal_supports_op(ctx, dst)) { - GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); - GGML_ABORT("unsupported op"); - } - - if (should_capture) { - [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]]; - } - - const int64_t ne00 = src0 ? src0->ne[0] : 0; - const int64_t ne01 = src0 ? src0->ne[1] : 0; - const int64_t ne02 = src0 ? src0->ne[2] : 0; - const int64_t ne03 = src0 ? src0->ne[3] : 0; - - const uint64_t nb00 = src0 ? src0->nb[0] : 0; - const uint64_t nb01 = src0 ? src0->nb[1] : 0; - const uint64_t nb02 = src0 ? src0->nb[2] : 0; - const uint64_t nb03 = src0 ? src0->nb[3] : 0; - - const int64_t ne10 = src1 ? src1->ne[0] : 0; - const int64_t ne11 = src1 ? src1->ne[1] : 0; - const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; - - const uint64_t nb10 = src1 ? src1->nb[0] : 0; - const uint64_t nb11 = src1 ? src1->nb[1] : 0; - const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; - - const int64_t ne20 = src2 ? src2->ne[0] : 0; - const int64_t ne21 = src2 ? src2->ne[1] : 0; - const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); - const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; - - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; - - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; - const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; - - id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; - id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; - id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; - id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; - - //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); - //if (src0) { - // GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, - // ggml_is_contiguous(src0), src0->name); - //} - //if (src1) { - // GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, - // ggml_is_contiguous(src1), src1->name); - //} - //if (dst) { - // GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, - // dst->name); - //} - - switch (dst->op) { - case GGML_OP_CONCAT: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - const size_t offs = 0; - - bool bcast_row = false; - - int64_t nb = ne00; // used by the "row" kernels - - id pipeline = nil; - - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - nb = ne00 / 4; - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - bcast_row = true; - } else { - switch (dst->op) { - case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; - case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; - case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; - case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ABORT("fatal error"); - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; - - if (bcast_row) { - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } else { - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - } break; - case GGML_OP_REPEAT: - { - id pipeline; - - switch (src0t) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; - case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; - default: GGML_ABORT("fatal error"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ACC: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - GGML_ASSERT(dstt == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - const size_t pnb1 = ((int32_t *) dst->op_params)[0]; - const size_t pnb2 = ((int32_t *) dst->op_params)[1]; - const size_t pnb3 = ((int32_t *) dst->op_params)[2]; - const size_t offs = ((int32_t *) dst->op_params)[3]; - - const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - // run a separete kernel to cpy src->dst - // not sure how to avoid this - // TODO: make a simpler cpy_bytes kernel - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); - - [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_SCALE: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - float scale; - memcpy(&scale, dst->op_params, sizeof(scale)); - - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - n /= 4; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_CLAMP: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; - - float min; - float max; - memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(gf->nodes[i])) { - // we are not taking into account the strides, so for now require contiguous tensors - GGML_ASSERT(ggml_is_contiguous(src0)); - - case GGML_UNARY_OP_TANH: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_RELU: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SIGMOID: - { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_UNARY_OP_SILU: - { - int64_t n = ggml_nelements(dst); - - id pipeline = nil; - - if (n % 4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; - n /= 4; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_SQR: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SQRT: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SIN: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_COS: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SUM_ROWS: - { - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SOFT_MAX: - { - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - - int nth = 32; // SIMD width - - id pipeline = nil; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - if (ne00%4 == 0) { - while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; - } - } else { - while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { - nth *= 2; - } - if (use_f16) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; - } - } - - float scale; - float max_bias; - - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; - [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_DIAG_MASK_INF: - { - const int n_past = ((int32_t *)(dst->op_params))[0]; - - id pipeline = nil; - - if (ne00%8 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - - if (ne00%8 == 0) { - [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - else { - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } - } break; - case GGML_OP_SSM_CONV: - { - GGML_ASSERT(src0t == GGML_TYPE_F32); - GGML_ASSERT(src1t == GGML_TYPE_F32); - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_SSM_SCAN: - { - struct ggml_tensor * src3 = gf->nodes[i]->src[3]; - struct ggml_tensor * src4 = gf->nodes[i]->src[4]; - struct ggml_tensor * src5 = gf->nodes[i]->src[5]; - - GGML_ASSERT(src3); - GGML_ASSERT(src4); - GGML_ASSERT(src5); - - size_t offs_src3 = 0; - size_t offs_src4 = 0; - size_t offs_src5 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; - id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; - - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); - const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - - const uint64_t nb30 = src3->nb[0]; - const uint64_t nb31 = src3->nb[1]; - - const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); - const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); - - const uint64_t nb40 = src4->nb[0]; - const uint64_t nb41 = src4->nb[1]; - const uint64_t nb42 = src4->nb[2]; - - const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); - const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); - const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); - - const uint64_t nb50 = src5->nb[0]; - const uint64_t nb51 = src5->nb[1]; - const uint64_t nb52 = src5->nb[2]; - - const int64_t d_state = ne00; - const int64_t d_inner = ne01; - const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; - [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_MUL_MAT: - { - GGML_ASSERT(ne00 == ne10); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - const uint r2 = ne12/ne02; - const uint r3 = ne13/ne03; - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - int ne11_mm_min = 1; - -#if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach - if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { - switch (src0t) { - case GGML_TYPE_F16: ne11_mm_min = 2; break; - case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; - case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; - case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: ne11_mm_min = 15; break; - case GGML_TYPE_Q4_K: ne11_mm_min = 11; break; - case GGML_TYPE_Q5_0: // not tested yet - case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet - case GGML_TYPE_Q5_K: ne11_mm_min = 7; break; - case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; - default: ne11_mm_min = 1; break; - } - } -#endif - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - !ggml_is_transposed(src0) && - !ggml_is_transposed(src1) && - src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; - default: GGML_ABORT("MUL MAT-MAT not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - id pipeline = nil; - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; - nrows = 4; - } break; - case GGML_TYPE_F16: - { - nth0 = 32; - nth1 = 1; - if (src1t == GGML_TYPE_F32) { - if (ne11 * ne12 < 4) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; - nrows = ne11; - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; - nrows = 4; - } - } else { - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; - nrows = 4; - } - } break; - case GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; - } break; - default: - { - GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); - GGML_ABORT("not implemented"); - } - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; - - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } - } break; - case GGML_OP_MUL_MAT_ID: - { - const int n_as = src0->ne[2]; - - // src2 = ids - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); - - GGML_ASSERT(src2t == GGML_TYPE_I32); - - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src1t == GGML_TYPE_F32); - - // find the break-even point where the matrix-matrix kernel becomes more efficient compared - // to the matrix-vector kernel - // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - GGML_ASSERT(dst_rows <= dst_rows_max); - - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - // !!! - // TODO: for now, always use mat-vec kernels until we figure out how to improve the - // indirect matrix multiplication - // !!! - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne00 % 32 == 0 && ne00 >= 64 && - dst_rows > dst_rows_min) { - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } - - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; - - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - id pipeline = nil; - - // use custom matrix x vector kernel - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(src1t == GGML_TYPE_F32); - nth0 = 32; - nth1 = 1; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; - } break; - case GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; - } break; - case GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; - } break; - case GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; - } break; - case GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; - } break; - case GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; - } break; - case GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; - } break; - case GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; - } break; - case GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; - } break; - case GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; - } break; - case GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ2_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; - } break; - case GGML_TYPE_IQ3_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; - } break; - case GGML_TYPE_IQ2_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; - } break; - case GGML_TYPE_IQ1_M: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; - } break; - case GGML_TYPE_IQ4_NL: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; - } break; - case GGML_TYPE_IQ4_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; - } break; - default: - { - GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); - GGML_ABORT("not implemented"); - } - }; - - if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; - - const int64_t _ne1 = 1; - const int tgz = dst_rows; - - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || - src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { - const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } - } break; - case GGML_OP_GET_ROWS: - { - id pipeline = nil; - - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; - case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; - } break; - case GGML_OP_RMS_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - int nth = 32; // SIMD width - - while (nth < ne00/4 && nth < 1024) { - nth *= 2; - } - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_GROUP_NORM: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous(src0)); - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - const int32_t n_groups = ((int32_t *) dst->op_params)[0]; - - int nth = 32; // SIMD width - - //while (nth < ne00/4 && nth < 1024) { - // nth *= 2; - //} - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&eps length:sizeof( float) atIndex:9]; - [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_NORM: - { - GGML_ASSERT(ggml_is_contiguous_1(src0)); - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - const int nth = MIN(256, ne00); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; - - const int64_t nrows = ggml_nrows(src0); - - [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ROPE: - { - GGML_ASSERT(ne10 == ne02); - - const int nth = MIN(1024, ne00); - - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; - - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - - id pipeline = nil; - - if (!is_neox) { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } else { - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_IM2COL: - { - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int32_t N = src1->ne[is_2D ? 3 : 2]; - const int32_t IC = src1->ne[is_2D ? 2 : 1]; - const int32_t IH = is_2D ? src1->ne[1] : 1; - const int32_t IW = src1->ne[0]; - - const int32_t KH = is_2D ? src0->ne[1] : 1; - const int32_t KW = src0->ne[0]; - - const int32_t OH = is_2D ? dst->ne[2] : 1; - const int32_t OW = dst->ne[1]; - - const int32_t CHW = IC * KH * KW; - - const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; - const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; - - id pipeline = nil; - - switch (dst->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2]; - [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3]; - [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4]; - [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5]; - [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6]; - [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7]; - [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8]; - [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9]; - [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10]; - [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11]; - [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12]; - - [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; - } break; - case GGML_OP_UPSCALE: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; - [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; - [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; - [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; - - const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_PAD: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARANGE: - { - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - float start; - float step; - - memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; - [encoder setBytes:&start length:sizeof(start) atIndex:2]; - [encoder setBytes:&step length:sizeof(step) atIndex:3]; - - const int nth = MIN(1024, ne0); - - [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int dim = dst->op_params[0]; - const int max_period = dst->op_params[1]; - - const int half = dim / 2; - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; - [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; - - const int nth = MIN(1024, half); - - [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - case GGML_OP_ARGSORT: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_I32); - - const int nrows = ggml_nrows(src0); - - enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - - // bitonic sort requires the number of elements to be power of 2 - int64_t ne00_padded = 1; - while (ne00_padded < ne00) { - ne00_padded *= 2; - } - - // Metal kernels require the buffer size to be multiple of 16 bytes - // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength - const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); - - id pipeline = nil; - - switch (order) { - case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; - case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; - default: GGML_ABORT("fatal error"); - }; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; - } break; - case GGML_OP_LEAKY_RELU: - { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - float slope; - memcpy(&slope, dst->op_params, sizeof(float)); - - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_are_same_shape (src1, src2)); - - struct ggml_tensor * src3 = gf->nodes[i]->src[3]; - - size_t offs_src3 = 0; - - id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; - - GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); - GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); - //const int64_t ne31 = src3 ? src3->ne[1] : 0; - const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); - const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); - - const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); - const uint64_t nb31 = src3 ? src3->nb[1] : 0; - const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); - const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); - - const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); - - float scale; - float max_bias; - float logit_softcap; - memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale)); - memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias)); - memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); - - if (logit_softcap != 0.0f) { - scale /= logit_softcap; - } - - const uint32_t n_head = src0->ne[2]; - const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - id pipeline = nil; - - bool use_vec_kernel = false; - - if (ne01 >= 4 || (ne00%128 != 0)) { - switch (ne00) { - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } else { - use_vec_kernel = true; - - switch (ne00) { - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } - } - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; - if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; - } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; - } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; - [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; - - if (!use_vec_kernel) { - // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 8 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - int64_t nsgmax = 2; - - while (true) { - const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2); - if (smem > ctx->device.maxThreadgroupMemoryLength) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - - const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } else { - // half1x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - - GGML_ASSERT(nqptg <= 32); - GGML_ASSERT(nqptg % 1 == 0); - GGML_ASSERT(ncpsg % 32 == 0); - - // simdgroups per threadgroup (a.k.a. warps) - const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); - - int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; - - const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); - - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); - [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; - - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; - } - } break; - case GGML_OP_DUP: - case GGML_OP_CPY: - case GGML_OP_CONT: - { - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - - id pipeline = nil; - - switch (src0t) { - case GGML_TYPE_F32: - { - GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); - - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - case GGML_TYPE_F16: - { - switch (dstt) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - default: GGML_ABORT("not implemented"); - }; - } break; - default: GGML_ABORT("not implemented"); - } - - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - } break; - default: - { - GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ABORT("fatal error"); - } - } - - if (should_capture) { - [encoder popDebugGroup]; - } - } - - [encoder endEncoding]; - - if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer commit]; - } - }); - - // Wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) - - for (int i = 0; i < n_cb; ++i) { - id command_buffer = command_buffers[i]; - [command_buffer waitUntilCompleted]; - - MTLCommandBufferStatus status = [command_buffer status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - NSString * error_code = [command_buffer error].localizedDescription; - GGML_METAL_LOG_INFO("error: %s\n", [error_code UTF8String]); - } - - return GGML_STATUS_FAILED; - } - - id next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil); - if (!next_buffer) { - continue; - } - - bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } - - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; - } - - [next_buffer commit]; - } - - if (should_capture) { - [[MTLCaptureManager sharedCaptureManager] stopCapture]; - } - - } - return GGML_STATUS_SUCCESS; -} - -//////////////////////////////////////////////////////////////////////////////// - -// backend interface - -// default buffer -static id g_backend_device = nil; -static int g_backend_device_ref_count = 0; - -static id ggml_backend_metal_get_device(void) { - if (g_backend_device == nil) { - g_backend_device = MTLCreateSystemDefaultDevice(); - } - - g_backend_device_ref_count++; - - return g_backend_device; -} - -static void ggml_backend_metal_free_device(void) { - assert(g_backend_device_ref_count > 0); - - g_backend_device_ref_count--; - - if (g_backend_device_ref_count == 0) { - [g_backend_device release]; - g_backend_device = nil; - } -} - -GGML_CALL static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) { - return "Metal"; - - UNUSED(buffer); -} - -GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { - struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - - for (int i = 0; i < ctx->n_buffers; i++) { - [ctx->buffers[i].metal release]; - } - ggml_backend_metal_free_device(); - - if (ctx->owned) { -#if TARGET_OS_OSX - vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size); -#else - free(ctx->all_data); -#endif - } - - free(ctx); -} - -GGML_CALL static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { - struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - - return ctx->all_data; -} - -GGML_CALL static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - memcpy((char *)tensor->data + offset, data, size); - - UNUSED(buffer); -} - -GGML_CALL static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { - memcpy(data, (const char *)tensor->data + offset, size); - - UNUSED(buffer); -} - -GGML_CALL static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { - if (ggml_backend_buffer_is_host(src->buffer)) { - memcpy(dst->data, src->data, ggml_nbytes(src)); - return true; - } - return false; - - UNUSED(buffer); -} - -GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; - - memset(ctx->all_data, value, ctx->all_size); -} - -static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = { - /* .get_name = */ ggml_backend_metal_buffer_get_name, - /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_get_base, - /* .init_tensor = */ NULL, - /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_clear, - /* .reset = */ NULL, -}; - -// default buffer type - -GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "Metal"; - - UNUSED(buft); -} - -static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { -#ifndef GGML_METAL_NDEBUG -#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) - if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", - __func__, - size_aligned / 1024.0 / 1024.0, - device.currentAllocatedSize / 1024.0 / 1024.0, - device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); - - if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { - GGML_METAL_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); - } else { - GGML_METAL_LOG_INFO("\n"); - } - } else { - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", - __func__, - size_aligned / 1024.0 / 1024.0, - device.currentAllocatedSize / 1024.0 / 1024.0); - } -#endif -#endif - UNUSED(device); - UNUSED(size_aligned); -} - -GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context)); - - const size_t size_page = sysconf(_SC_PAGESIZE); - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - id device = ggml_backend_metal_get_device(); - - ctx->all_data = ggml_metal_host_malloc(size_aligned); - ctx->all_size = size_aligned; - ctx->owned = true; - ctx->n_buffers = 1; - - if (ctx->all_data != NULL) { - ctx->buffers[0].data = ctx->all_data; - ctx->buffers[0].size = size; - ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data - length:size_aligned - options:MTLResourceStorageModeShared - deallocator:nil]; - } - - if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) { - GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - free(ctx); - ggml_backend_metal_free_device(); - return NULL; - } - - //ggml_backend_metal_log_allocated_size(device, size_aligned); - - return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); -} - -GGML_CALL static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 32; - UNUSED(buft); -} - -GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { - id device = ggml_backend_metal_get_device(); - size_t max_size = device.maxBufferLength; - ggml_backend_metal_free_device(); - - return max_size; - - UNUSED(buft); -} - -GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; - - UNUSED(buft); -} - -GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ ggml_backend_metal_buffer_type_is_host, - }, - /* .context = */ NULL, - }; - - return &ggml_backend_buffer_type_metal; -} - -// buffer from ptr - -GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { - struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context)); - - ctx->all_data = data; - ctx->all_size = size; - ctx->owned = false; - ctx->n_buffers = 0; - - const size_t size_page = sysconf(_SC_PAGESIZE); - - // page-align the data ptr - { - const uintptr_t offs = (uintptr_t) data % size_page; - data = (void *) ((char *) data - offs); - size += offs; - } - - size_t size_aligned = size; - if ((size_aligned % size_page) != 0) { - size_aligned += (size_page - (size_aligned % size_page)); - } - - id device = ggml_backend_metal_get_device(); - - // the buffer fits into the max buffer size allowed by the device - if (size_aligned <= device.maxBufferLength) { - ctx->buffers[ctx->n_buffers].data = data; - ctx->buffers[ctx->n_buffers].size = size; - - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); - return false; - } - - ggml_backend_metal_log_allocated_size(device, size_aligned); - - ++ctx->n_buffers; - } else { - // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into - // one of the views - const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case - const size_t size_step = device.maxBufferLength - size_ovlp; - const size_t size_view = device.maxBufferLength; - - for (size_t i = 0; i < size; i += size_step) { - const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); - - ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); - ctx->buffers[ctx->n_buffers].size = size_step_aligned; - - ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; - - if (ctx->buffers[ctx->n_buffers].metal == nil) { - GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); - return false; - } - - ggml_backend_metal_log_allocated_size(device, size_step_aligned); - - if (i + size_step < size) { - GGML_METAL_LOG_INFO("\n"); - } - - ++ctx->n_buffers; - } - } - - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); -} - -// backend - -GGML_CALL static const char * ggml_backend_metal_name(ggml_backend_t backend) { - return "Metal"; - - UNUSED(backend); -} - -GGML_CALL static void ggml_backend_metal_free(ggml_backend_t backend) { - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - ggml_metal_free(ctx); - free(backend); -} - -GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggml_backend_t backend) { - return ggml_backend_metal_buffer_type(); - - UNUSED(backend); -} - -GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { - struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context; - - return ggml_metal_graph_compute(metal_ctx, cgraph); -} - -GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context; - - return ggml_metal_supports_op(metal_ctx, op); -} - -GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name; - - UNUSED(backend); -} - -static struct ggml_backend_i ggml_backend_metal_i = { - /* .get_name = */ ggml_backend_metal_name, - /* .free = */ ggml_backend_metal_free, - /* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_metal_graph_compute, - /* .supports_op = */ ggml_backend_metal_supports_op, - /* .supports_buft = */ ggml_backend_metal_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, -}; - -void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) { - ggml_metal_log_callback = log_callback; - ggml_metal_log_user_data = user_data; -} - -static ggml_guid_t ggml_backend_metal_guid(void) { - static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 }; - return &guid; -} - -ggml_backend_t ggml_backend_metal_init(void) { - struct ggml_backend_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS); - if (ctx == NULL) { - GGML_METAL_LOG_ERROR("%s: error: failed to allocate context\n", __func__); - return NULL; - } - - ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend)); - - *metal_backend = (struct ggml_backend) { - /* .guid = */ ggml_backend_metal_guid(), - /* .interface = */ ggml_backend_metal_i, - /* .context = */ ctx, - }; - - return metal_backend; -} - -bool ggml_backend_is_metal(ggml_backend_t backend) { - return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid()); -} - -void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - - ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); -} - -void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - - ctx->abort_callback = abort_callback; - ctx->abort_callback_data = user_data; -} - -bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - - return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; -} - -void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { - GGML_ASSERT(ggml_backend_is_metal(backend)); - - struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; - ctx->should_capture_next_compute = true; -} - -GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning - -GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) { - return ggml_backend_metal_init(); - - GGML_UNUSED(params); - GGML_UNUSED(user_data); -} diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal deleted file mode 100644 index f323ab5f4..000000000 --- a/ggml/src/ggml-metal.metal +++ /dev/null @@ -1,6374 +0,0 @@ -#define GGML_COMMON_DECL_METAL -#define GGML_COMMON_IMPL_METAL -#include "ggml-common.h" - -#include - -using namespace metal; - -#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) -#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } - -#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 - -enum ggml_sort_order { - GGML_SORT_ORDER_ASC, - GGML_SORT_ORDER_DESC, -}; - -// general-purpose kernel for addition, subtraction, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -kernel void kernel_add( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); - } -} - -kernel void kernel_sub( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); - } -} - -kernel void kernel_mul( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); - } -} - -kernel void kernel_div( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); - } -} - -template -kernel void kernel_repeat( - device const char * src0, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3 % ne03; - const int64_t i02 = i2 % ne02; - const int64_t i01 = i1 % ne01; - - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i00 = i0 % ne00; - *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); - } -} - -typedef decltype(kernel_repeat) kernel_repeat_t; - -template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; -template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; -template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; -template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig % nb]; -} - -kernel void kernel_sub_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] - src1[tpig % nb]; -} - -kernel void kernel_mul_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % nb]; -} - -kernel void kernel_div_row( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - constant uint64_t & nb [[buffer(28)]], - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] / src1[tpig % nb]; -} - -kernel void kernel_scale( - device const float * src0, - device float * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_scale_4( - device const float4 * src0, - device float4 * dst, - constant float & scale, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * scale; -} - -kernel void kernel_clamp( - device const float * src0, - device float * dst, - constant float & min, - constant float & max, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); -} - -kernel void kernel_relu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_sigmoid( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_tanh( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = precise::tanh(x); -} - -constant float GELU_COEF_A = 0.044715f; -constant float GELU_QUICK_COEF = -1.702f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -kernel void kernel_gelu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_silu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_sqr( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sqrt( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sqrt(src0[tpig]); -} - -kernel void kernel_sin( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sin(src0[tpig]); -} - -kernel void kernel_cos( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = cos(src0[tpig]); -} - -kernel void kernel_sum_rows( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tpig[[thread_position_in_grid]]) { - int64_t i3 = tpig.z; - int64_t i2 = tpig.y; - int64_t i1 = tpig.x; - - if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { - return; - } - - device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - - float row_sum = 0; - - for (int64_t i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; - } - - dst_row[0] = row_sum; -} - -template -kernel void kernel_soft_max( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - - device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; - device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const int64_t h = i02; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - - // find the max value in the block - float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = -INFINITY; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = max_val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - max_val = buf[tiisg]; - max_val = simd_max(max_val); - } - - // parallel sum - float lsum = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); - lsum += exp_psrc0; - pdst[i00] = exp_psrc0; - } - - // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 - threadgroup_barrier(mem_flags::mem_none); - - float sum = simd_sum(lsum); - - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - sum = buf[tiisg]; - sum = simd_sum(sum); - } - - const float inv_sum = 1.0f/sum; - - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - pdst[i00] *= inv_sum; - } -} - -template -kernel void kernel_soft_max_4( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t i03 = (tgpig) / (ne02*ne01); - const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; - const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - - device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; - device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; - - float slope = 1.0f; - - if (max_bias > 0.0f) { - const int64_t h = i02; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); - } - - const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); - - float max_val = simd_max(lmax); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = -INFINITY; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = max_val; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - max_val = buf[tiisg]; - max_val = simd_max(max_val); - } - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - - const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; - - // This barrier fixes a failing test - // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 - threadgroup_barrier(mem_flags::mem_none); - - float sum = simd_sum(lsum); - - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - sum = buf[tiisg]; - sum = simd_sum(sum); - } - - const float inv_sum = 1.0f/sum; - - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - pdst4[i00] *= inv_sum; - } -} - -typedef decltype(kernel_soft_max) kernel_soft_max_t; -typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; - -template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; -template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; -template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; -template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; - -kernel void kernel_diag_mask_inf( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - const int64_t i02 = tpig[2]; - const int64_t i01 = tpig[1]; - const int64_t i00 = tpig[0]; - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - device const float4 * src0, - device float4 * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int & n_past, - uint3 tpig[[thread_position_in_grid]]) { - - const int64_t i = 2*tpig[0]; - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int64_t i4 = 4*i; - const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - const int64_t i01 = i4/(ne00); i4 -= i01*ne00; - const int64_t i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - dst[i+1][k] = -INFINITY; - if (i00 + k > n_past + i01) { - dst[i][k] = -INFINITY; - } - } -} - -// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 -// TODO: optimize -kernel void kernel_ssm_conv_f32( - device const void * src0, - device const void * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i2 = tgpig.y; - const int64_t i3 = tgpig.z; - - const int64_t nc = ne10; - const int64_t ncs = ne00; - const int64_t nr = ne01; - const int64_t n_t = ne1; - const int64_t n_s = ne2; - - device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); - device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); - device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2); - - float sumf = 0.0f; - - for (int64_t i0 = 0; i0 < nc; ++i0) { - sumf += s[i0] * c[i0]; - } - - x[0] = sumf; -} - -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 -// TODO: optimize -kernel void kernel_ssm_scan_f32( - device const void * src0, - device const void * src1, - device const void * src2, - device const void * src3, - device const void * src4, - device const void * src5, - device float * dst, - constant int64_t & d_state, - constant int64_t & d_inner, - constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb20, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb30, - constant uint64_t & nb31, - constant uint64_t & nb40, - constant uint64_t & nb41, - constant uint64_t & nb42, - constant uint64_t & nb50, - constant uint64_t & nb51, - constant uint64_t & nb52, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; - - const int64_t nc = d_state; - const int64_t nr = d_inner; - const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; - - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); - device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); - - if (i2 > 0) { - s0 = s; - } - - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; - float sumf = 0.0f; - - for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; - } - - y[0] = sumf; - } -} - -kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; - } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float mean = sum[0] / ne00; - - // recenter and VARIANCE - threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; - } - - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float variance = sum[0] / ne00; - - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - y[i00] = y[i00] * scale; - } -} - -kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = all_sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - all_sum = buf[tiisg]; - all_sum = simd_sum(all_sum); - } - - const float mean = all_sum/ne00; - const float scale = 1.0f/sqrt(mean + eps); - - device float4 * y = (device float4 *) (dst + tgpig*ne00); - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - y[i00] = x[i00] * scale; - } -} - -kernel void kernel_group_norm( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int32_t & n_groups, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - const int64_t ne = ne00*ne01*ne02; - const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); - - int start = tgpig * gs; - int end = start + gs; - - start += tpitg; - - if (end >= ne) { - end = ne; - } - - float tmp = 0.0f; // partial sum for thread in warp - - for (int j = start; j < end; j += ntg) { - tmp += src0[j]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - tmp = simd_sum(tmp); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = tmp; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - tmp = buf[tiisg]; - tmp = simd_sum(tmp); - } - - const float mean = tmp / gs; - tmp = 0.0f; - - for (int j = start; j < end; j += ntg) { - float xi = src0[j] - mean; - dst[j] = xi; - tmp += xi * xi; - } - - tmp = simd_sum(tmp); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (tiisg == 0) { - buf[sgitg] = tmp; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - tmp = buf[tiisg]; - tmp = simd_sum(tmp); - } - - const float variance = tmp / gs; - const float scale = 1.0f/sqrt(variance + eps); - for (int j = start; j < end; j += ntg) { - dst[j] *= scale; - } -} - -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q5 quants begin (0 or QK5_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); - } - return d * (sumy * -16.f + acc[0] + acc[1]); -} - -// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q5 quants begin (0 or QK5_1/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - - float2 acc = 0.f; - - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); - const uint32_t qh = *((device const uint32_t *)qb_curr->qh); - - for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) - + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); - acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) - + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); - } - return d * (acc[0] + acc[1]) + sumy * m; -} - -// putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -//Note: This is a template, but strictly speaking it only applies to -// quantizations where the block size is 32. It also does not -// guard against the number of rows not being divisible by -// N_DST, so this is another explicit assumption of the implementation. -template -void mul_vec_q_n_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, uint tiisg, uint sgitg) { - const int nb = ne00/QK4_0; - - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q_type * x = (device const block_q_type *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[nr] = {0.f}; - - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; - - device const float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += nw/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il); - } - - yb += QK4_0 * 16; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; - } - } -} - -kernel void kernel_mul_mv_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q4_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q5_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -kernel void kernel_mul_mv_q5_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - - -#define NB_Q8_0 8 - -void kernel_mul_mv_q8_0_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - const int nr = N_DST; - const int nsg = N_SIMDGROUP; - const int nw = N_SIMDWIDTH; - - const int nb = ne00/QK8_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * nsg + sgitg) * nr; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[NB_Q8_0]; - float sumf[nr]={0.f}; - - const int ix = tiisg/4; - const int il = tiisg%4; - - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; - - // each thread in a SIMD group deals with NB_Q8_0 quants at a time - for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { - yl[i] = yb[i]; - } - - for (int row = 0; row < nr; row++) { - device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il; - float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { - sumq += qs[iq] * yl[iq]; - } - sumf[row] += sumq*x[ib+row*nb].d; - } - - yb += NB_Q8_0 * nw; - } - - for (int row = 0; row < nr; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -[[host_name("kernel_mul_mv_q8_0_f32")]] -kernel void kernel_mul_mv_q8_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); -} - -#define N_MV_T_T 4 - -template -void kernel_mul_mv_impl( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_MV_T_T; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const T0 * x = (device const T0 *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (T0) x[i] * (T1) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const T04 * x4 = (device const T04 *) x; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); - device const T14 * y4 = (device const T14 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -template -kernel void kernel_mul_mv( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_impl( - src0, - src1, - dst, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - tgpig, - tiisg); -} - -typedef decltype(kernel_mul_mv) mul_mv_t; - -template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; -template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; - -template -kernel void kernel_mul_mv_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const T * x = (device const T *) (src0 + offset0); - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - device const T4 * x4 = (device const T4 *) x; - device const float4 * y4 = (device const float4 *) y; - - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); - } - - float all_sum = simd_sum(sumf); - - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; - -template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; - -// Assumes row size (ne00) is a multiple of 4 -template -kernel void kernel_mul_mv_l4( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - - const int nrows = ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const T4 * x4 = (device const T4 *) (src0 + offset0); - - for (int r1 = 0; r1 < nrows; ++r1) { - device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; - -template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; - -static float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - thread float * cos_theta, thread float * sin_theta) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); - } - *cos_theta = cos(theta) * mscale; - *sin_theta = sin(theta) * mscale; -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { - return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); -} - -static void rope_yarn_corr_dims( - int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] -) { - // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); -} - -template -kernel void kernel_rope_norm( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; - - float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - device const int32_t * pos = src1; - - const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; - - float cos_theta; - float sin_theta; - - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -template -kernel void kernel_rope_neox( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; - - float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - device const int32_t * pos = src1; - - const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; - - float cos_theta; - float sin_theta; - - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -typedef decltype(kernel_rope_norm) kernel_rope_norm_t; -typedef decltype(kernel_rope_neox) kernel_rope_neox_t; - -template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; -template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; - -template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; -template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; - -typedef void (im2col_t)( - device const float * x, - device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]); - -template -kernel void kernel_im2col( - device const float * x, - device char * dst, - constant int32_t & ofs0, - constant int32_t & ofs1, - constant int32_t & IW, - constant int32_t & IH, - constant int32_t & CHW, - constant int32_t & s0, - constant int32_t & s1, - constant int32_t & p0, - constant int32_t & p1, - constant int32_t & d0, - constant int32_t & d1, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0; - const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1; - - const int32_t offset_dst = - (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + - (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]); - - device T * pdst = (device T *) (dst); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - pdst[offset_dst] = 0.0f; - } else { - const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1; - pdst[offset_dst] = x[offset_src + iih * IW + iiw]; - } -} - -template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; -template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; - -kernel void kernel_upscale_f32( - device const char * src0, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant float & sf0, - constant float & sf1, - constant float & sf2, - constant float & sf3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3/sf3; - const int64_t i02 = i2/sf2; - const int64_t i01 = i1/sf1; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int64_t i00 = i0/sf0; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_ptr[0] = src0_ptr[0]; - } -} - -kernel void kernel_pad_f32( - device const char * src0, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; - - device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); - device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); - - if (i1 < ne01 && i2 < ne02 && i3 < ne03) { - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; - } - } - - return; - } - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; - } -} - -kernel void kernel_arange_f32( - device char * dst, - constant int64_t & ne0, - constant float & start, - constant float & step, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - device float * dst_ptr = (device float *) dst; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - dst_ptr[i0] = start + step * i0; - } -} - -kernel void kernel_timestep_embedding_f32( - device const char * src0, - device char * dst, - constant uint64_t & nb1, - constant int & dim, - constant int & max_period, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - int i = tgpig.x; - device float * embed_data = (device float *)(dst + i*nb1); - - int half_ = dim / 2; - for (int j = tpitg.x; j < half_; j += ntg.x) { - float timestep = ((device float *)src0)[i]; - float freq = (float)exp(-log((float)max_period) * j / half_); - float arg = timestep * freq; - embed_data[j ] = cos(arg); - embed_data[j + half_] = sin(arg); - } - - if (dim % 2 != 0 && tpitg.x == 0) { - embed_data[dim] = 0.f; - } -} - -// bitonic sort implementation following the CUDA kernels as reference -typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, - threadgroup int32_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]); - -template -kernel void kernel_argsort_f32_i32( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, - constant int64_t & ncols_pad, - threadgroup int32_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]]) { - // bitonic sort - int col = tpitg[0]; - int row = tgpig[1]; - - if (col >= ncols_pad) return; - - device const float * x_row = x + row * ncols; - threadgroup int32_t * dst_row = shared_values; - - // initialize indices - dst_row[col] = col; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int k = 2; k <= ncols_pad; k *= 2) { - for (int j = k / 2; j > 0; j /= 2) { - int ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (dst_row[col] >= ncols || - (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] > x_row[dst_row[ixj]] : - x_row[dst_row[col]] < x_row[dst_row[ixj]])) - ) { - SWAP(dst_row[col], dst_row[ixj]); - } - } else { - if (dst_row[ixj] >= ncols || - (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? - x_row[dst_row[col]] < x_row[dst_row[ixj]] : - x_row[dst_row[col]] > x_row[dst_row[ixj]])) - ) { - SWAP(dst_row[col], dst_row[ixj]); - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - } - - // copy the result to dst without the padding - if (col < ncols) { - dst[row * ncols + col] = dst_row[col]; - } -} - -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; - -kernel void kernel_leaky_relu_f32( - device const float * src0, - device float * dst, - constant float & slope, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; -} - -typedef void (flash_attn_ext_f16_t)( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - constant float & logit_softcap, - threadgroup half * shared, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]); - -// ref: https://arxiv.org/pdf/2307.08691.pdf -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - constant float & logit_softcap, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups - - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]*Q; - - const short D4 = D/4; - const short D8 = D/8; - //const short Q8 = Q/8; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) - - const short T = D + 2*nsg*SH; // shared memory size per query in (half) - const short TF = T/2; // shared memory size per query in (float) - const short T4 = T/4; // shared memory size per query in (half4) - - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - simdgroup_half8x8 lo[D8]; - - // load heads from Q to shared memory - for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); - - for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { - sq4[j*T4 + i] = (half4) q4[i]; - } else { - sq4[j*T4 + i] = 0.0h; - } - } - } - - // zero out lo - for (short i = 0; i < D8; ++i) { - lo[i] = make_filled_simdgroup_matrix(0.0h); - } - - // zero out shared memory SH - for (short j = 0; j < Q; ++j) { - for (short i = tiisg; i < SH; i += NW) { - ss[j*TF + i] = 0.0f; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - { - float S[Q] = { [0 ... Q-1] = 0.0h }; - float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; - - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; - - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; - - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; - - // k indices - const short ik2 = iq2/rk2; - const short ik3 = iq3/rk3; - - // v indices - const short iv2 = iq2/rv2; - const short iv3 = iq3/rv3; - - // load the queries from shared memory into local memory - simdgroup_half8x8 mq[D8]; - - for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); - } - - // pointer to the mask - device const half * mp = (device const half *) (mask + iq1*nb31); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; - - const float base = h < n_head_log2 ? m0 : m1; - const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exph); - } - - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= ne11) { - break; - } - - // Q*K^T - { - for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); - - device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose - - simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); - } - - simdgroup_store(mqk, ss + 8*cc, TF, 0, false); - } - } - - // used to detect blocks full of -INF - float smax = -INFINITY; - - // online softmax - { - float ms[Q]; - - for (short j = 0; j < Q; ++j) { - const float m = M[j]; - - // scale and apply the logitcap / mask - float s = ss[j*TF + tiisg]*scale; - - if (logit_softcap != 0.0f) { - s = logit_softcap*precise::tanh(s); - } - - if (mask != q) { - // mqk = mqk + mask*slope - s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; - } - - smax = simd_max(max(smax, s)); - M[j] = simd_max(max(M[j], s)); - - ms[j] = exp(m - M[j]); - const float vs = exp(s - M[j]); - - S[j] = S[j]*ms[j] + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[j*TF + tiisg] = vs; - } - - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*TF + C + tiisg] = ms[tiisg]; - } - } - - // skip -INF blocks - if (smax == -INFINITY) { - continue; - } - - // O = diag(ms)*O - { - simdgroup_float8x8 mm; - simdgroup_load(mm, ss + C, TF, 0, false); - - for (short i = 0; i < D8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); - } - } - - // O = O + (Q*K^T)*V - { - for (short cc = 0; cc < C/8; ++cc) { - device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - - for (short i = 0; i < D8; ++i) { - simdgroup_half8x8 mk; - simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); - - simdgroup_float8x8 mv; - simdgroup_load(mv, ss + 8*cc, TF, 0, false); - - simdgroup_multiply_accumulate(lo[i], mv, mk, lo[i]); - } - } - } - } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = 0; j < Q; ++j) { - if (tiisg == 0) { - ss[j*TF + 0] = S[j]; - ss[j*TF + 1] = M[j]; - } - } - } - - // reduce the warps sequentially - for (short sg = 1; sg < nsg; ++sg) { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // each simdgroup stores its output to shared memory, reusing sq - if (sgitg == sg) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // the first simdgroup accumulates the results from the other simdgroups - if (sgitg == 0) { - for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TF + 0]; - const float S1 = ss[j*TF + sg*SH + 0]; - - const float M0 = ss[j*TF + 1]; - const float M1 = ss[j*TF + sg*SH + 1]; - - M = max(M0, M1); - - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*TF + 0] = S; - ss[j*TF + 1] = M; - - ss[j*TF + C + j ] = ms0; - ss[j*TF + C + j + sg*SH] = ms1; - } - } - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - { - simdgroup_half8x8 t; - simdgroup_float8x8 ms0; - simdgroup_float8x8 ms1; - - simdgroup_load(ms0, ss + C, TF, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); - - for (short i = 0; i < D8; ++i) { - simdgroup_load (t, sq + i*8, T, 0, false); - simdgroup_multiply(t, ms1, t); - - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); - } - } - } - } - - // store result to shared memory (reuse sq) - if (sgitg == 0) { - for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], sq + i*8, T, 0, false); - } - } - - device float4 * dst4 = (device float4 *) dst; - - // final rescale with 1/S and store to global memory - if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const float S = ss[j*TF + 0]; - - for (short i = tiisg; i < D4; i += NW) { - dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; - } - } - } -} - -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80>; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; -//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; - -template // head size, queries per threadgroup, cache items per threadgroup -kernel void kernel_flash_attn_ext_vec_f16( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint32_t & n_head_log2, - constant float & logit_softcap, - threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short nsg = ntg.y; // number of simdgroups - - const short iq3 = tgpig[2]; - const short iq2 = tgpig[1]; - const short iq1 = tgpig[0]; - - const short D4 = D/4; - const short NW = N_SIMDWIDTH; - const short SH = (C + Q); // shared memory per simdgroup in (half) - - const short T = D + 2*nsg*SH; // shared memory size per query in (half) - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - const uint32_t h = iq2; - - const float base = h < n_head_log2 ? m0 : m1; - const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix - threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4 - threadgroup half4 * sr4 = (threadgroup half4 *) (shared + sgitg*D + 1*T); // scratch buffer for the results - - // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) - half4 lo[D4/NW]; - - // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); - - for (short i = tiisg; i < D4; i += NW) { - if (iq1 < ne01) { - sq4[i] = (half4) q4[i]; - } else { - sq4[i] = 0.0h; - } - } - - // zero out lo - for (short i = tiisg; i < D4; i += NW) { - lo[i/NW] = 0.0h; - } - - // zero out shared memory SH - for (short i = tiisg; i < SH/4; i += NW) { - ss4[i] = 0.0h; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - { - float S = { 0.0h }; - float M = { -FLT_MAX/2 }; - - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; - - // broadcast - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; - - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; - - // k indices - const short ik2 = iq2 / rk2; - const short ik3 = iq3 / rk3; - - // v indices - const short iv2 = iq2 / rv2; - const short iv3 = iq3 / rv3; - - // load the queries from shared memory into local memory - half4 mq[D4]; - - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - mq[i] = sq4[i]; - } - - // pointer to the mask - device const half4 * mp4 = (device const half4 *) (mask + iq1*nb31); - - // loop over the KV cache - // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { - const int ic = ic0 + C*sgitg; - if (ic >= ne11) { - break; - } - - // Q*K^T - { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - float4 mqk = { 0.0h }; - - device const half4 * pk4 = (device const half4 *) ((device const char *) k + ((ic + 4*cc)*nb11 + ik2*nb12 + ik3*nb13)); - -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - - half4x4 mk; - mk[0] = pk4[i + 0*(nb11/8)]; - mk[1] = pk4[i + 1*(nb11/8)]; - mk[2] = pk4[i + 2*(nb11/8)]; - mk[3] = pk4[i + 3*(nb11/8)]; - - mqk += (float4) (mq[i] * mk); - } - - // reduce the results from the threads in the simdgroup - mqk += simd_shuffle_down(mqk, 16); - mqk += simd_shuffle_down(mqk, 8); - mqk += simd_shuffle_down(mqk, 4); - mqk += simd_shuffle_down(mqk, 2); - mqk += simd_shuffle_down(mqk, 1); - - // mqk = mqk*scale + mask*slope - if (tiisg == 0) { - mqk *= scale; - - if (logit_softcap != 0.0f) { - mqk = logit_softcap*precise::tanh(mqk); - } - - mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f; - - ss4[cc] = mqk; - } - } - } - - // online softmax - { - const short p = tiisg; - - const float m = M; - const float s = ss[p]; - - M = simd_max(max(M, s)); - - const float ms = exp(m - M); - const float vs = exp(s - M); - - S = S*ms + simd_sum(vs); - - // the P matrix from the paper (Q rows, C columns) - ss[p] = vs; - - // O = diag(ms)*O -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - lo[i/NW] *= ms; - } - } - - // O = O + (Q*K^T)*V - { -#pragma unroll - for (short cc = 0; cc < C/4; ++cc) { - device const half4 * pv4 = (device const half4 *) ((device const char *) v + ((ic + 4*cc)*nb21 + iv2*nb22 + iv3*nb23)); - -#pragma unroll - for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - - lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; - lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; - lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; - lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; - } - } - } - - } - - // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - if (tiisg == 0) { - ss[0] = S; - ss[1] = M; - } - } - - // store results to shared memory - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = lo[ii/NW]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // parallel reduce - for (short r = nsg/2; r > 0; r >>= 1) { - if (sgitg < r) { - const float S0 = ss[ 0]; - const float S1 = ss[r*SH + 0]; - - const float M0 = ss[ 1]; - const float M1 = ss[r*SH + 1]; - - const float M = max(M0, M1); - - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); - - const float S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[0] = S; - ss[1] = M; - } - - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - sr4[i] = sr4[i]*ms0 + sr4[i + r*D4]*ms1; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - device float4 * dst4 = (device float4 *) dst; - - // final rescale with 1/S and store to global memory - if (sgitg == 0) { - const float S = ss[0]; - - for (short ii = 0; ii < D4; ii += NW) { - short i = ii + tiisg; - dst4[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D4 + i] = (float4) sr4[i]/S; - } - } -} - -template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; -//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; - -template -kernel void kernel_cpy( - device const void * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = (T1) src[0]; - } -} - -typedef decltype(kernel_cpy) kernel_cpy_t; - -template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; - -kernel void kernel_cpy_f32_q8_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0; - - device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = src[j]; - amax = MAX(amax, fabs(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK8_0].d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = src[j]*id; - - dst_data[i00/QK8_0].qs[j] = round(x0); - } - } -} - -kernel void kernel_cpy_f32_q4_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_0].d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_0/2 + j]*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - - dst_data[i00/QK4_0].qs[j] = xi0; - dst_data[i00/QK4_0].qs[j] |= xi1 << 4; - } - } -} - -kernel void kernel_cpy_f32_q4_1( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < QK4_1; j++) { - const float v = src[j]; - if (min > v) min = v; - if (max < v) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_1].d = d; - dst_data[i00/QK4_1].m = min; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK4_1/2 + j] - min)*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - - dst_data[i00/QK4_1].qs[j] = xi0; - dst_data[i00/QK4_1].qs[j] |= xi1 << 4; - } - } -} - -kernel void kernel_cpy_f32_q5_0( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK5_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -16; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_0].d = d; - - uint32_t qh = 0; - for (int j = 0; j < QK5_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK5_0/2 + j]*id; - - const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); - - dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_0].qh[j] = qh8[j]; - } - } -} - -kernel void kernel_cpy_f32_q5_1( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float max = src[0]; - float min = src[0]; - - for (int j = 1; j < QK5_1; j++) { - const float v = src[j]; - min = v < min ? v : min; - max = v > max ? v : max; - } - - const float d = (max - min) / 31; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_1].d = d; - dst_data[i00/QK5_1].m = min; - - uint32_t qh = 0; - for (int j = 0; j < QK5_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK5_1/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_1].qh[j] = qh8[j]; - } - } -} - -static inline int best_index_int8(int n, constant float * val, float x) { - if (x <= val[0]) return 0; - if (x >= val[n-1]) return n-1; - int ml = 0, mu = n-1; - while (mu-ml > 1) { - int mav = (ml+mu)/2; - if (x < val[mav]) mu = mav; else ml = mav; - } - return x - val[mu-1] < val[mu] - x ? mu-1 : mu; -} - -constexpr constant static float kvalues_iq4nl_f[16] = { - -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f -}; - -kernel void kernel_cpy_f32_iq4_nl( - device const float * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / kvalues_iq4nl_f[0]; - const float id = d ? 1.0f/d : 0.0f; - - float sumqx = 0, sumq2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_NL/2 + j]*id; - - const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); - const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); - - dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); - - const float v0 = kvalues_iq4nl_f[xi0]; - const float v1 = kvalues_iq4nl_f[xi1]; - const float w0 = src[0 + j]*src[0 + j]; - const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; - sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; - sumq2 += w0*v0*v0 + w1*v1*v1; - - } - - dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; - - } -} - -kernel void kernel_concat( - device const char * src0, - device const char * src1, - device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & dim, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; - - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - - device const float * x; - - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); - } else { - x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); - } - - device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - *y = *x; - } -} - -void kernel_mul_mv_q2_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q2_K) * nb; - - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 - - device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - float dall = dh[0]; - float dmin = dh[1] * 1.f/16.f; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_q2_K_f32")]] -kernel void kernel_mul_mv_q2_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - - //const uint16_t kmask1 = 0x3030; - //const uint16_t kmask2 = 0x0f0f; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int ip = tid/4; // 0 or 1 - const int il = 2*((tid%4)/2); // 0 or 2 - const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; - - // One would think that the Metal compiler would figure out that ip and il can only have - // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it - // with these two tales. - // - // Possible masks for the high bit - const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 - {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 - {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 - {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 - - // Possible masks for the low 2 bits - const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; - - const ushort4 hm = mm[2*ip + il/2]; - - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; - - const uint16_t s_shift1 = 4*ip; - const uint16_t s_shift2 = s_shift1 + il; - - const int q_offset = 32*ip + l0; - const int y_offset = 128*ip + 32*il + l0; - - const int step = sizeof(block_q3_K) * nb / 2; - - device const float * y1 = yy + ix*QK_K + y_offset; - - uint32_t scales32, aux32; - thread uint16_t * scales16 = (thread uint16_t *)&scales32; - thread const int8_t * scales = (thread const int8_t *)&scales32; - - float sumf1[2] = {0.f}; - float sumf2[2] = {0.f}; - for (int i = ix; i < nb; i += 4) { - - for (int l = 0; l < 8; ++l) { - yl[l+ 0] = y1[l+ 0]; - yl[l+ 8] = y1[l+16]; - yl[l+16] = y1[l+32]; - yl[l+24] = y1[l+48]; - } - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); - device const uint16_t * a = (device const uint16_t *)(x[i].scales); - device const half * dh = &x[i].d; - - for (int row = 0; row < 2; ++row) { - - const float d_all = (float)dh[0]; - - scales16[0] = a[4]; - scales16[1] = a[5]; - aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; - scales16[0] = a[il+0]; - scales16[1] = a[il+1]; - scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; - - float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2]; - s1 += yl[l+0] * (qs & qm[il/2][0]); - s2 += yl[l+1] * (qs & qm[il/2][1]); - s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); - s4 += yl[l+16] * (qs & qm[il/2][2]); - s5 += yl[l+17] * (qs & qm[il/2][3]); - s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); - } - float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[0] - 32); - sumf2[row] += d2 * (scales[2] - 32); - - s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { - const int32_t qs = q[l/2+8]; - s1 += yl[l+8] * (qs & qm[il/2][0]); - s2 += yl[l+9] * (qs & qm[il/2][1]); - s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); - s4 += yl[l+24] * (qs & qm[il/2][2]); - s5 += yl[l+25] * (qs & qm[il/2][3]); - s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); - } - d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); - d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); - sumf1[row] += d1 * (scales[1] - 32); - sumf2[row] += d2 * (scales[3] - 32); - - q += step; - h += step; - a += step; - dh += step; - - } - - y1 += 4 * QK_K; - - } - - for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); - sumf1[row] = simd_sum(sumf); - } - if (tiisg == 0) { - for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; - } - } -} - -[[host_name("kernel_mul_mv_q3_K_f32")]] -kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; - float yh[16]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - for (int ib = ix; ib < nb; ib += 4) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; - yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; - yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; - yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; - device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & kmask1; - sc16[1] = sc[2] & kmask1; - sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); - sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); - - device const uint16_t * q2 = q1 + 32; - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); - acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); - acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); - acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); - acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); - acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); - acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); - acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + - (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + - (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + - (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - sc += step; - dh += step; - } - - y4 += 4 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_q4_K_f32")]] -kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q5_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf[2]={0.f}; - - const int step = sizeof(block_q5_K) * nb; - - float yl[16], yh[16]; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = tiisg/4; - const int ix = tiisg%4; - const int iq = tid/4; - const int ir = tid%4; - const int n = 8; - - const int l0 = n*ir; - const int q_offset = 32*iq + l0; - const int y_offset = 64*iq + l0; - - const uint8_t hm1 = 1u << (2*iq); - const uint8_t hm2 = hm1 << 1; - const uint8_t hm3 = hm1 << 4; - const uint8_t hm4 = hm2 << 4; - - uint16_t sc16[4]; - thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - - device const float * y1 = yy + ix*QK_K + y_offset; - - for (int i = ix; i < nb; i += 4) { - - device const uint8_t * q1 = x[i].qs + q_offset; - device const uint8_t * qh = x[i].qh + l0; - device const half * dh = &x[i].d; - device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; - - device const float * y2 = y1 + 128; - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { - yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; - yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; - yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; - yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; - } - - for (int row = 0; row < 2; ++row) { - - device const uint8_t * q2 = q1 + 64; - - sc16[0] = a[0] & kmask1; - sc16[1] = a[2] & kmask1; - sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); - sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); - - float4 acc1 = {0.f}; - float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { - uint8_t h = qh[l]; - acc1[0] += yl[l+0] * (q1[l] & 0x0F); - acc1[1] += yl[l+8] * (q1[l] & 0xF0); - acc1[2] += yh[l+0] * (q2[l] & 0x0F); - acc1[3] += yh[l+8] * (q2[l] & 0xF0); - acc2[0] += h & hm1 ? yl[l+0] : 0.f; - acc2[1] += h & hm2 ? yl[l+8] : 0.f; - acc2[2] += h & hm3 ? yh[l+0] : 0.f; - acc2[3] += h & hm4 ? yh[l+8] : 0.f; - } - const float dall = dh[0]; - const float dmin = dh[1]; - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + - sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + - sc8[4] * (acc1[2] + 16.f*acc2[2]) + - sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - - q1 += step; - qh += step; - dh += step/2; - a += step/2; - - } - - y1 += 4 * QK_K; - - } - - for (int row = 0; row < 2; ++row) { - const float tot = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; - } - } -} - -[[host_name("kernel_mul_mv_q5_K_f32")]] -kernel void kernel_mul_mv_q5_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_q6_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; - - const int row = 2 * r0 + sgitg; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf = 0; - - const int tid = tiisg/2; - const int ix = tiisg%2; - const int ip = tid/8; // 0 or 1 - const int il = tid%8; - const int n = 4; - const int l0 = n*il; - const int is = 8*ip + l0/16; - - const int y_offset = 128*ip + l0; - const int q_offset_l = 64*ip + l0; - const int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += 2) { - - device const uint8_t * q1 = x[i].ql + q_offset_l; - device const uint8_t * q2 = q1 + 32; - device const uint8_t * qh = x[i].qh + q_offset_h; - device const int8_t * sc = x[i].scales + is; - - device const float * y = yy + i * QK_K + y_offset; - - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); - - } - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } -} - -[[host_name("kernel_mul_mv_q6_K_f32")]] -kernel void kernel_mul_mv_q6_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -// ======================= "True" 2-bit - -void kernel_mul_mv_iq2_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); - { - int nval = 4; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; - nval = 2; - pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq2_xxs * xr = x + ibl; - device const uint16_t * q2 = xr->qs + 4 * ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - device const uint8_t * aux8 = (device const uint8_t *)q2; - const uint32_t aux32 = q2[2] | (q2[3] << 16); - const float d = db * (0.5f + (aux32 >> 28)); - - float sum = 0; - for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { - sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } - } - sumf[row] += d * sum; - - dh += nb*sizeof(block_iq2_xxs)/2; - q2 += nb*sizeof(block_iq2_xxs)/2; - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; - } - } -} - -[[host_name("kernel_mul_mv_iq2_xxs_f32")]] -kernel void kernel_mul_mv_iq2_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq2_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); - { - int nval = 8; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; - nval = 2; - pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq2_xs * xr = x + ibl; - device const uint16_t * q2 = xr->qs + 4 * ib; - device const uint8_t * sc = xr->scales + ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const uint8_t ls1 = sc[0] & 0xf; - const uint8_t ls2 = sc[0] >> 4; - const float d1 = db * (0.5f + ls1); - const float d2 = db * (0.5f + ls2); - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < 2; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { - sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } - } - for (int l = 2; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { - sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); - } - } - sumf[row] += d1 * sum1 + d2 * sum2; - - dh += nb*sizeof(block_iq2_xs)/2; - q2 += nb*sizeof(block_iq2_xs)/2; - sc += nb*sizeof(block_iq2_xs); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; - } - } -} - -[[host_name("kernel_mul_mv_iq2_xs_f32")]] -kernel void kernel_mul_mv_iq2_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq3_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); - { - int nval = 4; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; - nval = 2; - pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq3_xxs * xr = x + ibl; - device const uint8_t * q3 = xr->qs + 8 * ib; - device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const uint32_t aux32 = gas[0] | (gas[1] << 16); - const float d = db * (0.5f + (aux32 >> 28)); - - float2 sum = {0}; - for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { - sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); - sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); - } - } - sumf[row] += d * (sum[0] + sum[1]); - - dh += nb*sizeof(block_iq3_xxs)/2; - q3 += nb*sizeof(block_iq3_xxs); - gas += nb*sizeof(block_iq3_xxs)/2; - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; - } - } -} - -[[host_name("kernel_mul_mv_iq3_xxs_f32")]] -kernel void kernel_mul_mv_iq3_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq3_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - { - int nval = 8; - int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq3_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 8 * ib; - device const uint8_t * qh = xr->qh + ib; - device const uint8_t * sc = xr->scales + (ib/2); - device const uint8_t * signs = xr->signs + 4 * ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); - - float2 sum = {0}; - for (int l = 0; l < 4; ++l) { - const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; - const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); - for (int j = 0; j < 4; ++j) { - sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); - sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); - } - } - sumf[row] += d * (sum[0] + sum[1]); - - dh += nb*sizeof(block_iq3_s)/2; - qs += nb*sizeof(block_iq3_s); - qh += nb*sizeof(block_iq3_s); - sc += nb*sizeof(block_iq3_s); - signs += nb*sizeof(block_iq3_s); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_iq3_s_f32")]] -kernel void kernel_mul_mv_iq3_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq2_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - //{ - // int nval = 32; - // int pos = (32*sgitg + tiisg)*nval; - // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq2_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint8_t * qh = xr->qh + ib; - device const uint8_t * sc = xr->scales + ib; - device const uint8_t * signs = qs + QK_K/8; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - const float db = dh[0]; - const float d1 = db * (0.5f + (sc[0] & 0xf)); - const float d2 = db * (0.5f + (sc[0] >> 4)); - - float2 sum = {0}; - for (int l = 0; l < 2; ++l) { - //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - for (int j = 0; j < 8; ++j) { - sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); - sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); - } - } - sumf[row] += d1 * sum[0] + d2 * sum[1]; - - dh += nb*sizeof(block_iq2_s)/2; - qs += nb*sizeof(block_iq2_s); - qh += nb*sizeof(block_iq2_s); - sc += nb*sizeof(block_iq2_s); - signs += nb*sizeof(block_iq2_s); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; - } - } -} - -[[host_name("kernel_mul_mv_iq2_s_f32")]] -kernel void kernel_mul_mv_iq2_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -void kernel_mul_mv_iq1_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - float sumy = 0; - for (int i = 0; i < 32; ++i) { - yl[i] = y4[i]; - sumy += yl[i]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq1_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint16_t * qh = xr->qh + ib; - device const half * dh = &xr->d; - - for (int row = 0; row < N_DST; row++) { - - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); - constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); - constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); - - float sum = 0; - for (int j = 0; j < 4; ++j) { - sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) - + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) - + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) - + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); - } - sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - - dh += nb*sizeof(block_iq1_s)/2; - qs += nb*sizeof(block_iq1_s); - qh += nb*sizeof(block_iq1_s)/2; - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -void kernel_mul_mv_iq1_m_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - - const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[32]; - float sumf[N_DST]={0.f}, all_sum; - - const int nb32 = nb * (QK_K / 32); - - const int ix = tiisg; - - device const float * y4 = y + 32 * ix; - - iq1m_scale_t scale; - - for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - float4 sumy = {0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; - } - - const int ibl = ib32 / (QK_K / 32); - const int ib = ib32 % (QK_K / 32); - - device const block_iq1_m * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib; - device const uint8_t * qh = xr->qh + 2 * ib; - device const uint16_t * sc = (device const uint16_t *)xr->scales; - - for (int row = 0; row < N_DST; row++) { - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); - constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); - constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); - - float2 sum = {0.f}; - for (int j = 0; j < 4; ++j) { - sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) - + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); - sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) - + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); - } - const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - - sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + - (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); - - sc += nb*sizeof(block_iq1_m)/2; - qs += nb*sizeof(block_iq1_m); - qh += nb*sizeof(block_iq1_m); - } - - y4 += 32 * 32; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -void kernel_mul_mv_iq4_nl_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK4_NL; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0 or 1 - - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - float4 yl[4]; - float sumf[2]={0.f}, all_sum; - - device const float * yb = y + ix * QK4_NL + it * 8; - - uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; - - float4 qf1, qf2; - - for (int ib = ix; ib < nb; ib += 16) { - - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { - - device const block_iq4_nl & xb = x[row*nb + ib]; - device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); - - float4 acc1 = {0.f}, acc2 = {0.f}; - - aux32[0] = q4[0] | (q4[1] << 16); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[0] * qf1; - acc2 += yl[1] * qf2; - - aux32[0] = q4[2] | (q4[3] << 16); - aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; - aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[2] * qf1; - acc2 += yl[3] * qf2; - - acc1 += acc2; - - sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - - } - - yb += 16 * QK4_NL; - } - - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -void kernel_mul_mv_iq4_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { - - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = (r0 * 2 + sgitg) * 2; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/16; // 0 or 1 - const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; - - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; - threadgroup_barrier(mem_flags::mem_threadgroup); - - float4 yl[4]; - float sumf[2]={0.f}, all_sum; - - device const float * yb = y + ix * QK_K + ib * 32 + il * 8; - - uint32_t aux32[2]; - thread const uint8_t * q8 = (thread const uint8_t *)aux32; - - float4 qf1, qf2; - - for (int ibl = ix; ibl < nb; ibl += 2) { - - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < 2; ++row) { - - device const block_iq4_xs & xb = x[row*nb + ibl]; - device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); - - float4 acc1 = {0.f}, acc2 = {0.f}; - - aux32[0] = q4[0] & 0x0f0f0f0f; - aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[0] * qf1; - acc2 += yl[1] * qf2; - - aux32[0] = q4[1] & 0x0f0f0f0f; - aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; - acc1 += yl[2] * qf1; - acc2 += yl[3] * qf2; - - acc1 += acc2; - - const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; - sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - - } - - yb += 2 * QK_K; - } - - for (int row = 0; row < 2; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} - -[[host_name("kernel_mul_mv_iq1_s_f32")]] -kernel void kernel_mul_mv_iq1_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq1_m_f32")]] -kernel void kernel_mul_mv_iq1_m_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_nl_f32")]] -kernel void kernel_mul_mv_iq4_nl_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -[[host_name("kernel_mul_mv_iq4_xs_f32")]] -kernel void kernel_mul_mv_iq4_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -} - -//============================= templates and their specializations ============================= - -// NOTE: this is not dequantizing - we are simply fitting the template -template -void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { - float4x4 temp = *(((device float4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { - half4x4 temp = *(((device half4x4 *)src)); - for (int i = 0; i < 16; i++){ - reg[i/4][i%4] = temp[i/4][i%4]; - } -} - -template -void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float md = -8.h * xb->d; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; - reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; - } -} - -template -void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const float d1 = il ? (xb->d / 16.h) : xb->d; - const float d2 = d1 / 256.f; - const float m = xb->m; - const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = mask0 << 8; - - for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; - reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; - } -} - -template -void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 3); - const float d = xb->d; - const float md = -16.h * xb->d; - const ushort mask = il ? 0x00F0 : 0x000F; - - const uint32_t qh = *((device const uint32_t *)xb->qh); - - const int x_mv = il ? 4 : 0; - - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; - - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; - - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - - reg[i/2][2*(i%2)+0] = d * x0 + md; - reg[i/2][2*(i%2)+1] = d * x1 + md; - } -} - -template -void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 4); - const float d = xb->d; - const float m = xb->m; - const ushort mask = il ? 0x00F0 : 0x000F; - - const uint32_t qh = *((device const uint32_t *)xb->qh); - - const int x_mv = il ? 4 : 0; - - const int gh_mv = il ? 12 : 0; - const int gh_bk = il ? 0 : 4; - - for (int i = 0; i < 8; i++) { - // extract the 5-th bits for x0 and x1 - const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; - const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; - - // combine the 4-bits from qs with the 5th bit - const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); - const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); - - reg[i/2][2*(i%2)+0] = d * x0 + m; - reg[i/2][2*(i%2)+1] = d * x1 + m; - } -} - -template -void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { - device const int8_t * qs = ((device const int8_t *)xb->qs); - const half d = xb->d; - - for (int i = 0; i < 16; i++) { - reg[i/4][i%4] = (qs[i + 16*il] * d); - } -} - -template -void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { - const float d = xb->d; - const float min = xb->dmin; - device const uint8_t * q = (device const uint8_t *)xb->qs; - float dl, ml; - uint8_t sc = xb->scales[il]; - - q = q + 32*(il/8) + 16*(il&1); - il = (il/2)%4; - - half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * q = (device const uint8_t *)xb->qs; - device const uint8_t * h = (device const uint8_t *)xb->hmask; - device const int8_t * scales = (device const int8_t *)xb->scales; - - q = q + 32 * (il/8) + 16 * (il&1); - h = h + 16 * (il&1); - uint8_t m = 1 << (il/2); - uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ - ((il/4)>0 ? 12 : 3); - uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; - uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) - : (scale_2&kmask2) | ((scale_1&kmask1) << 4); - float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); - const float ml = 4.f * dl; - - il = (il/2) & 3; - const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - dl *= coef; - - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); - } -} - -static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { - return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} - : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; -} - -template -void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uchar * q = xb->qs; - - short is = (il/4) * 2; - q = q + (il/4) * 32 + 16 * (il&1); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.h; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; - - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * (q[i] & mask) - ml; - } -} - -template -void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; - device const uint8_t * qh = xb->qh; - - short is = (il/4) * 2; - q = q + 32 * (il/4) + 16 * (il&1); - qh = qh + 16 * (il&1); - uint8_t ul = 1 << (il/2); - il = il & 3; - const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); - const float d = il < 2 ? xb->d : xb->d / 16.f; - const float min = xb->dmin; - const float dl = d * sc[0]; - const float ml = min * sc[1]; - - const ushort mask = il<2 ? 0x0F : 0xF0; - const float qh_val = il<2 ? 16.f : 256.f; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; - } -} - -template -void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const half d_all = xb->d; - device const uint8_t * ql = (device const uint8_t *)xb->ql; - device const uint8_t * qh = (device const uint8_t *)xb->qh; - device const int8_t * scales = (device const int8_t *)xb->scales; - - ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); - qh = qh + 32*(il/8) + 16*(il&1); - float sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2) & 3; - - const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const float coef = il>1 ? 1.f/16.f : 1.f; - const float ml = d_all * sc * 32.f; - const float dl = d_all * sc * coef; - for (int i = 0; i < 16; ++i) { - const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) - : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); - reg[i/4][i%4] = dl * q - ml; - } -} - -template -void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. - device const uint16_t * q2 = xb->qs + 4*ib32; - const uint32_t aux32_g = q2[0] | (q2[1] << 16); - const uint32_t aux32_s = q2[2] | (q2[3] << 16); - thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; - const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; - constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); - uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; - for (int i = 0; i < 8; ++i) { - reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } - grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); - signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; - for (int i = 0; i < 8; ++i) { - reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } -} - -template -void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint16_t * q2 = xb->qs + 4*ib32; - const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; - constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); - uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; - for (int i = 0; i < 8; ++i) { - reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } - grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); - signs = ksigns_iq2xs[q2[2*il+1] >> 9]; - for (int i = 0; i < 8; ++i) { - reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); - } -} - -template -void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint8_t * q3 = xb->qs + 8*ib32; - device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; - const uint32_t aux32 = gas[0] | (gas[1] << 16); - const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; - constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); - constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); - uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); - reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); - } - grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); - grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); - signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; - for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); - reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); - } -} - -template -void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint8_t * qs = xb->qs + 8*ib32; - device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; - const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); - constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); - constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); - reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); - } - grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); - grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); - for (int i = 0; i < 4; ++i) { - reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); - reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); - } -} - -template -void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const float d = xb->d; - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; - device const uint8_t * signs = qs + QK_K/8; - const uint8_t qh = xb->qh[ib32] >> 4*il; - const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; - constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); - constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); - for (int i = 0; i < 8; ++i) { - reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); - reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); - } -} - -template -void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - il = il%2; - const float d = xb->d; - device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; - device const uint16_t * qh = xb->qh; - const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); - const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); - const uint16_t h = qh[ib32] >> 6*il; - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * (grid1[i] & 0xf) + ml; - reg[1][i] = dl * (grid1[i] >> 4) + ml; - reg[2][i] = dl * (grid2[i] & 0xf) + ml; - reg[3][i] = dl * (grid2[i] >> 4) + ml; - } -} - -template -void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - il = il%2; - device const uint16_t * sc = (device const uint16_t *)xb->scales; - - iq1m_scale_t scale; - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - const float d = scale.f16; - - device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; - device const uint8_t * qh = xb->qh + 2*ib32 + il; - - const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); - const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); - constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); - constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); - for (int i = 0; i < 4; ++i) { - reg[0][i] = dl * (grid1[i] & 0xf) + ml1; - reg[1][i] = dl * (grid1[i] >> 4) + ml1; - reg[2][i] = dl * (grid2[i] & 0xf) + ml2; - reg[3][i] = dl * (grid2[i] >> 4) + ml2; - } -} - -template -void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { - device const uint16_t * q4 = (device const uint16_t *)xb->qs; - const float d = xb->d; - uint32_t aux32; - thread const uint8_t * q8 = (thread const uint8_t *)&aux32; - for (int i = 0; i < 4; ++i) { - aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; - reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; - reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; - reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; - reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; - } -} - -template -void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { - // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 - const int ib32 = il/2; - il = il%2; - // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 - device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; - const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); - const float d = (float)xb->d * (ls - 32); - uint32_t aux32; - thread const uint8_t * q8 = (thread const uint8_t *)&aux32; - for (int i = 0; i < 4; ++i) { - aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; - reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; - reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; - reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; - reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; - } -} - -template -kernel void kernel_get_rows_q( - device const void * src0, - device const void * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { - float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; - } -} - -template -kernel void kernel_get_rows_f( - device const void * src0, - device const void * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; - } -} - -kernel void kernel_get_rows_i32( - device const void * src0, - device const void * src1, - device int32_t * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; - } -} - - -#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B -#define BLOCK_SIZE_K 32 -#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A -#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B -#define THREAD_PER_BLOCK 128 -#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers -#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers -#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 -#define SG_MAT_ROW 8 - -// each block_q contains 16*nl weights -template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup T * sa = (threadgroup T *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_T8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - - short il = (tiitg % THREAD_PER_ROW); - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); - ushort offset1 = il/nl; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * im - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - T4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - #pragma unroll(16) - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - #pragma unroll(4) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - #pragma unroll(4) - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - #pragma unroll(2) - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - #pragma unroll(8) - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); - } - } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg == 0) { - for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids -template -void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, - threadgroup ushort2 * rowids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - int64_t ne0ne1, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - - if (r1 * BLOCK_SIZE_N >= ne1) return; - - // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; - - // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; - for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); - } - short il = (tiitg % THREAD_PER_ROW); - - ushort offset1 = il/nl; - - threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; - - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * id[1] - + nb11 * (id[0] % ne11) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - // load data and store to threadgroup memory - half4x4 temp_a; - dequantize_func(x, il, temp_a); - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; - } - - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); - - il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; - y += BLOCK_SIZE_K; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); - - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { - for (int i = 0; i < 4; i++) { - simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); - } - simdgroup_barrier(mem_flags::mem_none); - for (int i = 0; i < 2; i++) { - simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); - } - - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - - for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); - } - } - } - - { - threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - device float * C = dst + (BLOCK_SIZE_M * r0); - if (sgitg == 0) { - for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int joff = jid[0] * ne0 + jid[1] * ne0ne1; - for (int i = 0; i < n_rows; i++) { - *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); - } - } - } - } -} - -template -kernel void kernel_mul_mm_id( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int32_t i02 = tgpig.z; - tgpig.z = 0; - - device const uchar * src0 = src0s + i02*nb02; - - // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); - - // TODO: parallelize this loop - int64_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < nei1; ii1++) { - for (ushort ii0 = 0; ii0 < nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; - if (id == i02) { - //if (tiitg == 0) { - rowids[_ne1] = ushort2(ii0, ii1); - //} - _ne1++; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - kernel_mul_mm_id_impl( - src0, - src1, - rowids, - dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, - tgpig, - tiitg, - sgitg); -} - -#define QK_NL 16 - -// -// get rows -// - -typedef decltype(kernel_get_rows_f) get_rows_f_t; - -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; - -typedef decltype(kernel_get_rows_q) get_rows_q_t; - -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; -template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; - -// -// matrix-matrix multiplication -// - -typedef decltype(kernel_mul_mm) mat_mm_t; - -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; - -// -// indirect matrix-matrix multiplication -// - -typedef decltype(kernel_mul_mm_id) mat_mm_id_t; - -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; - -// -// matrix-vector multiplication -// - -typedef void (kernel_mul_mv_impl_t)( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg); - -typedef void (kernel_mul_mv2_impl_t)( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne10, - int64_t ne12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg); - -template -void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg); -} - -template -void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); -} - -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; - -template -kernel void kernel_mul_mv_id( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int iid1 = tgpig.z/nei0; - const int idx = tgpig.z%nei0; - - tgpig.z = 0; - - const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; - - const int64_t i11 = idx % ne11; - const int64_t i12 = iid1; - - const int64_t i1 = idx; - const int64_t i2 = i12; - - device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; - - impl_fn( - /* src0 */ src0_cur, - /* src1 */ src1_cur, - /* dst */ dst_cur, - /* ne00 */ ne00, - /* ne01 */ ne01, - /* ne02 */ 1,//ne02, - /* nb00 */ nb00, - /* nb01 */ nb01, - /* nb02 */ nb02, - /* ne10 */ ne10, - /* ne11 */ 1,//ne11, - /* ne12 */ 1,//ne12, - /* ne13 */ 1,//ne13, - /* nb10 */ nb10, - /* nb11 */ nb11, - /* nb12 */ nb12, - /* ne0 */ ne0, - /* ne1 */ 1,//ne1, - /* nb1 */ nb1, - /* r2 */ 1, - /* r3 */ 1, - shared_values, - tgpig, - tiitg, - tiisg, - sgitg); -} - -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; - -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt new file mode 100644 index 000000000..89fcde2fa --- /dev/null +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -0,0 +1,121 @@ +find_library(FOUNDATION_LIBRARY Foundation REQUIRED) +find_library(METAL_FRAMEWORK Metal REQUIRED) +find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + +message(STATUS "Metal framework found") + +ggml_add_backend_library(ggml-metal + ggml-metal.m + ) + +target_link_libraries(ggml-metal PRIVATE + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ) + +if (GGML_METAL_NDEBUG) + add_compile_definitions(GGML_METAL_NDEBUG) +endif() + +if (GGML_METAL_USE_BF16) + add_compile_definitions(GGML_METAL_USE_BF16) +endif() + +# copy metal files to bin directory +configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) +configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) + +if (GGML_METAL_EMBED_LIBRARY) + enable_language(ASM) + + add_compile_definitions(GGML_METAL_EMBED_LIBRARY) + + set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") + set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") + + file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") + + # merge ggml-common.h and ggml-metal.metal into a single file + set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") + + add_custom_command( + OUTPUT ${METALLIB_EMBED_ASM} + COMMAND echo "Embedding Metal library" + COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP} + COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED} + COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM} + COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM} + COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM} + COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM} + COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM} + COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM} + DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h + COMMENT "Generate assembly for embedded Metal library" + ) + + target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM}) +else() + if (GGML_METAL_SHADER_DEBUG) + # custom command to do the following: + # xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air + # xcrun -sdk macosx metallib ggml-metal.air -o default.metallib + # + # note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works + # disabling fast math is needed in order to pass tests/test-backend-ops + # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1 + # note: unfortunately, we have to call it default.metallib instead of ggml.metallib + # ref: https://github.com/ggerganov/whisper.cpp/issues/1720 + set(XC_FLAGS -fno-fast-math -fno-inline -g) + else() + set(XC_FLAGS -O3) + endif() + + # Append macOS metal versioning flags + if (GGML_METAL_MACOSX_VERSION_MIN) + message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation") + list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN}) + endif() + + if (GGML_METAL_STD) + message(STATUS "Adding -std=${GGML_METAL_STD} flag to metal compilation") + list (APPEND XC_FLAGS -std=${GGML_METAL_STD}) + endif() + + add_custom_command( + OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air + COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air + COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h + COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal + DEPENDS ggml-metal.metal ggml-common.h + COMMENT "Compiling Metal kernels" + ) + + # FIXME: only add to the ggml-metal target? + add_custom_target( + ggml-metal-lib ALL + DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + ) +endif() # GGML_METAL_EMBED_LIBRARY + +if (NOT GGML_METAL_EMBED_LIBRARY) + install( + FILES src/ggml-metal/ggml-metal.metal + PERMISSIONS + OWNER_READ + OWNER_WRITE + GROUP_READ + WORLD_READ + DESTINATION ${CMAKE_INSTALL_BINDIR}) + + install( + FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib + DESTINATION ${CMAKE_INSTALL_BINDIR} + ) +endif() diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h new file mode 100644 index 000000000..e3dc25f16 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -0,0 +1,288 @@ +#ifndef GGML_METAL_IMPL +#define GGML_METAL_IMPL + +// kernel argument structs +// +// - element counters (e.g. ne00) typically use int32_t to reduce register usage +// however, be careful from int overflows when using those in the kernel implementation +// +// - strides (e.g. nb00) use uint64_t + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + +typedef struct { + int64_t ne10; + int64_t ne11; + int64_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; + bool inplace; +} ggml_metal_kargs_set; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; + int16_t nsg; + int16_t nxpsg; + int16_t r1ptg; +} ggml_metal_kargs_mul_mv_ext; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; + +#endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m new file mode 100644 index 000000000..76f8e4291 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -0,0 +1,4990 @@ +#import "ggml-metal.h" + +#import "ggml-impl.h" +#import "ggml-backend-impl.h" +#import "ggml-metal-impl.h" + +#import + +#import + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// max memory buffers that can be mapped to the device +#define GGML_METAL_MAX_BUFFERS 64 + +// max number of MTLCommandBuffer used to submit a graph for processing +#define GGML_METAL_MAX_COMMAND_BUFFERS 8 + +// create residency sets only on macOS >= 15.0 +#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 +#define GGML_METAL_HAS_RESIDENCY_SETS 1 +#endif + +// globals + +// overload of MTLGPUFamilyMetal3 (not available in some environments) +static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; + +// initialized in ggml_backend_metal_reg +static struct ggml_backend_reg g_ggml_backend_metal_reg; +static struct ggml_backend_device g_ggml_backend_metal_device; + +// information about a Metal device +// note: assumes single GPU device - the default one +// TODO: support multiple GPU devices +static struct ggml_backend_metal_device_context { + id mtl_device; + int mtl_device_ref_count; + + bool has_simdgroup_reduction; + bool has_simdgroup_mm; + bool has_residency_sets; + bool has_bfloat; + bool use_bfloat; + + char name[128]; +} g_ggml_ctx_dev_main = { + /*.mtl_device =*/ nil, + /*.mtl_device_ref_count =*/ 0, + /*.has_simdgroup_reduction =*/ false, + /*.has_simdgroup_mm =*/ false, + /*.has_residency_sets =*/ false, + /*.has_bfloat =*/ false, + /*.use_bfloat =*/ false, + /*.name =*/ "", +}; + +// acquire +static id ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) { + assert(ctx != NULL); + + if (ctx->mtl_device == nil) { + ctx->mtl_device = MTLCreateSystemDefaultDevice(); + } + + if (ctx->mtl_device) { + ctx->has_simdgroup_reduction = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; + ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + + ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; + +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL; +#endif + + ctx->has_bfloat = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; + ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; + +#if defined(GGML_METAL_USE_BF16) + ctx->use_bfloat = ctx->has_bfloat; +#else + ctx->use_bfloat = false; +#endif + + strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1); + } + + ctx->mtl_device_ref_count++; + + return ctx->mtl_device; +} + +// release +static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) { + assert(ctx != NULL); + assert(ctx->mtl_device_ref_count > 0); + + ctx->mtl_device_ref_count--; + + if (ctx->mtl_device_ref_count == 0) { + if (ctx->mtl_device) { + [ctx->mtl_device release]; + ctx->mtl_device = nil; + } + } +} + +// kernels + +struct ggml_metal_kernel { + id pipeline; +}; + +enum ggml_metal_kernel_type { + GGML_METAL_KERNEL_TYPE_ADD, + GGML_METAL_KERNEL_TYPE_ADD_ROW, + GGML_METAL_KERNEL_TYPE_SUB, + GGML_METAL_KERNEL_TYPE_SUB_ROW, + GGML_METAL_KERNEL_TYPE_MUL, + GGML_METAL_KERNEL_TYPE_MUL_ROW, + GGML_METAL_KERNEL_TYPE_DIV, + GGML_METAL_KERNEL_TYPE_DIV_ROW, + GGML_METAL_KERNEL_TYPE_REPEAT_F32, + GGML_METAL_KERNEL_TYPE_REPEAT_F16, + GGML_METAL_KERNEL_TYPE_REPEAT_I32, + GGML_METAL_KERNEL_TYPE_REPEAT_I16, + GGML_METAL_KERNEL_TYPE_SCALE, + GGML_METAL_KERNEL_TYPE_SCALE_4, + GGML_METAL_KERNEL_TYPE_CLAMP, + GGML_METAL_KERNEL_TYPE_TANH, + GGML_METAL_KERNEL_TYPE_RELU, + GGML_METAL_KERNEL_TYPE_SIGMOID, + GGML_METAL_KERNEL_TYPE_GELU, + GGML_METAL_KERNEL_TYPE_GELU_4, + GGML_METAL_KERNEL_TYPE_GELU_QUICK, + GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, + GGML_METAL_KERNEL_TYPE_SILU, + GGML_METAL_KERNEL_TYPE_SILU_4, + GGML_METAL_KERNEL_TYPE_ELU, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, + GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, + GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, + GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, + GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, + GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, + GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, + GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, + GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, + GGML_METAL_KERNEL_TYPE_RMS_NORM, + GGML_METAL_KERNEL_TYPE_GROUP_NORM, + GGML_METAL_KERNEL_TYPE_NORM, + GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, + GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, + GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, + //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_F32, + GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, + GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, + GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, + GGML_METAL_KERNEL_TYPE_UPSCALE_F32, + GGML_METAL_KERNEL_TYPE_PAD_F32, + GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, + GGML_METAL_KERNEL_TYPE_ARANGE_F32, + GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, + GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, + GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, + GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_SET_I32, + GGML_METAL_KERNEL_TYPE_SET_F32, + GGML_METAL_KERNEL_TYPE_CPY_F32_F32, + GGML_METAL_KERNEL_TYPE_CPY_F32_F16, + GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, + GGML_METAL_KERNEL_TYPE_CPY_F16_F16, + GGML_METAL_KERNEL_TYPE_CPY_F16_F32, + GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, + GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, + GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, + GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, + GGML_METAL_KERNEL_TYPE_CONCAT, + GGML_METAL_KERNEL_TYPE_SQR, + GGML_METAL_KERNEL_TYPE_SQRT, + GGML_METAL_KERNEL_TYPE_SIN, + GGML_METAL_KERNEL_TYPE_COS, + GGML_METAL_KERNEL_TYPE_SUM_ROWS, + GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, + GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, + GGML_METAL_KERNEL_TYPE_ARGMAX, + + GGML_METAL_KERNEL_TYPE_COUNT +}; + +struct ggml_backend_metal_context { + id queue; + + dispatch_queue_t d_queue; + + struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; + + // capture state + bool capture_next_compute; + bool capture_started; + + id capture_scope; + + // command buffer state + int n_cb; // number of extra threads used to submit the command buffers + int n_nodes_0; // number of nodes submitted by the main thread + int n_nodes_1; // remaining number of nodes submitted by the n_cb threads + int n_nodes_per_cb; + + struct ggml_cgraph * gf; + + // the callback given to the thread pool + void (^encode_async)(size_t ith); + + // n_cb command buffers + 1 used by the main thread + id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + + // abort ggml_metal_graph_compute if callback returns true + ggml_abort_callback abort_callback; + void * abort_callback_data; +}; + +// MSL code +// TODO: move the contents here when ready +// for now it is easier to work in a separate file +// static NSString * const msl_library_source = @"see metal.metal"; + +// Here to assist with NSBundle Path Hack +@interface GGMLMetalClass : NSObject +@end +@implementation GGMLMetalClass +@end + +static void * ggml_metal_host_malloc(size_t n) { + void * data = NULL; + +#if TARGET_OS_OSX + kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); + if (err != KERN_SUCCESS) { + GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); + return NULL; + } +#else + const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); + if (result != 0) { + GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); + return NULL; + } +#endif + + return data; +} + +static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) { + GGML_LOG_INFO("%s: allocating\n", __func__); + +#if TARGET_OS_OSX && !GGML_METAL_NDEBUG + // Show all the Metal device instances in the system + NSArray * devices = MTLCopyAllDevices(); + for (id device in devices) { + GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); + } + [devices release]; // since it was created by a *Copy* C method +#endif + + // init context + struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context)); + struct ggml_backend_metal_device_context * ctx_dev = dev->context; + + id device = ggml_backend_metal_device_acq(ctx_dev); + GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); + + ctx->queue = [device newCommandQueue]; + if (ctx->queue == nil) { + GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); + return NULL; + } + + ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); + + id metal_library; + + // load library + // + // - first check if the library is embedded + // - then check if the library is in the bundle + // - if not found, load the source and compile it + // - if that fails, return NULL + { + NSBundle * bundle = nil; +#ifdef SWIFT_PACKAGE + bundle = SWIFTPM_MODULE_BUNDLE; +#else + bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; +#endif + + NSError * error = nil; + +#if GGML_METAL_EMBED_LIBRARY + const bool try_metallib = false; +#else + const bool try_metallib = true; +#endif + + NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; + if (path_lib == nil) { + // Try to find the resource in the directory where the current binary located. + NSString * current_binary = [[NSProcessInfo processInfo] arguments][0]; + NSString * bin_dir = [current_binary stringByDeletingLastPathComponent]; + NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]]; + if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { + GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]); + NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error]; + if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) { + // Optionally, if this is a symlink, try to resolve it. + default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error]; + if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) { + // It is a relative path, adding the binary directory as directory prefix. + default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]]; + } + if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { + // Link to the resource could not be resolved. + default_metallib_path = nil; + } else { + GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]); + } + } + } else { + // The resource couldn't be found in the binary's directory. + default_metallib_path = nil; + } + path_lib = default_metallib_path; + } + + if (try_metallib && path_lib != nil) { + // pre-compiled library found + NSURL * libURL = [NSURL fileURLWithPath:path_lib]; + GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); + + metal_library = [device newLibraryWithURL:libURL error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + } else { +#if GGML_METAL_EMBED_LIBRARY + GGML_LOG_INFO("%s: using embedded metal library\n", __func__); + + extern const char ggml_metallib_start[]; + extern const char ggml_metallib_end[]; + + NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding]; +#else + GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); + + NSString * path_source; + NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; + + GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); + + if (path_resource) { + path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"]; + } else { + path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; + } + + if (path_source == nil) { + GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); + path_source = @"ggml-metal.metal"; + } + + GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); + + NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } +#endif // GGML_METAL_EMBED_LIBRARY + + @autoreleasepool { + // dictionary of preprocessor macros + NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + + if (ctx_dev->use_bfloat) { + [prep setObject:@"1" forKey:@"GGML_METAL_USE_BF16"]; + } + +#if GGML_METAL_EMBED_LIBRARY + [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"]; +#endif + + MTLCompileOptions * options = [MTLCompileOptions new]; + options.preprocessorMacros = prep; + + //[options setFastMathEnabled:false]; + + metal_library = [device newLibraryWithSource:src options:options error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } + +#if !__has_feature(objc_arc) + [options release]; +#endif + } +#if GGML_METAL_EMBED_LIBRARY + [src release]; +#endif // GGML_METAL_EMBED_LIBRARY + } + } + + // print MTL GPU family: + GGML_LOG_INFO("%s: GPU name: %s\n", __func__, [[device name] UTF8String]); + + // determine max supported GPU family + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf + { + for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { + if ([device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { + if ([device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); + break; + } + } + + for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) { + if ([device supportsFamily:i]) { + GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i); + break; + } + } + } + + GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false"); + GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false"); + GGML_LOG_INFO("%s: has residency sets = %s\n", __func__, ctx_dev->has_residency_sets ? "true" : "false"); + GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false"); + GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); + + ctx->capture_next_compute = false; + ctx->capture_started = false; + ctx->capture_scope = nil; + + ctx->gf = nil; + ctx->encode_async = nil; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + ctx->command_buffers[i] = nil; + } + +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6); + } +#endif + + // load kernels + { + NSError * error = nil; + + for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { + ctx->kernels[i].pipeline = nil; + } + +#define GGML_METAL_ADD_KERNEL(e, name, supported) \ + if (supported) { \ + struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ + id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ + kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ + GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ + [metal_function release]; \ + if (error) { \ + GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + [metal_library release]; \ + return NULL; \ + } \ + } else { \ + GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ + } + + const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; + const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; + const bool use_bfloat = ctx_dev->use_bfloat; + + // simd_sum and simd_max requires MTLGPUFamilyApple7 + + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M, get_rows_iq1_m, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32, mul_mv_iq1_m_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, has_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, mul_mv_id_bf16_f32, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112, flash_attn_ext_q4_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128, flash_attn_ext_q4_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112, flash_attn_ext_q4_1_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128, flash_attn_ext_q4_1_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112, flash_attn_ext_q5_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128, flash_attn_ext_q5_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112, flash_attn_ext_q5_1_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128, flash_attn_ext_q5_1_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112, flash_attn_ext_q8_0_h112, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); + } + + [metal_library release]; + + return ctx; +} + +static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { + GGML_LOG_INFO("%s: deallocating\n", __func__); + + for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { + [ctx->kernels[i].pipeline release]; + } + + Block_release(ctx->encode_async); + + [ctx->queue release]; + + dispatch_release(ctx->d_queue); + + free(ctx); +} + +// temporarily defined here for compatibility between ggml-backend and the old API + +struct ggml_backend_metal_buffer { + void * data; + size_t size; + + id metal; +}; + +struct ggml_backend_metal_buffer_context { + void * all_data; + size_t all_size; + bool owned; + + // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap + int n_buffers; + struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS]; + + // optional MTLResidencySet + id rset; +}; + +// rset init +static bool ggml_backend_metal_buffer_rset_init( + struct ggml_backend_metal_buffer_context * ctx, + struct ggml_backend_metal_device_context * ctx_dev, + id device) { + ctx->rset = nil; + + if (!ctx_dev->has_residency_sets) { + return true; + } + +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, *)) { + MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; + desc.label = @"ggml_backend_metal"; + desc.initialCapacity = ctx->n_buffers; + + NSError * error; + ctx->rset = [device newResidencySetWithDescriptor:desc error:&error]; + if (error) { + GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + [desc release]; + return false; + } + + [desc release]; + + for (int i = 0; i < ctx->n_buffers; i++) { + [ctx->rset addAllocation:ctx->buffers[i].metal]; + } + + [ctx->rset commit]; + [ctx->rset requestResidency]; + + return true; + } +#else + GGML_UNUSED(ctx_dev); + GGML_UNUSED(device); +#endif + + return true; +} + +// rset free +static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) { +#if defined(GGML_METAL_HAS_RESIDENCY_SETS) + if (@available(macOS 15.0, *)) { + if (ctx->rset) { + [ctx->rset endResidency]; + [ctx->rset removeAllAllocations]; + [ctx->rset release]; + } + } +#else + GGML_UNUSED(ctx); +#endif +} + +// finds the Metal buffer that contains the tensor data on the GPU device +// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the +// Metal buffer based on the host memory pointer +// +static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs) { + //GGML_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); + + const int64_t tsize = ggml_nbytes(t); + + ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; + + struct ggml_backend_metal_buffer_context * buf_ctx = (struct ggml_backend_metal_buffer_context *) buffer->context; + + // find the view that contains the tensor fully + for (int i = 0; i < buf_ctx->n_buffers; ++i) { + const int64_t ioffs = (int64_t) t->data - (int64_t) buf_ctx->buffers[i].data; + + //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf_ctx->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf_ctx->buffers[i].size); + if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf_ctx->buffers[i].size) { + *offs = (size_t) ioffs; + + //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); + + return buf_ctx->buffers[i].metal; + } + } + + GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); + + return nil; +} + +static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) { + const bool has_simdgroup_mm = ctx_dev->has_simdgroup_mm; + const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; + const bool use_bfloat = ctx_dev->use_bfloat; + + if (!use_bfloat) { + for (size_t i = 0, n = 3; i < n; ++i) { + if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { + return false; + } + } + } + + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_ELU: + return ggml_is_contiguous(op->src[0]); + default: + return false; + } + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + case GGML_OP_CONCAT: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_ACC: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_REPEAT: + case GGML_OP_SCALE: + case GGML_OP_CLAMP: + case GGML_OP_CONV_TRANSPOSE_1D: + return true; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_SUM_ROWS: + case GGML_OP_SOFT_MAX: + case GGML_OP_GROUP_NORM: + return has_simdgroup_reduction; + case GGML_OP_RMS_NORM: + return has_simdgroup_reduction && (op->ne[0] % 4 == 0); + case GGML_OP_ARGMAX: + case GGML_OP_NORM: + return true; + case GGML_OP_ROPE: + { + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } + return true; + } + case GGML_OP_IM2COL: + return op->src[0]->type == GGML_TYPE_F16; + case GGML_OP_POOL_1D: + return false; + case GGML_OP_POOL_2D: + case GGML_OP_UPSCALE: + case GGML_OP_PAD: + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_ARGSORT: + case GGML_OP_LEAKY_RELU: + return true; + case GGML_OP_FLASH_ATTN_EXT: + if (op->src[1]->type != op->src[2]->type) { + return false; + } + return has_simdgroup_mm; // TODO: over-restricted for vec-kernels + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + return true; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + return has_simdgroup_reduction && + (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + } + case GGML_TYPE_F16: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + default: + return false; + } + case GGML_TYPE_BF16: + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_BF16: + return true; + default: + return false; + } + default: + return false; + }; + } + case GGML_OP_SET: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_I32: + return true; + default: + return false; + }; + } + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_GET_ROWS: + { + return op->ne[3] == 1; + } + default: + return false; + } +} + +static void ggml_metal_encode_node( + ggml_backend_t backend, + int idx, + id encoder) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + struct ggml_cgraph * gf = ctx->gf; + + struct ggml_tensor * node = ggml_graph_node(gf, idx); + + //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); + + struct ggml_tensor * src0 = node->src[0]; + struct ggml_tensor * src1 = node->src[1]; + struct ggml_tensor * src2 = node->src[2]; + struct ggml_tensor * dst = node; + + if (ggml_is_empty(dst)) { + return; + } + + switch (dst->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_TRANSPOSE: + case GGML_OP_PERMUTE: + { + // noop -> next node + } return; + default: + { + } break; + } + + if (!ggml_metal_supports_op(ctx_dev, dst)) { + GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); + GGML_ABORT("unsupported op"); + } + + const int64_t ne00 = src0 ? src0->ne[0] : 0; + const int64_t ne01 = src0 ? src0->ne[1] : 0; + const int64_t ne02 = src0 ? src0->ne[2] : 0; + const int64_t ne03 = src0 ? src0->ne[3] : 0; + + const uint64_t nb00 = src0 ? src0->nb[0] : 0; + const uint64_t nb01 = src0 ? src0->nb[1] : 0; + const uint64_t nb02 = src0 ? src0->nb[2] : 0; + const uint64_t nb03 = src0 ? src0->nb[3] : 0; + + const int64_t ne10 = src1 ? src1->ne[0] : 0; + const int64_t ne11 = src1 ? src1->ne[1] : 0; + const int64_t ne12 = src1 ? src1->ne[2] : 0; + const int64_t ne13 = src1 ? src1->ne[3] : 0; + + const uint64_t nb10 = src1 ? src1->nb[0] : 0; + const uint64_t nb11 = src1 ? src1->nb[1] : 0; + const uint64_t nb12 = src1 ? src1->nb[2] : 0; + const uint64_t nb13 = src1 ? src1->nb[3] : 0; + + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); + + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT; + + size_t offs_src0 = 0; + size_t offs_src1 = 0; + size_t offs_src2 = 0; + size_t offs_dst = 0; + + id id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil; + id id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil; + id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; + id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + +#if 0 + GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + if (src0) { + GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, + ggml_is_contiguous(src0), src0->name); + } + if (src1) { + GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, + ggml_is_contiguous(src1), src1->name); + } + if (dst) { + GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, + dst->name); + } +#endif + + id device = ctx_dev->mtl_device; + + switch (dst->op) { + case GGML_OP_CONCAT: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + + const int32_t dim = ((const int32_t *) dst->op_params)[0]; + + ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + const size_t offs = 0; + + bool bcast_row = false; + + id pipeline = nil; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + bcast_row = true; + } else { + switch (dst->op) { + case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; + case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break; + case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; + case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; + default: GGML_ABORT("fatal error"); + } + } + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ offs, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + } break; + case GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ACC: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(dstt == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t pnb1 = ((const int32_t *) dst->op_params)[0]; + const size_t pnb2 = ((const int32_t *) dst->op_params)[1]; + const size_t pnb3 = ((const int32_t *) dst->op_params)[2]; + const size_t offs = ((const int32_t *) dst->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) dst->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); + + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + n /= 4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_CLAMP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline; + + float min; + float max; + memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + // we are not taking into account the strides, so for now require contiguous tensors + GGML_ASSERT(ggml_is_contiguous(src0)); + + case GGML_UNARY_OP_TANH: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_RELU: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_GELU_QUICK: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SILU: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_ELU: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ELU].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + default: + { + GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } break; + case GGML_OP_SQR: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SQRT: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQRT].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SIN: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_COS: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SUM_ROWS: + { + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFT_MAX: + { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + + int nth = 32; // SIMD width + + id pipeline = nil; + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + if (ne00%4 == 0) { + while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline; + } + } else { + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + if (use_f16) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline; + } + } + + float scale; + float max_bias; + + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // TODO: add ggml_metal_kargs struct + // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + if (id_src1) { + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; + [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_DIAG_MASK_INF: + { + const int n_past = ((const int32_t *)(dst->op_params))[0]; + + id pipeline = nil; + + if (ne00%8 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; + } + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; + + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + } break; + case GGML_OP_SSM_CONV: + { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SSM_SCAN: + { + struct ggml_tensor * src3 = node->src[3]; + struct ggml_tensor * src4 = node->src[4]; + struct ggml_tensor * src5 = node->src[5]; + + GGML_ASSERT(src3); + GGML_ASSERT(src4); + GGML_ASSERT(src5); + + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; + id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + + const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); + + const uint64_t nb30 = src3->nb[0]; + const uint64_t nb31 = src3->nb[1]; + + const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); + const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + + const uint64_t nb40 = src4->nb[0]; + const uint64_t nb41 = src4->nb[1]; + const uint64_t nb42 = src4->nb[2]; + + const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); + const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); + const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + + const uint64_t nb50 = src5->nb[0]; + const uint64_t nb51 = src5->nb[1]; + const uint64_t nb52 = src5->nb[2]; + + const int64_t d_state = ne00; + const int64_t d_inner = ne01; + const int64_t n_seq_tokens = ne11; + const int64_t n_seqs = ne02; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; + + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_MUL_MAT: + { + GGML_ASSERT(ne00 == ne10); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const uint32_t r2 = ne12/ne02; + const uint32_t r3 = ne13/ne03; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + const int ne11_mm_min = 4; + + // first try to use small-batch mat-mv kernels + // these should be efficient for BS [2, ~8] + if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) && + ( + ( + ( + src0t == GGML_TYPE_F16 || // TODO: helper function + src0t == GGML_TYPE_Q4_0 || + src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || + src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_IQ4_NL || + false) && (ne11 >= 2 && ne11 <= 8) + ) || + ( + ( + src0t == GGML_TYPE_Q4_K || + src0t == GGML_TYPE_Q5_K || + src0t == GGML_TYPE_Q6_K || + false) && (ne11 >= 4 && ne11 <= 8) + ) + ) + ) { + // TODO: determine the optimal parameters based on grid utilization + // I still don't know why we should not always use the maximum available threads: + // + // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 + // + // my current hypothesis is that the work grid is not evenly divisible for different nsg + // values and there can be some tail effects when nsg is high. need to confirm this + // + const int nsg = 2; // num simdgroups per threadgroup + const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup + const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) + const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup + int r1ptg = 4; // num src1 rows per threadgroup + + // note: not sure how optimal are those across all different hardware. there might be someting cleverer + switch (ne11) { + case 2: + r1ptg = 2; break; + case 3: + case 6: + r1ptg = 3; break; + case 4: + case 7: + case 8: + r1ptg = 4; break; + case 5: + r1ptg = 5; break; + }; + + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F16: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q4_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q4_1: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_1: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q8_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q4_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q5_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_Q6_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + case GGML_TYPE_IQ4_NL: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break; + default: GGML_ABORT("not implemented"); + } break; + default: GGML_ABORT("not implemented"); + } + + ggml_metal_kargs_mul_mv_ext args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + /*.nsg =*/ nsg, + /*.nxpsg =*/ nxpsg, + /*.r1ptg =*/ r1ptg, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg); + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + if ([device supportsFamily:MTLGPUFamilyApple7] && + !ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && + src1t == GGML_TYPE_F32 && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { + //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; + } + + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + default: GGML_ABORT("MUL MAT-MAT not implemented"); + } + + ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; + nrows = 4; + } break; + case GGML_TYPE_F16: + { + nth0 = 32; + nth1 = 1; + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; + nrows = ne11; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; + nrows = 4; + } + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; + nrows = 4; + } + } break; + case GGML_TYPE_BF16: + { + nth0 = 32; + nth1 = 1; + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; + nrows = ne11; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; + nrows = 4; + } + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; + nrows = 4; + } + } break; + case GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; + } break; + case GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; + } break; + case GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; + } break; + case GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; + } break; + case GGML_TYPE_Q8_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; + } break; + case GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; + } break; + case GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; + } break; + case GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; + } break; + case GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; + } break; + case GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; + } break; + case GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_M: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; + } break; + case GGML_TYPE_IQ4_NL: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; + } break; + case GGML_TYPE_IQ4_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; + } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); + GGML_ABORT("not implemented"); + } + }; + + ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { + const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { + const int mem_size = 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q3_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (ne11 + nrows - 1)/nrows; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case GGML_OP_MUL_MAT_ID: + { + const int n_as = src0->ne[2]; + + // src2 = ids + const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); + + GGML_ASSERT(src2t == GGML_TYPE_I32); + + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src1t == GGML_TYPE_F32); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + // ne20 = n_used_experts + // ne21 = n_rows + const int dst_rows = ne20*ne21; + const int dst_rows_min = n_as; + const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4; + + // max size of the rowids array in the kernel shared buffer + GGML_ASSERT(dst_rows <= dst_rows_max); + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + // !!! + // TODO: for now, always use mat-vec kernels until we figure out how to improve the + // indirect matrix multiplication + // !!! + if ([device supportsFamily:MTLGPUFamilyApple7] && + ne00 % 32 == 0 && ne00 >= 64 && + dst_rows > dst_rows_min) { + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; + default: break; + } + + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + ggml_metal_kargs_mul_mm_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; + + [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; + } break; + case GGML_TYPE_BF16: + { + GGML_ASSERT(src1t == GGML_TYPE_F32); + nth0 = 32; + nth1 = 1; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; + } break; + case GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; + } break; + case GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; + } break; + case GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; + } break; + case GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; + } break; + case GGML_TYPE_Q8_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; + } break; + case GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; + } break; + case GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; + } break; + case GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; + } break; + case GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; + } break; + case GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; + } break; + case GGML_TYPE_IQ3_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; + } break; + case GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; + } break; + case GGML_TYPE_IQ1_M: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; + } break; + case GGML_TYPE_IQ4_NL: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; + } break; + case GGML_TYPE_IQ4_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; + } break; + default: + { + GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t); + GGML_ABORT("not implemented"); + } + }; + + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nth0*nth1); + } + + ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; + + const int64_t _ne1 = 1; + const int tgz = dst_rows; + + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || + src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { + const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { + const int mem_size = 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q3_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } + } break; + case GGML_OP_GET_ROWS: + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_M ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; + default: GGML_ABORT("not implemented"); + } + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + } break; + case GGML_OP_RMS_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_GROUP_NORM: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + + const int32_t n_groups = ((const int32_t *) dst->op_params)[0]; + + int nth = 32; // SIMD width + + //while (nth < ne00/4 && nth < 1024) { + // nth *= 2; + //} + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8]; + [encoder setBytes:&eps length:sizeof( float) atIndex:9]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ROPE: + { + // make sure we have one or more position id(ne10) per token(ne02) + GGML_ASSERT(ne10 % ne02 == 0); + GGML_ASSERT(ne10 >= ne02); + + const int nth = MIN(1024, ne00); + + const int n_past = ((const int32_t *) dst->op_params)[0]; + const int n_dims = ((const int32_t *) dst->op_params)[1]; + const int mode = ((const int32_t *) dst->op_params)[2]; + // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal + const int n_ctx_orig = ((const int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + + id pipeline = nil; + + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } + + ggml_metal_kargs_rope args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_IM2COL: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; + + const int32_t N = src1->ne[is_2D ? 3 : 2]; + const int32_t IC = src1->ne[is_2D ? 2 : 1]; + const int32_t IH = is_2D ? src1->ne[1] : 1; + const int32_t IW = src1->ne[0]; + + const int32_t KH = is_2D ? src0->ne[1] : 1; + const int32_t KW = src0->ne[0]; + + const int32_t OH = is_2D ? dst->ne[2] : 1; + const int32_t OW = dst->ne[1]; + + const int32_t CHW = IC * KH * KW; + + const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; + const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; + + const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup; + + switch (dst->type) { + case GGML_TYPE_F32: { + pipeline = (is_gt_mttpt ? + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline + : + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline); + } break; + case GGML_TYPE_F16: { + pipeline = (is_gt_mttpt ? + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline + : + ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline); + } break; + default: GGML_ABORT("fatal error"); + }; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2]; + [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3]; + [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4]; + [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5]; + [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6]; + [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7]; + [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8]; + [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9]; + [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10]; + [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11]; + [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12]; + + if (is_gt_mttpt) { + [encoder setBytes:&N length:sizeof(int32_t) atIndex:13]; + [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14]; + [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15]; + + const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N); + + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } else { + [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; + } + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + + const int32_t IC = src1->ne[1]; + const int32_t IL = src1->ne[0]; + + const int32_t K = src0->ne[0]; + + const int32_t OL = dst->ne[0]; + const int32_t OC = dst->ne[1]; + + id pipeline; + + switch (src0->type) { + case GGML_TYPE_F32: { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline; + } break; + case GGML_TYPE_F16: { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline; + } break; + default: GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&K length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8]; + + [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_UPSCALE: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const float sf0 = (float)ne0/src0->ne[0]; + const float sf1 = (float)ne1/src0->ne[1]; + const float sf2 = (float)ne2/src0->ne[2]; + const float sf3 = (float)ne3/src0->ne[3]; + + const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18]; + [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19]; + [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20]; + [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_PAD: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_PAD_REFLECT_1D: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int32_t p0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[1]; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14]; + [encoder setBytes:&p0 length:sizeof(p0) atIndex:15]; + [encoder setBytes:&p1 length:sizeof(p1) atIndex:16]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARANGE: + { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float start; + float step; + + memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; + [encoder setBytes:&start length:sizeof(start) atIndex:2]; + [encoder setBytes:&step length:sizeof(step) atIndex:3]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + const int half = dim / 2; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:3]; + [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4]; + + const int nth = MIN(1024, half); + + [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_ARGSORT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + // Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16); + + id pipeline = nil; + + switch (order) { + case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; + case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; + } break; + case GGML_OP_LEAKY_RELU: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + float slope; + memcpy(&slope, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&slope length:sizeof(slope) atIndex:2]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne11 % 32 == 0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == src2->type); + + GGML_ASSERT(ggml_are_same_shape (src1, src2)); + + struct ggml_tensor * src3 = node->src[3]; + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + //const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + float max_bias; + float logit_softcap; + memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); + memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); + memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + id pipeline = nil; + + bool use_vec_kernel = false; + + // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) + // for now avoiding mainly to keep the number of templates/kernels a bit lower + if (ne01 >= 4 || (ne00%128 != 0)) { + switch (src1->type) { + case GGML_TYPE_F16: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } break; + case GGML_TYPE_BF16: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } break; + case GGML_TYPE_Q4_0: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } break; + case GGML_TYPE_Q4_1: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } break; + case GGML_TYPE_Q5_0: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } break; + case GGML_TYPE_Q5_1: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } break; + case GGML_TYPE_Q8_0: + { + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } else { + use_vec_kernel = true; + + switch (ne00) { + case 128: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; + case 256: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; + default: + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } + + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.nb_12_1 =*/ nb11, + /*.nb_12_2 =*/ nb12, + /*.nb_12_3 =*/ nb13, + /*.nb31 =*/ nb31, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; + if (id_src3) { + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; + + if (!use_vec_kernel) { + // half8x8 kernel + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // 2*(2*ncpsg + nqptg)*(nsg) + // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) + // + // 16*32*(nsg) + // the shared memory needed for the simdgroups to load the KV cache + // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + if (smem > device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else { + // half4x4 kernel + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 1 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // ne00 + 2*ncpsg*(nsg) + // for each query, we load it as f16 in shared memory (ne00) + // and store the soft_max values and the mask + // + // ne00*(nsg) + // each simdgroup has a full f16 head vector in shared mem to accumulate results + // +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16)) + + int64_t nsgmax = 2; + + while (true) { + const size_t smem = FATTN_SMEM(nsgmax); + if (smem > device.maxThreadgroupMemoryLength) { + break; + } + nsgmax *= 2; + } + nsgmax /= 2; + + // simdgroups per threadgroup (a.k.a. warps) + const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); + + int64_t nsg = 1; + while (nsg <= nsgt) { + nsg *= 2; + } + nsg /= 2; + + const size_t smem = FATTN_SMEM(nsg); + + //printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg); + GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; +#undef FATTN_SMEM + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } + } break; + case GGML_OP_DUP: + case GGML_OP_CPY: + case GGML_OP_CONT: + { + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); + + id pipeline = nil; + + switch (src0t) { + case GGML_TYPE_F32: + { + GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); + + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_F16: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + default: GGML_ABORT("not implemented"); + }; + } break; + case GGML_TYPE_BF16: + { + switch (dstt) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; + default: GGML_ASSERT(false && "not implemented"); + }; + } break; + default: GGML_ABORT("not implemented"); + } + + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_SET: + { + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + + // src0 and dst as viewed during set + const size_t dst_nb0 = ggml_element_size(src0); + + const size_t dst_nb1 = ((int32_t *) dst->op_params)[0]; + const size_t dst_nb2 = ((int32_t *) dst->op_params)[1]; + const size_t dst_nb3 = ((int32_t *) dst->op_params)[2]; + const size_t offset = ((int32_t *) dst->op_params)[3]; + const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + memcpy(((char *) dst->data), ((char *) src0->data), ggml_nbytes(dst)); + } + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= ggml_nbytes(dst)); + + id pipeline = nil; + + switch (src0t) { + case GGML_TYPE_F32: + GGML_ASSERT(nb10 == sizeof(float)); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break; + case GGML_TYPE_I32: + GGML_ASSERT(nb10 == sizeof(int32_t)); + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break; + default: GGML_ABORT("fatal error"); + } + + ggml_metal_kargs_set args = { + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ dst_nb1, + /*.nb2 =*/ dst_nb2, + /*.nb3 =*/ dst_nb3, + /*.offs =*/ offset, + /*.inplace =*/ inplace, + }; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_POOL_2D: + { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt); + + const int32_t * opts = dst->op_params; + enum ggml_op_pool op = opts[0]; + + id pipeline = nil; + switch (src0t) { + case GGML_TYPE_F32: { + switch(op) { + case GGML_OP_POOL_AVG: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break; + case GGML_OP_POOL_MAX: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break; + default: GGML_ASSERT(false && "not implemented"); + } + } break; + default: GGML_ASSERT(false && "not implemented"); + } + + const int32_t k0 = opts[1]; + const int32_t k1 = opts[2]; + const int32_t s0 = opts[3]; + const int32_t s1 = opts[4]; + const int32_t p0 = opts[5]; + const int32_t p1 = opts[6]; + + const int64_t IH = src0->ne[1]; + const int64_t IW = src0->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int64_t parallel_elements = N * OC * OH * OW; + const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); + const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + + // TODO: add ggml_metal_kargs struct + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2]; + [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3]; + [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4]; + [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5]; + [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6]; + [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7]; + [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8]; + [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9]; + [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10]; + [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11]; + [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12]; + + [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; + } break; + case GGML_OP_ARGMAX: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + + const int64_t nrows = ggml_nrows(src0); + + int nth = 32; // SIMD width + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + default: + { + GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); + GGML_ABORT("fatal error"); + } + } +} + +static enum ggml_status ggml_metal_graph_compute( + ggml_backend_t backend, + struct ggml_cgraph * gf) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + // number of nodes encoded by the main thread (empirically determined) + const int n_main = 128; + + // number of threads in addition to the main thread + const int n_cb = ctx->n_cb; + + // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them + // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread + // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes + // each thread creates it's own command buffer and enqueues the ops in parallel + // + // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 + + @autoreleasepool { + ctx->gf = gf; + + ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); + ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; + + ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; + + const bool should_capture = ctx->capture_next_compute; + if (should_capture) { + ctx->capture_next_compute = false; + + if (!ctx->capture_started) { + // create capture scope + ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device]; + + MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; + descriptor.captureObject = ctx->capture_scope; + descriptor.destination = MTLCaptureDestinationGPUTraceDocument; + descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + + NSError * error = nil; + if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { + GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); + } else { + [ctx->capture_scope beginScope]; + ctx->capture_started = true; + } + } + } + + // the main thread commits the first few commands immediately + // command_buffer[n_cb] + { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[n_cb] = command_buffer; + + [command_buffer enqueue]; + ctx->encode_async(n_cb); + } + + // prepare the rest of the command buffers asynchronously + // command_buffer[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->command_buffers[cb_idx] = command_buffer; + + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer enqueue]; + } + } + + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id command_buffer = ctx->command_buffers[n_cb]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + + for (int i = 0; i < n_cb; ++i) { + id command_buffer = ctx->command_buffers[i]; + [command_buffer waitUntilCompleted]; + + MTLCommandBufferStatus status = [command_buffer status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + + id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); + if (!next_buffer) { + continue; + } + + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; + } + + if (!should_capture && ctx->capture_started) { + [ctx->capture_scope endScope]; + [[MTLCaptureManager sharedCaptureManager] stopCapture]; + } + } + + return GGML_STATUS_SUCCESS; +} + +//////////////////////////////////////////////////////////////////////////////// + +// backend interface + +static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + for (int i = 0; i < ctx->n_buffers; i++) { + [ctx->buffers[i].metal release]; + } + + ggml_backend_metal_buffer_rset_free(ctx); + ggml_backend_metal_device_rel(buffer->buft->device->context); + + if (ctx->owned) { +#if TARGET_OS_OSX + vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size); +#else + free(ctx->all_data); +#endif + } + + free(ctx); +} + +static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + return ctx->all_data; +} + +static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + memcpy((char *)tensor->data + offset, data, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + memcpy(data, (const char *)tensor->data + offset, size); + + GGML_UNUSED(buffer); +} + +static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; + } + return false; + + GGML_UNUSED(buffer); +} + +static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context; + + memset(ctx->all_data, value, ctx->all_size); +} + +static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = { + /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_clear, + /* .reset = */ NULL, +}; + +// default buffer type + +static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "Metal"; + + GGML_UNUSED(buft); +} + +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG +#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) + if (@available(macOS 10.12, iOS 16.0, *)) { + GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0, + device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + + if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { + GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); + } + } else { + GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); + } +#endif +#endif + GGML_UNUSED(device); + GGML_UNUSED(size_aligned); +} + +static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); + + const size_t size_page = sysconf(_SC_PAGESIZE); + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context; + id device = ggml_backend_metal_device_acq(ctx_dev); + + ctx->all_data = ggml_metal_host_malloc(size_aligned); + ctx->all_size = size_aligned; + ctx->owned = true; + ctx->n_buffers = 1; + + if (ctx->all_data != NULL) { + ctx->buffers[0].data = ctx->all_data; + ctx->buffers[0].size = size; + ctx->buffers[0].metal = nil; + + if (size_aligned > 0) { + ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data + length:size_aligned + options:MTLResourceStorageModeShared + deallocator:nil]; + } + } + + if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + //ggml_backend_metal_log_allocated_size(device, size_aligned); + + return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); +} + +static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 32; + GGML_UNUSED(buft); +} + +static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + id device = ggml_backend_metal_device_acq(buft->device->context); + const size_t max_size = device.maxBufferLength; + ggml_backend_metal_device_rel(buft->device->context); + + return max_size; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return true; + + GGML_UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_metal_buffer_type_is_host, + }, + /* .device = */ &g_ggml_backend_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_type_metal; +} + +static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) { + return "Metal_Mapped"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) { + static struct ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_metal = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_from_ptr_type_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes + /* .is_host = */ ggml_backend_metal_buffer_type_is_host, + }, + /* .device = */ &g_ggml_backend_metal_device, + /* .context = */ NULL, + }; + + return &ggml_backend_buffer_from_ptr_type_metal; +} + +// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr +ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) { + struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); + + ctx->all_data = data; + ctx->all_size = size; + ctx->owned = false; + ctx->n_buffers = 0; + + const size_t size_page = sysconf(_SC_PAGESIZE); + + // page-align the data ptr + { + const uintptr_t offs = (uintptr_t) data % size_page; + data = (void *) ((char *) data - offs); + size += offs; + } + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main; + id device = ggml_backend_metal_device_acq(ctx_dev); + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= device.maxBufferLength) { + ctx->buffers[ctx->n_buffers].data = data; + ctx->buffers[ctx->n_buffers].size = size; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + return false; + } + } + + ggml_backend_metal_log_allocated_size(device, size_aligned); + + ++ctx->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = device.maxBufferLength - size_ovlp; + const size_t size_view = device.maxBufferLength; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i); + ctx->buffers[ctx->n_buffers].size = size_step_aligned; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_step_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + return false; + } + } + + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + + if (i + size_step < size) { + GGML_LOG_INFO("\n"); + } + + ++ctx->n_buffers; + } + } + + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); +} + +// backend + +static const char * ggml_backend_metal_name(ggml_backend_t backend) { + return "Metal"; + + GGML_UNUSED(backend); +} + +static void ggml_backend_metal_free(ggml_backend_t backend) { + struct ggml_backend_metal_context * ctx = backend->context; + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + ggml_backend_metal_device_rel(ctx_dev); + ggml_metal_free(ctx); + + free(backend); +} + +static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + return ggml_metal_graph_compute(backend, cgraph); +} + +static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + + if (ctx->n_cb != n_cb) { + ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); + + if (ctx->n_cb > 2) { + GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); + } + } + + if (ctx->encode_async) { + Block_release(ctx->encode_async); + } + + ctx->encode_async = Block_copy(^(size_t iter) { + const int cb_idx = iter; + const int n_cb_l = ctx->n_cb; + + const int n_nodes_0 = ctx->n_nodes_0; + const int n_nodes_1 = ctx->n_nodes_1; + + const int n_nodes_per_cb = ctx->n_nodes_per_cb; + + id command_buffer = ctx->command_buffers[cb_idx]; + id encoder = [command_buffer computeCommandEncoder]; + + int node_start = 0; + int node_end = n_nodes_0; + + if (cb_idx < n_cb_l) { + node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); + node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); + } + + const bool should_capture = ctx->capture_next_compute; + + for (int idx = node_start; idx < node_end; ++idx) { + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; + } + + ggml_metal_encode_node(backend, idx, encoder); + + if (should_capture) { + [encoder popDebugGroup]; + } + } + + [encoder endEncoding]; + + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer commit]; + } + }); +} + +static struct ggml_backend_i ggml_backend_metal_i = { + /* .get_name = */ ggml_backend_metal_name, + /* .free = */ ggml_backend_metal_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_metal_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +static ggml_guid_t ggml_backend_metal_guid(void) { + static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 }; + return &guid; +} + +// TODO: remove in the future +ggml_backend_t ggml_backend_metal_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0); + + struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); + if (ctx == NULL) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); + + *backend = (struct ggml_backend) { + /* .guid = */ ggml_backend_metal_guid(), + /* .interface = */ ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; +} + +bool ggml_backend_is_metal(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid()); +} + +void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = user_data; +} + +bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + + return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; +} + +void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context; + ctx->capture_next_compute = true; +} + +// backend device + +static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { + return "Metal"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { + // acq/rel just to populate ctx->name in case it hasn't been done yet + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; + ggml_backend_metal_device_acq(ctx_dev); + ggml_backend_metal_device_rel(ctx_dev); + + return ctx_dev->name; +} + +static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + if (@available(macOS 10.12, iOS 16.0, *)) { + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; + id device = ggml_backend_metal_device_acq(ctx_dev); + + *total = device.recommendedMaxWorkingSetSize; + *free = *total - device.currentAllocatedSize; + + ggml_backend_metal_device_rel(ctx_dev); + } else { + *free = 1; + *total = 1; + } +} + +static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_metal_device_get_name(dev); + props->description = ggml_backend_metal_device_get_description(dev); + props->type = ggml_backend_metal_device_get_type(dev); + ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = (struct ggml_backend_dev_caps) { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { + struct ggml_backend_metal_context * ctx = ggml_metal_init(dev); + if (ctx == NULL) { + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); + return NULL; + } + + ggml_backend_t backend = malloc(sizeof(struct ggml_backend)); + + *backend = (struct ggml_backend) { + /* .guid = */ ggml_backend_metal_guid(), + /* .interface = */ ggml_backend_metal_i, + /* .device = */ dev, + /* .context = */ ctx, + }; + + ggml_backend_metal_set_n_cb(backend, 1); + + return backend; + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_metal_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context)); + + ctx->all_data = ptr; + ctx->all_size = size; + ctx->owned = false; + ctx->n_buffers = 0; + + const size_t size_page = sysconf(_SC_PAGESIZE); + + // page-align the data ptr + { + const uintptr_t offs = (uintptr_t) ptr % size_page; + ptr = (void *) ((char *) ptr - offs); + size += offs; + } + + size_t size_aligned = size; + if ((size_aligned % size_page) != 0) { + size_aligned += (size_page - (size_aligned % size_page)); + } + + struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context; + id device = ggml_backend_metal_device_acq(ctx_dev); + + // the buffer fits into the max buffer size allowed by the device + if (size_aligned <= device.maxBufferLength) { + ctx->buffers[ctx->n_buffers].data = ptr; + ctx->buffers[ctx->n_buffers].size = size; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); + return false; + } + } + + ggml_backend_metal_log_allocated_size(device, size_aligned); + + ++ctx->n_buffers; + } else { + // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into + // one of the views + const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case + const size_t size_step = device.maxBufferLength - size_ovlp; + const size_t size_view = device.maxBufferLength; + + for (size_t i = 0; i < size; i += size_step) { + const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); + + ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) ptr + i); + ctx->buffers[ctx->n_buffers].size = size_step_aligned; + ctx->buffers[ctx->n_buffers].metal = nil; + + if (size_step_aligned > 0) { + ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; + + if (ctx->buffers[ctx->n_buffers].metal == nil) { + GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); + return false; + } + } + + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + + if (i + size_step < size) { + GGML_LOG_INFO("\n"); + } + + ++ctx->n_buffers; + } + } + + if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) { + GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); + free(ctx); + ggml_backend_metal_device_rel(ctx_dev); + return NULL; + } + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size); +} + +static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + struct ggml_backend_metal_device_context * ctx_dev = dev->context; + + return ggml_metal_supports_op(ctx_dev, op); +} + +static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name || + buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name; + + GGML_UNUSED(dev); +} + +static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + return false; + + GGML_UNUSED(dev); + GGML_UNUSED(op); +} + +static struct ggml_backend_device_i ggml_backend_metal_device_i = { + /* .get_name = */ ggml_backend_metal_device_get_name, + /* .get_description = */ ggml_backend_metal_device_get_description, + /* .get_memory = */ ggml_backend_metal_device_get_memory, + /* .get_type = */ ggml_backend_metal_device_get_type, + /* .get_props = */ ggml_backend_metal_device_get_props, + /* .init_backend = */ ggml_backend_metal_device_init, + /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_metal_device_supports_op, + /* .supports_buft = */ ggml_backend_metal_device_supports_buft, + /* .offload_op = */ ggml_backend_metal_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend registry + +static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { + return "Metal"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + return &g_ggml_backend_metal_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static struct ggml_backend_feature g_ggml_backend_metal_features[] = { +#if defined(GGML_METAL_EMBED_LIBRARY) + { "EMBED_LIBRARY", "1" }, +#endif +#if defined(GGML_METAL_USE_BF16) + { "BF16", "1" }, +#endif + { nil, nil }, +}; + +static struct ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) { + return g_ggml_backend_metal_features; + + GGML_UNUSED(reg); +} + +static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "ggml_backend_get_features") == 0) { + return (void *)ggml_backend_metal_get_features; + } + + return NULL; + + GGML_UNUSED(reg); +} +static struct ggml_backend_reg_i ggml_backend_metal_reg_i = { + /* .get_name = */ ggml_backend_metal_reg_get_name, + /* .device_count = */ ggml_backend_metal_reg_device_count, + /* .device_get = */ ggml_backend_metal_reg_device_get, + /* .get_proc_address = */ ggml_backend_metal_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_metal_reg(void) { + // TODO: make this thread-safe somehow? + { + g_ggml_backend_metal_reg = (struct ggml_backend_reg) { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_metal_reg_i, + /* .context = */ NULL, + }; + + g_ggml_backend_metal_device = (struct ggml_backend_device) { + /* .iface = */ ggml_backend_metal_device_i, + /* .reg = */ &g_ggml_backend_metal_reg, + /* .context = */ &g_ggml_ctx_dev_main, + }; + } + + return &g_ggml_backend_metal_reg; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal new file mode 100644 index 000000000..44f04c909 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -0,0 +1,6735 @@ +#define GGML_COMMON_DECL_METAL +#define GGML_COMMON_IMPL_METAL +#if defined(GGML_METAL_EMBED_LIBRARY) +__embed_ggml-common.h__ +#else +// TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift +#include "../ggml-common.h" +#endif +#include "ggml-metal-impl.h" + +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } + +#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 + +// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf +// +// cmd: +// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal +// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal +// +#if __METAL_VERSION__ < 310 && defined(GGML_METAL_USE_BF16) +#undef GGML_METAL_USE_BF16 +#endif + +#if defined(GGML_METAL_USE_BF16) +typedef matrix bfloat4x4; +#endif + +constexpr constant static float kvalues_iq4nl_f[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} + +template +void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { + reg = (type4)(*(src + il)); +} + +#if defined(GGML_METAL_USE_BF16) +template +void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { + reg = (type4x4)(*src); +} +#endif + +template +void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md; + reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; + const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i = 0; i < 2; i++) { + reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md; + reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md; + } +} + +template +void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = il ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m; + reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 2); + const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; + const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; + const ushort mask1 = mask0 << 8; + + for (int i = 0; i < 2; i++) { + reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m; + reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m; + } +} + +template +void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg_f[i/2][2*(i%2) + 0] = d * x0 + md; + reg_f[i/2][2*(i%2) + 1] = d * x1 + md; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + md; + reg[2*ii + 1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + float4x4 reg_f; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg_f[i/2][2*(i%2) + 0] = d * x0 + m; + reg_f[i/2][2*(i%2) + 1] = d * x1 + m; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = (il/4) ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = (il/4) ? 4 : 0; + + const int gh_mv = (il/4) ? 12 : 0; + const int gh_bk = (il/4) ? 0 : 4; + + for (int ii = 0; ii < 2; ii++) { + int i = 2*(il%4) + ii; + + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[2*ii + 0] = d * x0 + m; + reg[2*ii + 1] = d * x1 + m; + } +} + +template +void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const float d = xb->d; + + float4x4 reg_f; + + for (int i = 0; i < 16; i++) { + reg_f[i/4][i%4] = (qs[i + 16*il] * d); + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const float d = xb->d; + + for (int i = 0; i < 4; i++) { + reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d); + } +} + +template +void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { + const float d = xb->d; + const float min = xb->dmin; + device const uint8_t * q = (device const uint8_t *)xb->qs; + float dl, ml; + uint8_t sc = xb->scales[il]; + + q = q + 32*(il/8) + 16*(il&1); + il = (il/2)%4; + + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * q = (device const uint8_t *)xb->qs; + device const uint8_t * h = (device const uint8_t *)xb->hmask; + device const int8_t * scales = (device const int8_t *)xb->scales; + + q = q + 32 * (il/8) + 16 * (il&1); + h = h + 16 * (il&1); + uint8_t m = 1 << (il/2); + uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ + ((il/4)>0 ? 12 : 3); + uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; + uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + const float ml = 4.f * dl; + + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; + + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); + } +} + +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + +template +void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) { + device const uchar * q = xb->qs; + + short is = (il/4) * 2; + q = q + (il/4) * 32 + 16 * (il&1); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.h; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il < 2 ? 0x0F : 0xF0; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * (q[i] & mask) - ml; + } +} + +template +void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { + device const uint8_t * q = xb->qs; + device const uint8_t * qh = xb->qh; + + short is = (il/4) * 2; + q = q + 32 * (il/4) + 16 * (il&1); + qh = qh + 16 * (il&1); + uint8_t ul = 1 << (il/2); + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const float d = il < 2 ? xb->d : xb->d / 16.f; + const float min = xb->dmin; + const float dl = d * sc[0]; + const float ml = min * sc[1]; + + const ushort mask = il<2 ? 0x0F : 0xF0; + const float qh_val = il<2 ? 16.f : 256.f; + for (int i = 0; i < 16; ++i) { + reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; + } +} + +template +void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { + const half d_all = xb->d; + device const uint8_t * ql = (device const uint8_t *)xb->ql; + device const uint8_t * qh = (device const uint8_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + + ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); + qh = qh + 32*(il/8) + 16*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const float coef = il>1 ? 1.f/16.f : 1.f; + const float ml = d_all * sc * 32.f; + const float dl = d_all * sc * coef; + for (int i = 0; i < 16; ++i) { + const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) + : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; + } +} + +template +void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. + device const uint16_t * q2 = xb->qs + 4*ib32; + const uint32_t aux32_g = q2[0] | (q2[1] << 16); + const uint32_t aux32_s = q2[2] | (q2[3] << 16); + thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; + const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint16_t * q2 = xb->qs + 4*ib32; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); + uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } + grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); + signs = ksigns_iq2xs[q2[2*il+1] >> 9]; + for (int i = 0; i < 8; ++i) { + reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * q3 = xb->qs + 8*ib32; + device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; + constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); + constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); + uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } + grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); + grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); + signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); + reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); + } +} + +template +void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 8*ib32; + device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); + constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); + constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); + grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +} + +template +void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const float d = xb->d; + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * signs = qs + QK_K/8; + const uint8_t qh = xb->qh[ib32] >> 4*il; + const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); + for (int i = 0; i < 8; ++i) { + reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); + reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); + } +} + +template +void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + const float d = xb->d; + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); + const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); + const uint16_t h = qh[ib32] >> 6*il; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml; + reg[1][i] = dl * (grid1[i] >> 4) + ml; + reg[2][i] = dl * (grid2[i] & 0xf) + ml; + reg[3][i] = dl * (grid2[i] >> 4) + ml; + } +} + +template +void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + device const uint16_t * sc = (device const uint16_t *)xb->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const float d = scale.f16; + + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint8_t * qh = xb->qh + 2*ib32 + il; + + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); + const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) + ml1; + reg[1][i] = dl * (grid1[i] >> 4) + ml1; + reg[2][i] = dl * (grid2[i] & 0xf) + ml2; + reg[3][i] = dl * (grid2[i] >> 4) + ml2; + } +} + +template +void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +template +void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) { + device const uint16_t * q4 = (device const uint16_t *)xb->qs; + const float d = xb->d; + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f; + reg[0] = d * kvalues_iq4nl_f[q8[0]]; + reg[1] = d * kvalues_iq4nl_f[q8[1]]; + reg[2] = d * kvalues_iq4nl_f[q8[2]]; + reg[3] = d * kvalues_iq4nl_f[q8[3]]; +} + +template +void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { + // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; + // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 + device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; + const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); + const float d = (float)xb->d * (ls - 32); + uint32_t aux32; + thread const uint8_t * q8 = (thread const uint8_t *)&aux32; + for (int i = 0; i < 4; ++i) { + aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; + reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; + reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; + reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; + reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; + } +} + +enum ggml_sort_order { + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, +}; + +// general-purpose kernel for addition, subtraction, multiplication and division of two tensors +// pros: works for non-contiguous tensors, supports broadcast across all dims +// cons: not very efficient +kernel void kernel_add( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +kernel void kernel_sub( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +kernel void kernel_mul( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +kernel void kernel_div( + constant ggml_metal_kargs_bin & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + } +} + +template +kernel void kernel_repeat( + constant ggml_metal_kargs_repeat & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + const int i03 = i3%args.ne03; + const int i02 = i2%args.ne02; + const int i01 = i1%args.ne01; + + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i00 = i0%args.ne00; + *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); + } +} + +typedef decltype(kernel_repeat) kernel_repeat_t; + +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] + src1[tpig % nb]; +} + +kernel void kernel_sub_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] - src1[tpig % nb]; +} + +kernel void kernel_mul_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] * src1[tpig % nb]; +} + +kernel void kernel_div_row( + constant ggml_metal_kargs_bin & args, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; + dst[tpig] = src0[tpig] / src1[tpig % nb]; +} + +kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( + device const float4 * src0, + device float4 * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_clamp( + device const float * src0, + device float * dst, + constant float & min, + constant float & max, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]); +} + +kernel void kernel_relu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = max(0.0f, src0[tpig]); +} + +kernel void kernel_sigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); +} + +kernel void kernel_tanh( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = precise::tanh(x); +} + +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + // BEWARE !!! + // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! + // This was observed with Falcon 7B and 40B models + // + dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + +kernel void kernel_elu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); +} + +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + +kernel void kernel_sqrt( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sqrt(src0[tpig]); +} + +kernel void kernel_sin( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = sin(src0[tpig]); +} + +kernel void kernel_cos( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = cos(src0[tpig]); +} + +kernel void kernel_sum_rows( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tpig[[thread_position_in_grid]]) { + int64_t i3 = tpig.z; + int64_t i2 = tpig.y; + int64_t i1 = tpig.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int64_t i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} + +template +kernel void kernel_soft_max( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + + // find the max value in the block + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float lsum = 0.0f; + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); + lsum += exp_psrc0; + pdst[i00] = exp_psrc0; + } + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + pdst[i00] *= inv_sum; + } +} + +template +kernel void kernel_soft_max_4( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant float & scale, + constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); + + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + + float slope = 1.0f; + + if (max_bias > 0.0f) { + const int64_t h = i02; + + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); + } + + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + + float max_val = simd_max(lmax); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = -INFINITY; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = max_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = buf[tiisg]; + max_val = simd_max(max_val); + } + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + // This barrier fixes a failing test + // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335 + threadgroup_barrier(mem_flags::mem_none); + + float sum = simd_sum(lsum); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[tiisg]; + sum = simd_sum(sum); + } + + const float inv_sum = 1.0f/sum; + + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + pdst4[i00] *= inv_sum; + } +} + +typedef decltype(kernel_soft_max) kernel_soft_max_t; +typedef decltype(kernel_soft_max_4) kernel_soft_max_4_t; + +template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max; +template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; +template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4; + +kernel void kernel_diag_mask_inf( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + const int64_t i02 = tpig[2]; + const int64_t i01 = tpig[1]; + const int64_t i00 = tpig[0]; + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_conv_f32 +// TODO: optimize +kernel void kernel_ssm_conv_f32( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = ne10; + //const int64_t ncs = ne00; + //const int64_t nr = ne01; + //const int64_t n_t = ne1; + //const int64_t n_s = ne2; + + device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02); + device const float * c = (device const float *) ((device const char *) src1 + ir*nb11); + device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// TODO: optimize +kernel void kernel_ssm_scan_f32( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i3 = tgpig.y; + + const int64_t nc = d_state; + //const int64_t nr = d_inner; + const int64_t n_t = n_seq_tokens; + //const int64_t n_s = n_seqs; + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); + device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); + device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); + device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); + device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); + + if (i2 > 0) { + s0 = s; + } + + // i1 == 0 + float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + int64_t i = i0; + float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + } +} + +kernel void kernel_argmax( + device const void * x, + device int32_t * dst, + constant int64_t & ncols, + constant uint64_t & nb01, + threadgroup float * shared_maxval [[threadgroup(0)]], + threadgroup int32_t * shared_argmax [[threadgroup(1)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01); + + float lmax = -INFINITY; + int32_t larg = -1; + + for (int i00 = tpitg; i00 < ncols; i00 += ntg) { + if (x_row[i00] > lmax) { + lmax = x_row[i00]; + larg = i00; + } + } + + // find the argmax value in the block + float max_val = simd_max(lmax); + int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); + + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + shared_maxval[tiisg] = -INFINITY; + shared_argmax[tiisg] = -1; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shared_maxval[sgitg] = max_val; + shared_argmax[sgitg] = arg_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_val = shared_maxval[tiisg]; + arg_val = shared_argmax[tiisg]; + + float max_val_reduced = simd_max(max_val); + int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); + + dst[tgpig] = arg_val_reduced; + + return; + } + + dst[tgpig] = arg_val; +} + +kernel void kernel_norm( + constant ggml_metal_kargs_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float4 sumf4(0.0f); + + float sumf = 0.0f; + + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf4 += x[i00]; + } + sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + + sumf = 0.0f; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] - mean; + sumf += dot(y[i00], y[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float variance = sumf/args.ne00; + + const float scale = 1.0f/sqrt(variance + args.eps); + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = y[i00] * scale; + } +} + +kernel void kernel_rms_norm( + constant ggml_metal_kargs_rms_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; + + // parallel sum + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + const float scale = 1.0f/sqrt(mean + args.eps); + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + +kernel void kernel_group_norm( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int32_t & n_groups, + constant float & eps, + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t ne = ne00*ne01*ne02; + const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups); + + int start = tgpig * gs; + int end = start + gs; + + start += tpitg; + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += ntg) { + tmp += src0[j]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float mean = tmp / gs; + tmp = 0.0f; + + for (int j = start; j < end; j += ntg) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = simd_sum(tmp); + if (ntg > N_SIMDWIDTH) { + if (sgitg == 0) { + buf[tiisg] = 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + buf[sgitg] = tmp; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + tmp = buf[tiisg]; + tmp = simd_sum(tmp); + } + + const float variance = tmp / gs; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += ntg) { + dst[j] *= scale; + } +} + +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); + + for (int i = 0; i < 8; i += 2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + } + + return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); +} + +// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); + acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + } + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; +} + +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + + return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); + acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); + acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + + return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; +} + +// putting them in the kernel cause a significant performance penalty +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +//Note: This is a template, but strictly speaking it only applies to +// quantizations where the block size is 32. It also does not +// guard against the number of rows not being divisible by +// N_DST, so this is another explicit assumption of the implementation. +template +void mul_vec_q_n_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nb = args.ne00/QK4_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_q_type * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); + } + + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; + + device const float * yb = y + ix*QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + float sumy[2] = { 0.f, 0.f }; + +#pragma unroll + for (int i = 0; i < 8; i += 2) { + sumy[0] += yb[i + 0] + yb[i + 1]; + yl[i + 0] = yb[i + 0]; + yl[i + 1] = yb[i + 1]/256.f; + + sumy[1] += yb[i + 16] + yb[i + 17]; + yl[i + 8] = yb[i + 16]/16.f; + yl[i + 9] = yb[i + 17]/4096.f; + } + +#pragma unroll + for (int row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); + } + + yb += QK4_0 * 16; + } + + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +kernel void kernel_mul_mv_q4_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +kernel void kernel_mul_mv_q4_1_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +kernel void kernel_mul_mv_q5_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +#define NB_Q8_0 8 + +template +void kernel_mul_mv_q8_0_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + + const int nb = args.ne00/QK8_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0*nsg + sgitg)*nr; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_q8_0 * ax[nr]; + for (int row = 0; row < nr; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); + } + + float yl[NB_Q8_0]; + float sumf[nr] = { 0.f }; + + const short ix = tiisg/4; + const short il = tiisg%4; + + device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; + + // each thread in a SIMD group deals with NB_Q8_0 quants at a time + for (int ib = ix; ib < nb; ib += nw/4) { + for (short i = 0; i < NB_Q8_0; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; + float sumq = 0.f; + for (short iq = 0; iq < NB_Q8_0; ++iq) { + sumq += qs[iq] * yl[iq]; + } + sumf[row] += sumq*ax[row][ib].d; + } + + yb += nw*NB_Q8_0; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q8_0_f32")]] +kernel void kernel_mul_mv_q8_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +// mat-vec kernel processing in chunks of float4 +// chpb - chunks per quantization block +template +void kernel_mul_mv_ext_q4_f32_impl( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short chpt = 4; // chunks per thread + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg); + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i11 = tgpig.y*r1ptg; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; + + device const float4 * y4[r1ptg]; + + for (int ir1 = 0; ir1 < r1ptg; ++ir1) { + y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1; + } + + float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; + + short cch = tx%chpb; // current chunk index + + for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) { + float4 lx[chpt]; + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { + deq_t4(xq, cch, lx[ch]); + + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; + } + } + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); + + } + } + +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + y4[ir1] += chpt*nxpsg; + } + } + + // reduce only the threads in each row + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + if (nxpsg >= 32) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); + } + if (nxpsg >= 16) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); + } + if (nxpsg >= 8) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); + } + if (nxpsg >= 4) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); + } + if (nxpsg >= 2) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); + } + + //sumf[ir1] = simd_sum(sumf[ir1]); + } + + if (tx == 0) { + for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; + + if (i01 < args.ne01) { + dst_f32[i01] = sumf[ir1]; + } + } + } +} + +// mat-vec kernel processing in chunks of float4x4 +template +void kernel_mul_mv_ext_q4x4_f32_impl( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short chpt = 1; + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg); + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty; + const int i11 = tgpig.y*r1ptg; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; + + device const float4x4 * y4x4[r1ptg]; + + for (int ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1; + } + + float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; + + short cch = tx%chpb; + + for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) { + float4x4 lx[chpt]; + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { + deq_t4x4(xq, cch, lx[ch]); + + cch += nxpsg; + if (cch >= chpb) { + xq += cch/chpb; + cch %= chpb; + } + } + +#pragma unroll(chpt) + for (short ch = 0; ch < chpt; ++ch) { +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + sumf[ir1] += + dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) + + dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) + + dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) + + dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]); + + } + } + +#pragma unroll(r1ptg) + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + y4x4[ir1] += chpt*nxpsg; + } + } + + for (short ir1 = 0; ir1 < r1ptg; ++ir1) { + if (nxpsg >= 32) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); + } + if (nxpsg >= 16) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); + } + if (nxpsg >= 8) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); + } + if (nxpsg >= 4) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); + } + if (nxpsg >= 2) { + sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); + } + + //sumf[ir1] = simd_sum(sumf[ir1]); + } + + if (tx == 0) { + for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; + + if (i01 < args.ne01) { + dst_f32[i01] = sumf[ir1]; + } + } + } +} + +// dispatchers needed for compile-time nxpsg +// epb - elements per quantization block +template +kernel void kernel_mul_mv_ext_q4_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} + +template +kernel void kernel_mul_mv_ext_q4x4_f32_disp( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; + } +} + +typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; +typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; + +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; +template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; +template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>; +template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>; +template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>; + +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>; +template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>; + +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>; +template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>; + +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>; +template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>; + +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>; +template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>; + +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>; +template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>; + +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; + +#define N_MV_T_T 4 + +template +void kernel_mul_mv_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg) { + const int r0 = tgpig.x; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + device const T0 * x = (device const T0 *) (src0 + offset0); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + + if (args.ne00 < 128) { + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= args.ne11) { + break; + } + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); + + float sumf = 0; + for (int i = tiisg; i < args.ne00; i += 32) { + sumf += (T0) x[i] * (T1) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + } + } + } else { + device const T04 * x4 = (device const T04 *) x; + for (int row = 0; row < N_MV_T_T; ++row) { + int r1 = rb + row; + if (r1 >= args.ne11) { + break; + } + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T1 * y = (device const T1 *) (src1 + offset1); + device const T14 * y4 = (device const T14 *) y; + + float sumf = 0; + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], (float4) y4[i]); + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + } + } + } +} + +template +kernel void kernel_mul_mv( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_impl( + args, + src0, + src1, + dst, + tgpig, + tiisg); +} + +typedef decltype(kernel_mul_mv) mul_mv_t; + +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv; +#endif + +template +kernel void kernel_mul_mv_1row( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const T * x = (device const T *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + float sumf = 0; + if (args.ne00 < 128) { + for (int i = tiisg; i < args.ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[r0] = all_sum; + } + } else { + device const T4 * x4 = (device const T4 *) x; + device const float4 * y4 = (device const float4 *) y; + + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); + } + + float all_sum = simd_sum(sumf); + + if (tiisg == 0) { + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst_f32[r0] = all_sum; + } + } +} + +typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; + +template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; +#endif + +// Assumes row size (ne00) is a multiple of 4 +template +kernel void kernel_mul_mv_l4( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = args.ne11; + const int r0 = tgpig.x; + const int im = tgpig.z; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + device const T4 * x4 = (device const T4 *) (src0 + offset0); + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + + for (int r1 = 0; r1 < nrows; ++r1) { + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const float4 * y4 = (device const float4 *) (src1 + offset1); + + float sumf = 0; + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; + } + } +} + +typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; + +template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +#endif + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale, + thread float * cos_theta, thread float * sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + *cos_theta = cos(theta) * mscale; + *sin_theta = sin(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +static void rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); +} + +template +kernel void kernel_rope_norm( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_neox( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +typedef decltype(kernel_rope_norm) kernel_rope_norm_t; +typedef decltype(kernel_rope_neox) kernel_rope_neox_t; + +template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; +template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; + +template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; + +typedef void (im2col_t)( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +// const int64_t IC = tgpg[0]; + const int64_t OH = tgpg[1]; + const int64_t OW = tgpg[2]; + +// const int64_t N = ntg[0]; + const int64_t KH = ntg[1]; + const int64_t KW = ntg[2]; + + const int64_t in = tpitg[0]; + const int64_t ikh = tpitg[1]; + const int64_t ikw = tpitg[2]; + + const int64_t iic = tgpig[0]; + const int64_t ioh = tgpig[1]; + const int64_t iow = tgpig[2]; + + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw; + pdst[offset_dst] = x[offset_src]; + } +} + +template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; +template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; + +typedef void (im2col_ext_t)( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template +kernel void kernel_im2col_ext( + device const float * x, + device char * dst, + constant int32_t & ofs0, + constant int32_t & ofs1, + constant int32_t & IW, + constant int32_t & IH, + constant int32_t & CHW, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int32_t & d0, + constant int32_t & d1, + constant int32_t & N, + constant int32_t & KH, + constant int32_t & KW, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int64_t KHW = KH * KW; // KHW == ntg[1] * ntg[2], KW == ntg[2] + + const int64_t d = tgpig[0] / CHW; + const int64_t chw = tgpig[0] % CHW; + const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int64_t HW = tgpig[0] % KHW; + + const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= N) { + return; + } + + const int64_t tpitg_1 = HW / KW; + const int64_t tpitg_2 = HW % KW; + + const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0; + const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1; + + const int64_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW + + (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1; + pdst[offset_dst] = x[offset_src + iih * IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; + +typedef void (conv_transpose_1d_t)( + device const float * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template +kernel void kernel_conv_transpose_1d( + device const T * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]) { + + float v = 0.0f; + + for (int64_t c = 0; c < IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; + const int32_t input_offset = c * IL; + + for (int64_t i = 0; i < IL; i++) { + if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) { + v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i]; + } + } + } + + device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1); + + dst_ptr[0] = v; +} + +template [[host_name("kernel_conv_transpose_1d_f32_f32")]] +kernel void kernel_conv_transpose_1d( + device const float * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +template [[host_name("kernel_conv_transpose_1d_f16_f32")]] +kernel void kernel_conv_transpose_1d( + device const half * src0, + device const float * src1, + device char * dst, + constant int32_t & IC, + constant int32_t & IL, + constant int32_t & K, + constant int32_t & s0, + constant uint64_t & nb0, + constant uint64_t & nb1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]]); + +kernel void kernel_upscale_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant float & sf0, + constant float & sf1, + constant float & sf2, + constant float & sf3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3/sf3; + const int64_t i02 = i2/sf2; + const int64_t i01 = i1/sf1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int64_t i00 = i0/sf0; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_ptr[0] = src0_ptr[0]; + } +} + +kernel void kernel_pad_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < ne00) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } + } + + return; + } + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = 0.0f; + } +} + +kernel void kernel_pad_reflect_1d_f32( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & ne0, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int32_t & p0, + constant int32_t & p1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3; + const int64_t i02 = i2; + const int64_t i01 = i1; + + device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1); + + if (i1 < ne01 && i2 < ne02 && i3 < ne03) { + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i0 < p0) { + dst_ptr[i0] = src0_ptr[p0 - i0]; + } else if (i0 < ne0 - p1) { + dst_ptr[i0] = src0_ptr[i0 - p0]; + } else { + dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1]; + } + } + } +} + +kernel void kernel_arange_f32( + device char * dst, + constant int64_t & ne0, + constant float & start, + constant float & step, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + dst_ptr[i0] = start + step * i0; + } +} + +kernel void kernel_timestep_embedding_f32( + device const char * src0, + device char * dst, + constant uint64_t & nb1, + constant int & dim, + constant int & max_period, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + int i = tgpig.x; + device float * embed_data = (device float *)(dst + i*nb1); + + int half_ = dim / 2; + for (int j = tpitg.x; j < half_; j += ntg.x) { + float timestep = ((device float *)src0)[i]; + float freq = (float)exp(-log((float)max_period) * j / half_); + float arg = timestep * freq; + embed_data[j ] = cos(arg); + embed_data[j + half_] = sin(arg); + } + + if (dim % 2 != 0 && tpitg.x == 0) { + embed_data[dim] = 0.f; + } +} + +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols_pad) return; + + device const float * x_row = x + row * ncols; + threadgroup int32_t * dst_row = shared_values; + + // initialize indices + dst_row[col] = col; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; + +kernel void kernel_leaky_relu_f32( + device const float * src0, + device float * dst, + constant float & slope, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; +} + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template< + typename q_t, // query types in shared memory + typename q4_t, + typename q8x8_t, + typename k_t, // key types in shared memory + typename k4x4_t, + typename k8x8_t, + typename v_t, // value types in shared memory + typename v4x4_t, + typename v8x8_t, + typename qk_t, // Q*K types + typename qk8x8_t, + typename s_t, // soft-max types + typename s8x8_t, + typename o_t, // attention accumulation types + typename o4_t, + typename o8x8_t, + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // key type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short D, // head size + short Q = 8, // queries per threadgroup + short KV = 8, // key/value processed per each simdgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]*Q; + + const short D4 = D/4; + const short D8 = D/8; + const short D16 = D/16; + const short NW = N_SIMDWIDTH; + const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) + + const short TS = nsg*SH; // shared memory size per query in (s_t == float) + const short T = D + 2*TS; // shared memory size per query in (half) + + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix + + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + o8x8_t lo[D8]; + + // load heads from Q to shared memory + for (short j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 + j < args.ne01) { + sq4[j*D4 + i] = (q4_t) q4[i]; + } else { + sq4[j*D4 + i] = (q4_t) 0.0f; + } + } + } + + // zero out lo + for (short i = 0; i < D8; ++i) { + lo[i] = make_filled_simdgroup_matrix((o_t) 0.0f); + } + + // zero out shared memory SH + for (short j = 0; j < Q; ++j) { + for (short i = tiisg; i < SH; i += NW) { + ss[j*TS + i] = 0.0f; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + half S[Q] = { [0 ... Q-1] = 0.0f }; + half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; + + // thread indices inside the simdgroup + // TODO: see if we can utilize quad-group functions for better performance + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3) + const short tx = tiisg%4; + const short ty = tiisg/4; + + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + // load the queries from shared memory into local memory + q8x8_t mq[D8]; + + for (short i = 0; i < D8; ++i) { + simdgroup_load(mq[i], sq + i*8, D); + } + + const bool has_mask = mask != q; + + half slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const short h = iq2; + + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exph); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= args.ne11) { + break; + } + + if (has_mask) { + // used to detect blocks full of -INF + half smax = -INFINITY; + + // load the mask in shared memory + #pragma unroll(Q) + for (short j = 0; j < Q; ++j) { + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); + + const half m = pm[ic + tiisg]; + + ss[j*TS + C + tiisg] = m; + smax = max(smax, m); + } + + smax = simd_max(smax); + + if (smax == -INFINITY) { + continue; + } + } + + // Q*K^T + { + for (short cc = 0; cc < C/8; ++cc) { + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + + // this is compile-time check, so it does not have runtime overhead + if (is_same::value) { + // we can read directly from global memory + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + + #pragma unroll(D8) + for (short i = 0; i < D8; ++i) { + k8x8_t mk; + simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + + simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); + } + } else { + for (short ii = 0; ii < D16; ii += 4) { + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + + if (D16%4 == 0) { + // the head is evenly divisible by 4*16 = 64, so no need for bound checks + { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + k8x8_t mk; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); + } + } else { + if (ii + tx < D16) { + k4x4_t tmp; + deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); + sk4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < D16; ++k) { + k8x8_t mk; + + simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose + simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk); + + simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose + simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk); + } + } + } + } + + // cast qk_t -> s_t + //s8x8_t mqks(1.0f); + //simdgroup_multiply(mqks, mqk, mqks); + //simdgroup_store(mqks, ss + 8*cc, TS, 0, false); + + simdgroup_store(mqk, ss + 8*cc, TS, 0, false); + } + } + + // online softmax + { + for (ushort j = 0; j < Q; ++j) { + const half m = M[j]; + + // scale and apply the logitcap / mask + half s = ss[j*TS + tiisg]*args.scale; + + if (args.logit_softcap != 0.0f) { + s = args.logit_softcap*precise::tanh(s); + } + + // mqk = mqk + mask*slope + s += slope*ss[j*TS + C + tiisg]; + + M[j] = simd_max(max(M[j], s)); + + const half ms = exp(m - M[j]); + const half vs = exp(s - M[j]); + + S[j] = S[j]*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*TS + tiisg] = vs; + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg == j) { + ss[j*TS + 2*C + j] = ms; + } + } + } + + // O = diag(ms)*O + { + s8x8_t mm; + simdgroup_load(mm, ss + 2*C, TS, 0, false); + + #pragma unroll(D8) + for (short i = 0; i < D8; ++i) { + simdgroup_multiply(lo[i], mm, lo[i]); + } + } + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/8; ++cc) { + s8x8_t ms; + simdgroup_load(ms, ss + 8*cc, TS, 0, false); + + if (is_same::value) { + // we can read directly from global memory + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + + #pragma unroll(D8) + for (short i = 0; i < D8; ++i) { + v8x8_t mv; + simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 + + simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); + } + } else { + for (short ii = 0; ii < D16; ii += 4) { + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + + if (D16%4 == 0) { + // no need for bound checks + { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(4) + for (short k = 0; k < 4; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + } + } else { + if (ii + tx < D16) { + v4x4_t tmp; + deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); + sv4x4[4*ty + tx] = tmp; + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + for (short k = 0; k < 4 && ii + k < D16; ++k) { + v8x8_t mv; + + simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + + simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + } + } + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (short j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*TS + 0] = S[j]; + ss[j*TS + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (ushort sg = 1; sg < nsg; ++sg) { + half S = { 0.0f }; + half M = { -__FLT16_MAX__/2 }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], so + i*8, D, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (short j = 0; j < Q; ++j) { + const half S0 = ss[j*TS + 0]; + const half S1 = ss[j*TS + sg*SH + 0]; + + const half M0 = ss[j*TS + 1]; + const half M1 = ss[j*TS + sg*SH + 1]; + + M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*TS + 0] = S; + ss[j*TS + 1] = M; + + ss[j*TS + 2*C + j ] = ms0; + ss[j*TS + 2*C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + { + s8x8_t ms0; + s8x8_t ms1; + + simdgroup_load(ms0, ss + 2*C, TS, 0, false); + simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); + + #pragma unroll(D8) + for (short i = 0; i < D8; ++i) { + o8x8_t t; + + simdgroup_load (t, so + i*8, D, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (short i = 0; i < D8; ++i) { + simdgroup_store(lo[i], so + i*8, D, 0, false); + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { + const float S = ss[j*TS + 0]; + + for (short i = tiisg; i < D4; i += NW) { + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + } + } + } +} + +// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as +// template to be able to explore different combinations +// +#define FA_TYPES \ + half, half4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8 + +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +#endif + +template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +#undef FA_TYPES + +template< + typename q4_t, // query types in shared memory + typename q4x4_t, + typename k4x4_t, // key types in shared memory + typename v4x4_t, // value types in shared memory + typename qk_t, // Q*K types + typename s_t, // soft-max types + typename s4_t, + typename s4x4_t, + typename o4x4_t, // attention accumulation types + typename kd4x4_t, // key type in device memory + short nl_k, + void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), + typename vd4x4_t, // key type in device memory + short nl_v, + void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), + short D, // head size + short Q = 1, // queries per threadgroup + short C = 32> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short nsg = ntg.y; // number of simdgroups + + const int iq3 = tgpig[2]; + const int iq2 = tgpig[1]; + const int iq1 = tgpig[0]; + + const short D4 = D/4; + const short D16 = D/16; + const short NW = N_SIMDWIDTH; + const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0 + const short SH = 2*C; // shared memory per simdgroup + + const short T = D + nsg*SH; // shared memory size per query in (half) + + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask + threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + o4x4_t lo[D16/NL]; + + // load heads from Q to shared memory + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); + + for (short i = tiisg; i < D4; i += NW) { + if (iq1 < args.ne01) { + sq4[i] = (q4_t) q4[i]; + } else { + sq4[i] = (q4_t) 0.0f; + } + } + + // zero out lo + for (short i = 0; i < D16/NL; ++i) { + lo[i] = (o4x4_t) 0.0f; + } + + // zero out shared memory SH + for (short i = tiisg; i < SH/4; i += NW) { + ss4[i] = (s4_t) 0.0f; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + half S = 0.0f; + half M = -__FLT16_MAX__/2; + + // thread indices inside the simdgroup + const short tx = tiisg%NL; + const short ty = tiisg/NL; + + // broadcast kv + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + // load the queries from shared memory into local memory + q4x4_t mq[D16/NL]; + + #pragma unroll(D16/NL) + for (short ii = 0; ii < D16; ii += NL) { + mq[ii/NL] = sq4x4[ii + tx]; + } + + const bool has_mask = mask != q; + + // pointer to the mask + device const half * pm = (device const half *) (mask + iq1*args.nb31); + + half slope = 1.0f; + + // ALiBi + if (args.max_bias > 0.0f) { + const short h = iq2; + + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; + + slope = pow(base, exph); + } + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { + const int ic = ic0 + C*sgitg; + if (ic >= args.ne11) { + break; + } + + if (has_mask) { + sm[tiisg] = pm[ic + tiisg]; + } + + // Q*K^T + { + // each simdgroup processes 1 query and 4 (NW/NL) keys + for (short cc = 0; cc < C/4; ++cc) { + qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; + + device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + + #pragma unroll(D16/NL) + for (short ii = 0; ii < D16; ii += NL) { + const short i = ii + tx; + + k4x4_t mk; + deq_k(pk + i/nl_k, i%nl_k, mk); + + // note: this is less precise than the version below + //mqka[0] += dot(mq[ii/NL][0], mk[0]); + //mqka[1] += dot(mq[ii/NL][1], mk[1]); + //mqka[2] += dot(mq[ii/NL][2], mk[2]); + //mqka[3] += dot(mq[ii/NL][3], mk[3]); + + mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]); + mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]); + mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]); + mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]); + } + + qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3]; + + // simdgroup reduce + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + //mqk += simd_shuffle_down(mqk, 16); + //mqk += simd_shuffle_down(mqk, 8); + mqk += simd_shuffle_down(mqk, 4); + mqk += simd_shuffle_down(mqk, 2); + mqk += simd_shuffle_down(mqk, 1); + + // mqk = mqk*scale + mask*slope + if (tx == 0) { + mqk *= args.scale; + + if (args.logit_softcap != 0.0f) { + mqk = args.logit_softcap*precise::tanh(mqk); + } + + mqk += sm[4*cc + ty]*slope; + + ss[4*cc + ty] = mqk; + } + } + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // online softmax + { + const half m = M; + const half s = ss[tiisg]; + + M = simd_max(max(M, s)); + + const half ms = exp(m - M); + const half vs = exp(s - M); + + S = S*ms + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[tiisg] = vs; + + // O = diag(ms)*O + #pragma unroll(D16/NL) + for (short ii = 0; ii < D16; ii += NL) { + lo[ii/NL] *= ms; + } + } + + simdgroup_barrier(mem_flags::mem_threadgroup); + + // O = O + (Q*K^T)*V + { + for (short cc = 0; cc < C/4; ++cc) { + device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); + + const s4x4_t ms(ss[4*cc + ty]); + + #pragma unroll(D16/NL) + for (short ii = 0; ii < D16; ii += NL) { + const short i = ii + tx; + + v4x4_t mv; + deq_v(pv4 + i/nl_v, i%nl_v, mv); + + lo[ii/NL] += mv*ms; + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (tiisg == 0) { + ss[0] = (s_t) S; + ss[1] = (s_t) M; + } + } + + // simdgroup reduce + // [ 0, 8, 16, 24] -> [ 0] + // [ 1, 9, 17, 25] -> [ 1] + // [ 2, 10, 18, 26] -> [ 2] + // [ 3, 11, 19, 27] -> [ 3] + // [ 4, 12, 20, 28] -> [ 4] + // [ 5, 13, 21, 29] -> [ 5] + // [ 6, 14, 22, 30] -> [ 6] + // [ 7, 15, 23, 31] -> [ 7] + for (short ii = 0; ii < D16; ii += NL) { + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16); + lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8); + //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4); + //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2); + //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1); + + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16); + lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8); + //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4); + //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2); + //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1); + + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16); + lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8); + //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4); + //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2); + //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1); + + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16); + lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8); + //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4); + //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2); + //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // store results to shared memory + for (short i = tiisg; i < D16; i += NL) { + sr4x4[i] = lo[i/NL]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // parallel reduce + for (short r = nsg/2; r > 0; r >>= 1) { + if (sgitg < r) { + const half S0 = ss[ 0]; + const half S1 = ss[r*SH + 0]; + + const half M0 = ss[ 1]; + const half M1 = ss[r*SH + 1]; + + const half M = max(M0, M1); + + const half ms0 = exp(M0 - M); + const half ms1 = exp(M1 - M); + + const half S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[0] = S; + ss[1] = M; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short i = tiisg; i < D16; i += NW) { + sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + device float4x4 * dst44 = (device float4x4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + const float S = ss[0]; + + for (short i = tiisg; i < D16; i += NW) { + dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + } + } +} + +// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem +// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max +// +#define FA_TYPES \ + half4, half4x4, \ + half4x4, \ + half4x4, \ + float, \ + half, half4, half4x4, \ + half4x4 + +typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +#undef FA_TYPES + +template +kernel void kernel_set( + constant ggml_metal_kargs_set & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i13 = tgpig[2]; + const int i12 = tgpig[1]; + const int i11 = tgpig[0]; + + const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10; + + const int64_t i3 = n / (args.ne12*args.ne11*args.ne10); + const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10); + const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10; + + device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs); + + for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) { + device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10); + dst_data[i10] = (T) src[0]; + } +} + +typedef decltype(kernel_set) kernel_set_t; + +template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set; +template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set; + +template +kernel void kernel_cpy( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + + device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { + device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + dst_data[i00] = (T1) src[0]; + } +} + +typedef decltype(kernel_cpy) kernel_cpy_t; + +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +#endif +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +#endif + +kernel void kernel_cpy_f32_q8_0( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + + device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK8_0].d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst_data[i00/QK8_0].qs[j] = round(x0); + } + } +} + +kernel void kernel_cpy_f32_q4_0( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; + + device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_0].d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_0/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + dst_data[i00/QK4_0].qs[j] = xi0; + dst_data[i00/QK4_0].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q4_1( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; + + device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK4_1].d = d; + dst_data[i00/QK4_1].m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst_data[i00/QK4_1].qs[j] = xi0; + dst_data[i00/QK4_1].qs[j] |= xi1 << 4; + } + } +} + +kernel void kernel_cpy_f32_q5_0( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; + + device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_0].d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_0].qh[j] = qh8[j]; + } + } +} + +kernel void kernel_cpy_f32_q5_1( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; + + device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst_data[i00/QK5_1].d = d; + dst_data[i00/QK5_1].m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + for (int j = 0; j < 4; ++j) { + dst_data[i00/QK5_1].qh[j] = qh8[j]; + } + } +} + +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +kernel void kernel_cpy_f32_iq4_nl( + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; + + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; + + const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; + + device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { + device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + } +} + +kernel void kernel_concat( + constant ggml_metal_kargs_concat & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + int o[4] = {0, 0, 0, 0}; + o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); + + device const float * x; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00); + } else { + x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10); + } + + device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + *y = *x; + } +} + +template +void kernel_mul_mv_q2_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 + + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; + + for (int ib = ix; ib < nb; ib += 4) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; + } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += args.nb01/2; + sc += args.nb01; + dh += args.nb01/2; + } + + y4 += 4 * QK_K; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_q2_K_f32")]] +kernel void kernel_mul_mv_q2_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_q3_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); + + float yl[32]; + + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int ip = tid/4; // 0 or 1 + const int il = 2*((tid%4)/2); // 0 or 2 + const int ir = tid%2; + const int n = 8; + const int l0 = n*ir; + + // One would think that the Metal compiler would figure out that ip and il can only have + // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it + // with these two tales. + // + // Possible masks for the high bit + const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 + {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 + {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 + {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 + + // Possible masks for the low 2 bits + const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; + + const ushort4 hm = mm[2*ip + il/2]; + + const short shift = 2*il; + + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; + + const uint16_t s_shift1 = 4*ip; + const uint16_t s_shift2 = s_shift1 + il; + + const int q_offset = 32*ip + l0; + const int y_offset = 128*ip + 32*il + l0; + + device const float * y1 = yy + ix*QK_K + y_offset; + + uint32_t scales32, aux32; + thread uint16_t * scales16 = (thread uint16_t *)&scales32; + thread const int8_t * scales = (thread const int8_t *)&scales32; + + float sumf1[2] = {0.f}; + float sumf2[2] = {0.f}; + for (int i = ix; i < nb; i += 4) { + for (int l = 0; l < 8; ++l) { + yl[l+ 0] = y1[l+ 0]; + yl[l+ 8] = y1[l+16]; + yl[l+16] = y1[l+32]; + yl[l+24] = y1[l+48]; + } + + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; + + for (int row = 0; row < 2; ++row) { + const float d_all = (float)dh[0]; + + scales16[0] = a[4]; + scales16[1] = a[5]; + aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; + scales16[0] = a[il+0]; + scales16[1] = a[il+1]; + scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; + + float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm[il/2][0]); + s2 += yl[l+1] * (qs & qm[il/2][1]); + s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); + s4 += yl[l+16] * (qs & qm[il/2][2]); + s5 += yl[l+17] * (qs & qm[il/2][3]); + s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); + } + float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[0] - 32); + sumf2[row] += d2 * (scales[2] - 32); + + s1 = s2 = s3 = s4 = s5 = s6 = 0; + for (int l = 0; l < n; l += 2) { + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm[il/2][0]); + s2 += yl[l+9] * (qs & qm[il/2][1]); + s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); + s4 += yl[l+24] * (qs & qm[il/2][2]); + s5 += yl[l+25] * (qs & qm[il/2][3]); + s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); + } + d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); + sumf1[row] += d1 * (scales[1] - 32); + sumf2[row] += d2 * (scales[3] - 32); + + q += args.nb01/2; + h += args.nb01/2; + a += args.nb01/2; + dh += args.nb01/2; + } + + y1 += 4 * QK_K; + } + + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); + sumf1[row] = simd_sum(sumf); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + if (tiisg == 0) { + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + dst_f32[first_row + row] = sumf1[row]; + } + } +} + +[[host_name("kernel_mul_mv_q3_K_f32")]] +kernel void kernel_mul_mv_q3_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_q4_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int iq = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int first_row = r0 * N_DST; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[16]; + float yh[16]; + float sumf[N_DST]={0.f}, all_sum; + + device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + for (int ib = ix; ib < nb; ib += 4) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); + acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); + acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); + acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); + acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); + acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); + acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += args.nb01/2; + sc += args.nb01/2; + dh += args.nb01/2; + } + + y4 += 4 * QK_K; + } + + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; + + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_q4_K_f32")]] +kernel void kernel_mul_mv_q4_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_q5_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); + + float sumf[2]={0.f}; + + float yl[16], yh[16]; + + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = tiisg/4; + const int ix = tiisg%4; + const int iq = tid/4; + const int ir = tid%4; + const int n = 8; + + const int l0 = n*ir; + const int q_offset = 32*iq + l0; + const int y_offset = 64*iq + l0; + + const uint8_t hm1 = 1u << (2*iq); + const uint8_t hm2 = hm1 << 1; + const uint8_t hm3 = hm1 << 4; + const uint8_t hm4 = hm2 << 4; + + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; + + device const float * y1 = yy + ix*QK_K + y_offset; + + for (int i = ix; i < nb; i += 4) { + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; + + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } + + for (int row = 0; row < 2; ++row) { + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; + for (int l = 0; l < n; ++l) { + uint8_t h = qh[l]; + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += args.nb01; + qh += args.nb01; + dh += args.nb01/2; + a += args.nb01/2; + } + + y1 += 4 * QK_K; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q5_K_f32")]] +kernel void kernel_mul_mv_q5_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_q6_K_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = args.ne00/QK_K; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int row = 2*r0 + sgitg; + + if (row >= args.ne0) { + return; + } + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); + + float sumf = 0; + + const int tid = tiisg/2; + const int ix = tiisg%2; + const int ip = tid/8; // 0 or 1 + const int il = tid%8; + const int n = 4; + const int l0 = n*il; + const int is = 8*ip + l0/16; + + const int y_offset = 128*ip + l0; + const int q_offset_l = 64*ip + l0; + const int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += 2) { + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = x[i].qh + q_offset_h; + device const int8_t * sc = x[i].scales + is; + + device const float * y = yy + i * QK_K + y_offset; + + const float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst_f32[row] = tot; + } +} + +[[host_name("kernel_mul_mv_q6_K_f32")]] +kernel void kernel_mul_mv_q6_K_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +// ======================= "True" 2-bit + +template +void kernel_mul_mv_iq2_xxs_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xxs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + device const uint8_t * aux8 = (device const uint8_t *)q2; + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float sum = 0; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d * sum; + + dh += args.nb01/2; + q2 += args.nb01/2; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xxs_f32")]] +kernel void kernel_mul_mv_iq2_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq2_xs_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512); + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_xs * xr = x + ibl; + device const uint16_t * q2 = xr->qs + 4 * ib; + device const uint8_t * sc = xr->scales + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const uint8_t ls1 = sc[0] & 0xf; + const uint8_t ls2 = sc[0] >> 4; + const float d1 = db * (0.5f + ls1); + const float d2 = db * (0.5f + ls2); + + float sum1 = 0, sum2 = 0; + for (int l = 0; l < 2; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + for (int l = 2; l < 4; ++l) { + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; + for (int j = 0; j < 8; ++j) { + sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); + } + } + sumf[row] += d1 * sum1 + d2 * sum2; + + dh += args.nb01/2; + q2 += args.nb01/2; + sc += args.nb01; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_xs_f32")]] +kernel void kernel_mul_mv_iq2_xs_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq3_xxs_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); + { + int nval = 4; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i]; + nval = 2; + pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_xxs * xr = x + ibl; + device const uint8_t * q3 = xr->qs + 8 * ib; + device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + const float db = dh[0]; + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = db * (0.5f + (aux32 >> 28)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += args.nb01/2; + q3 += args.nb01; + gas += args.nb01/2; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum * 0.5f; + } + } +} + +[[host_name("kernel_mul_mv_iq3_xxs_f32")]] +kernel void kernel_mul_mv_iq3_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq3_s_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem; + { + int nval = 8; + int pos = (32*sgitg + tiisg)*nval; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq3_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 8 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + (ib/2); + device const uint8_t * signs = xr->signs + 4 * ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); + + float2 sum = {0}; + for (int l = 0; l < 4; ++l) { + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); + for (int j = 0; j < 4; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + } + } + sumf[row] += d * (sum[0] + sum[1]); + + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq3_s_f32")]] +kernel void kernel_mul_mv_iq3_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq2_s_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem; + //{ + // int nval = 32; + // int pos = (32*sgitg + tiisg)*nval; + // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i]; + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq2_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + ib; + device const uint8_t * sc = xr->scales + ib; + device const uint8_t * signs = qs + QK_K/8; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + const float db = dh[0]; + const float d1 = db * (0.5f + (sc[0] & 0xf)); + const float d2 = db * (0.5f + (sc[0] >> 4)); + + float2 sum = {0}; + for (int l = 0; l < 2; ++l) { + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + for (int j = 0; j < 8; ++j) { + sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); + sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); + } + } + sumf[row] += d1 * sum[0] + d2 * sum[1]; + + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum * 0.25f; + } + } +} + +[[host_name("kernel_mul_mv_iq2_s_f32")]] +kernel void kernel_mul_mv_iq2_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +void kernel_mul_mv_iq1_s_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float sumy = 0; + for (int i = 0; i < 32; ++i) { + yl[i] = y4[i]; + sumy += yl[i]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_s * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; + + for (int row = 0; row < N_DST; row++) { + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); + + float sum = 0; + for (int j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); + + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01/2; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum; + } + } +} + +template +void kernel_mul_mv_iq1_m_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; + + const int nb32 = nb * (QK_K / 32); + + const int ix = tiisg; + + device const float * y4 = y + 32 * ix; + + iq1m_scale_t scale; + + for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + + float4 sumy = {0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; + } + + const int ibl = ib32 / (QK_K / 32); + const int ib = ib32 % (QK_K / 32); + + device const block_iq1_m * xr = x + ibl; + device const uint8_t * qs = xr->qs + 4 * ib; + device const uint8_t * qh = xr->qh + 2 * ib; + device const uint16_t * sc = (device const uint16_t *)xr->scales; + + for (int row = 0; row < N_DST; row++) { + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); + + float2 sum = {0.f}; + for (int j = 0; j < 4; ++j) { + sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); + sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); + } + const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); + + sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); + + sc += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + } + + y4 += 32 * 32; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum; + } + } +} + +template +void kernel_mul_mv_iq4_nl_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + const int nb = args.ne00/QK4_NL; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0 or 1 + + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK4_NL + it * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ib = ix; ib < nb; ib += 16) { + + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { + + device const block_iq4_nl & xb = x[row*nb + ib]; + device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = q4[0] | (q4[1] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = q4[2] | (q4[3] << 16); + aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; + aux32[0] &= 0x0f0f0f0f; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 16 * QK4_NL; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum; + } + } +} + +template +void kernel_mul_mv_iq4_xs_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + const int nb = args.ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * 2 + sgitg) * 2; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + const int ix = tiisg/16; // 0 or 1 + const int it = tiisg%16; // 0...15 + const int ib = it/2; + const int il = it%2; + + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + float4 yl[4]; + float sumf[2]={0.f}, all_sum; + + device const float * yb = y + ix * QK_K + ib * 32 + il * 8; + + uint32_t aux32[2]; + thread const uint8_t * q8 = (thread const uint8_t *)aux32; + + float4 qf1, qf2; + + for (int ibl = ix; ibl < nb; ibl += 2) { + device const float4 * y4 = (device const float4 *)yb; + yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + + for (int row = 0; row < 2; ++row) { + device const block_iq4_xs & xb = x[row*nb + ibl]; + device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + + float4 acc1 = {0.f}, acc2 = {0.f}; + + aux32[0] = (q4[0] ) & 0x0f0f0f0f; + aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; + acc1 += yl[0] * qf1; + acc2 += yl[1] * qf2; + + aux32[0] = (q4[1] ) & 0x0f0f0f0f; + aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; + acc1 += yl[2] * qf1; + acc2 += yl[3] * qf2; + + acc1 += acc2; + + const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; + sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + + } + + yb += 2 * QK_K; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = all_sum; + } + } +} + +[[host_name("kernel_mul_mv_iq1_s_f32")]] +kernel void kernel_mul_mv_iq1_s_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq1_m_f32")]] +kernel void kernel_mul_mv_iq1_m_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_nl_f32")]] +kernel void kernel_mul_mv_iq4_nl_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +[[host_name("kernel_mul_mv_iq4_xs_f32")]] +kernel void kernel_mul_mv_iq4_xs_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +template +kernel void kernel_get_rows_q( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { + float4x4 temp; + dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +template +kernel void kernel_get_rows_f( + device const void * src0, + device const void * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + } +} + +kernel void kernel_get_rows_i32( + device const void * src0, + device const void * src1, + device int32_t * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb1, + constant uint64_t & nb2, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int64_t i10 = tgpig.x; + const int64_t i11 = tgpig.y; + + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; + + const int64_t i02 = i11; + + for (int ind = tiitg; ind < ne00; ind += tptg.x) { + (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; + } +} + + +#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B +#define BLOCK_SIZE_K 32 +#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A +#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B +#define THREAD_PER_BLOCK 128 +#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers +#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers +#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8 +#define SG_MAT_ROW 8 + +// each block_q contains 16*nl weights +template +kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); + + const int r0 = tgpig.y; + const int r1 = tgpig.x; + const int im = tgpig.z; + + // if this block is of 64x32 shape or smaller + const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_T8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 mc[8]; + + for (short i = 0; i < 8; i++){ + mc[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; + + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; + + device const float * y = (device const float *)(src1 + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*(r1*BLOCK_SIZE_N + thread_col) + + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + T4x4 temp_a; + dequantize_func(x, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); + + #pragma unroll(4) + for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { + #pragma unroll(4) + for (short i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + } + + simdgroup_barrier(mem_flags::mem_none); + + #pragma unroll(2) + for (short i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + } + + #pragma unroll(8) + for (short i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); + } + + lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; + lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; + } + } + + if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); + } + } + } + } +} + +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids +// TODO: this kernel needs to be reimplemented from scratch for better performance +template +void kernel_mul_mm_id_impl( + int32_t ne00, + int32_t ne02, + uint64_t nb01, + uint64_t nb02, + int32_t ne11, + int32_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int32_t ne0, + int32_t ne1, + int64_t ne0ne1, + device const char * src0, + device const char * src1, + threadgroup ushort2 * rowids, + device char * dst, + threadgroup char * shmem, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); + + const int r0 = tgpig.y; + const int r1 = tgpig.x; + + if (r1*BLOCK_SIZE_N >= ne1) return; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 mc[8]; + for (int i = 0; i < 8; i++){ + mc[i] = make_filled_simdgroup_matrix(0.f); + } + short il = (tiitg % THREAD_PER_ROW); + + ushort offset1 = il/nl; + + threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * id[1] + + nb11 * (id[0] % ne11) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + #pragma unroll(BLOCK_SIZE_K/8) + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + + #pragma unroll(8) + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); + } + } + } + + { + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; + int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; + + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); + } + } + } + } +} + +template +kernel void kernel_mul_mm_id( + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + const int32_t i02 = tgpig.z; + + tgpig.z = 0; + + device const char * src0 = src0s + i02*args.nb02; + + // row indices + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); + + // TODO: parallelize this loop + int32_t _ne1 = 0; + for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { + for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; + if (id == i02) { + if (tiitg == 0) { + rowids[_ne1] = ushort2(ii0, ii1); + } + _ne1++; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + kernel_mul_mm_id_impl( + args.ne00, + args.ne02, + args.nb01, + args.nb02, + args.ne11, + args.ne12, + args.nb10, + args.nb11, + args.nb12, + args.ne0, + _ne1, + (int64_t)args.ne0*args.ne1, + src0, + src1, + rowids, + dst, + shmem, + tgpig, + tiitg, + sgitg); +} + +#define QK_NL 16 + +// +// get rows +// + +typedef decltype(kernel_get_rows_f) get_rows_f_t; + +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +#endif + +typedef decltype(kernel_get_rows_q) get_rows_q_t; + +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; + +// +// matrix-matrix multiplication +// + +typedef decltype(kernel_mul_mm) mat_mm_t; + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm; +#endif +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; + +// +// indirect matrix-matrix multiplication +// + +typedef decltype(kernel_mul_mm_id) mat_mm_id_t; + +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#endif +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; + +// +// matrix-vector multiplication +// + +typedef void (kernel_mul_mv_impl_t)( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg); + +typedef void (kernel_mul_mv2_impl_t)( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg); + +template +void mmv_fn( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { + impl_fn(args, src0, src1, dst, tgpig, tiisg); +} + +template +void mmv_fn( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { + impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; + +template +kernel void kernel_mul_mv_id( + constant ggml_metal_kargs_mul_mv_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const int iid1 = tgpig.z/args.nei0; + const int idx = tgpig.z%args.nei0; + + tgpig.z = 0; + + const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx]; + + const int64_t i11 = idx % args.ne11; + const int64_t i12 = iid1; + + const int64_t i1 = idx; + const int64_t i2 = i12; + + device const char * src0_cur = src0s + i02*args.nb02; + device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12; + + device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float); + + ggml_metal_kargs_mul_mv args0 = { + /*.ne00 =*/ args.ne00, + /*.ne01 =*/ args.ne01, + /*.ne02 =*/ 1, // args.ne02, + /*.nb00 =*/ args.nb00, + /*.nb01 =*/ args.nb01, + /*.nb02 =*/ args.nb02, + /*.nb03 =*/ args.nb02, // args.ne02 == 1 + /*.ne10 =*/ args.ne10, + /*.ne11 =*/ 1, // args.ne11, + /*.ne12 =*/ 1, // args.ne12, + /*.nb10 =*/ args.nb10, + /*.nb11 =*/ args.nb11, + /*.nb12 =*/ args.nb12, + /*.nb13 =*/ args.nb12, // ne12 == 1 + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ 1, // args.ne1, + /*.r2 =*/ 1, + /*.r3 =*/ 1, + }; + + impl_fn( + args0, + /* src0 */ src0_cur, + /* src1 */ src1_cur, + /* dst */ dst_cur, + shmem, + tgpig, + tiitg, + tiisg, + sgitg); +} + +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; + +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#endif +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + +kernel void kernel_pool_2d_max_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + + float res = -INFINITY; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + res = MAX(res, i_ptr[i * IW + j]); + } + } + + o_ptr[cur_oh * OW + cur_ow] = res; +} + +kernel void kernel_pool_2d_avg_f32( + device const float * src0, + device float * dst, + constant int32_t & k0, + constant int32_t & k1, + constant int32_t & s0, + constant int32_t & s1, + constant int32_t & p0, + constant int32_t & p1, + constant int64_t & IH, + constant int64_t & IW, + constant int64_t & OH, + constant int64_t & OW, + constant int64_t & parallel_elements, + uint gid[[thread_position_in_grid]]) { + + if (gid >= parallel_elements) { + return; + } + + const int idx = gid; + const int I_HW = IH * IW; + const int O_HW = OH * OW; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / OW; + const int cur_ow = idx % O_HW % OW; + + device const float * i_ptr = src0 + nc * I_HW; + device float * o_ptr = dst + nc * O_HW; + + const int start_h = cur_oh * s1 - p1; + const int bh = MAX(0, start_h); + const int eh = MIN(IH, start_h + k1); + const int start_w = cur_ow * s0 - p0; + const int bw = MAX(0, start_w); + const int ew = MIN(IW, start_w + k0); + // const float scale = 1. / ((eh - bh) * (ew - bw)); + const float scale = 1. / (k0 * k1); + + float res = 0; + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { + float cur = i_ptr[i * IW + j]; + res += cur * scale; + } + } + + o_ptr[cur_oh * OW + cur_ow] = res; +} diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt new file mode 100644 index 000000000..415b2b2e0 --- /dev/null +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -0,0 +1,107 @@ +if (NOT EXISTS $ENV{MUSA_PATH}) + if (NOT EXISTS /opt/musa) + set(MUSA_PATH /usr/local/musa) + else() + set(MUSA_PATH /opt/musa) + endif() +else() + set(MUSA_PATH $ENV{MUSA_PATH}) +endif() + +set(CMAKE_C_COMPILER "${MUSA_PATH}/bin/clang") +set(CMAKE_C_EXTENSIONS OFF) +set(CMAKE_CXX_COMPILER "${MUSA_PATH}/bin/clang++") +set(CMAKE_CXX_EXTENSIONS OFF) + +list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake") + +find_package(MUSAToolkit) + +if (MUSAToolkit_FOUND) + message(STATUS "MUSA Toolkit found") + + if (NOT DEFINED MUSA_ARCHITECTURES) + set(MUSA_ARCHITECTURES "21;22") + endif() + message(STATUS "Using MUSA architectures: ${MUSA_ARCHITECTURES}") + + file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh") + list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h") + + file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + + if (GGML_CUDA_FA_ALL_QUANTS) + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) + else() + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + endif() + + set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) + foreach(SOURCE ${GGML_SOURCES_MUSA}) + set(COMPILE_FLAGS "-x musa -mtgpu") + foreach(ARCH ${MUSA_ARCHITECTURES}) + set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}") + endforeach() + set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS ${COMPILE_FLAGS}) + endforeach() + + ggml_add_backend_library(ggml-musa + ${GGML_HEADERS_MUSA} + ${GGML_SOURCES_MUSA} + ) + + # TODO: do not use CUDA definitions for MUSA + target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) + + add_compile_definitions(GGML_USE_MUSA) + add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) + + if (GGML_CUDA_GRAPHS) + add_compile_definitions(GGML_CUDA_USE_GRAPHS) + endif() + + if (GGML_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() + + if (GGML_CUDA_FORCE_CUBLAS) + add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) + endif() + + if (GGML_CUDA_NO_VMM) + add_compile_definitions(GGML_CUDA_NO_VMM) + endif() + + if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) + add_compile_definitions(GGML_CUDA_F16) + endif() + + if (GGML_CUDA_NO_PEER_COPY) + add_compile_definitions(GGML_CUDA_NO_PEER_COPY) + endif() + + if (GGML_STATIC) + target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static) + else() + target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas) + endif() + + if (GGML_CUDA_NO_VMM) + # No VMM requested, no need to link directly with the musa driver lib (libmusa.so) + else() + target_link_libraries(ggml-musa PRIVATE MUSA::musa_driver) + endif() +else() + message(FATAL_ERROR "MUSA Toolkit not found") +endif() diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt new file mode 100644 index 000000000..45328a657 --- /dev/null +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -0,0 +1,147 @@ +find_package(OpenCL REQUIRED) +find_package(Python3 REQUIRED) + +set(TARGET_NAME ggml-opencl) + +ggml_add_backend_library(${TARGET_NAME} + ggml-opencl.cpp + ../../include/ggml-opencl.h) +target_link_libraries(${TARGET_NAME} PRIVATE ${OpenCL_LIBRARIES}) +target_include_directories(${TARGET_NAME} PRIVATE ${OpenCL_INCLUDE_DIRS}) + +if (GGML_OPENCL_PROFILING) + message(STATUS "OpenCL profiling enabled (increases CPU overhead)") + add_compile_definitions(GGML_OPENCL_PROFILING) +endif () + +add_compile_definitions(GGML_OPENCL_SOA_Q) + +if (GGML_OPENCL_USE_ADRENO_KERNELS) + message(STATUS "OpenCL will use matmul kernels optimized for Adreno") + add_compile_definitions(GGML_OPENCL_USE_ADRENO_KERNELS) +endif () + +if (GGML_OPENCL_EMBED_KERNELS) + add_compile_definitions(GGML_OPENCL_EMBED_KERNELS) + + set(OPENCL_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl.cl.h") + set(OPENCL_MM_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mm.cl.h") + set(OPENCL_CVT_CL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_cvt.cl.h") + + set(OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle.cl.h") + set(OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle_general.cl.h") + set(OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h") + set(OPENCL_TRANSPOSE_16_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_16.cl.h") + set(OPENCL_TRANSPOSE_32_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32.cl.h") + set(OPENCL_TRANSPOSE_32_16_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32_16.cl.h") + + set(EMBED_KERNEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py") + file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") + + include_directories("${CMAKE_BINARY_DIR}/autogenerated") + + # Python must be accessible from command line + add_custom_command( + OUTPUT ${OPENCL_CL_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl.cl + ${OPENCL_CL_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_MM_CL_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mm.cl + ${OPENCL_MM_CL_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_mm.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_mm.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_CVT_CL_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_cvt.cl + ${OPENCL_CVT_CL_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_cvt.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_cvt.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle.cl + ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_gemv_noshuffle.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_gemv_noshuffle.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle_general.cl + ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_gemv_noshuffle_general.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_gemv_noshuffle_general.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl + ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_mul_mat_Ab_Bi_8x4.cl.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_TRANSPOSE_16_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_16.cl + ${OPENCL_TRANSPOSE_16_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_transpose_16.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_transpose_16.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_TRANSPOSE_32_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32.cl + ${OPENCL_TRANSPOSE_32_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_transpose_32.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_transpose_32.cl.h" + ) + + add_custom_command( + OUTPUT ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED} + COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT} + ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32_16.cl + ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED} + DEPENDS kernels/ggml-opencl_transpose_32_16.cl ${EMBED_KERNEL_SCRIPT} + COMMENT "Generate ggml-opencl_transpose_32_16.cl.h" + ) + + target_sources(${TARGET_NAME} PRIVATE + ${OPENCL_CL_SOURCE_EMBED} + ${OPENCL_MM_CL_SOURCE_EMBED} + ${OPENCL_CVT_CL_SOURCE_EMBED} + ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED} + ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED} + ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED} + ${OPENCL_TRANSPOSE_16_SOURCE_EMBED} + ${OPENCL_TRANSPOSE_32_SOURCE_EMBED} + ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}) +else () + # copy ggml-opencl.cl to bin directory + configure_file(kernels/ggml-opencl.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl.cl COPYONLY) + configure_file(kernels/ggml-opencl_mm.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mm.cl COPYONLY) + configure_file(kernels/ggml-opencl_cvt.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_cvt.cl COPYONLY) + + configure_file(kernels/ggml-opencl_gemv_noshuffle.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle.cl COPYONLY) + configure_file(kernels/ggml-opencl_gemv_noshuffle_general.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle_general.cl COPYONLY) + configure_file(kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mul_mat_Ab_Bi_8x4.cl COPYONLY) + configure_file(kernels/ggml-opencl_transpose_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_16.cl COPYONLY) + configure_file(kernels/ggml-opencl_transpose_32.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32.cl COPYONLY) + configure_file(kernels/ggml-opencl_transpose_32_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32_16.cl COPYONLY) +endif () diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp new file mode 100644 index 000000000..ed90e471a --- /dev/null +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -0,0 +1,4004 @@ +#define CL_TARGET_OPENCL_VERSION 220 +#define CL_USE_DEPRECATED_OPENCL_1_2_APIS + +// suppress warnings in CL headers for GCC and Clang +#pragma GCC diagnostic ignored "-Woverlength-strings" +#ifdef __clang__ +#pragma GCC diagnostic ignored "-Wgnu-anonymous-struct" +#endif + +#include "ggml-opencl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" +#include "ggml.h" + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define UNUSED(x) (void)(x) + +#define CL_CHECK(err) \ + do { \ + cl_int err_ = (err); \ + if (err_ != CL_SUCCESS) { \ + GGML_LOG_ERROR("ggml_opencl: %s error %d at %s:%d\n", \ + #err, err_, __FILE__, __LINE__); \ + GGML_ASSERT(0); \ + } \ + } while (0) + +//------------------------------------------------------------------------------ +// OpenCL +//------------------------------------------------------------------------------ + +bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor); + +enum GPU_FAMILY { + ADRENO, + INTEL, + UNKNOWN, +}; + +enum ADRENO_GPU_GEN { + ADRENO_UNKNOWN, + A7X, + A8X, + X1E, +}; + +static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { + if (strstr(device_name, "730") || + strstr(device_name, "740") || + strstr(device_name, "750")) { + return ADRENO_GPU_GEN::A7X; + } + + if (strstr(device_name, "830")) { + return ADRENO_GPU_GEN::A8X; + } + + if (strstr(device_name, "X1")) { + return ADRENO_GPU_GEN::X1E; + } + + return ADRENO_GPU_GEN::ADRENO_UNKNOWN; +} + +static int get_adreno_cl_compiler_version(const char *driver_version) { + std::string driver_ver_str(driver_version); + size_t compiler_ver_pos = driver_ver_str.find("E031"); + size_t compiler_ver_len = 13; + size_t compiler_ver_offset = 5; + + if (compiler_ver_pos == std::string::npos) { + compiler_ver_pos = driver_ver_str.find("DX"); + if (compiler_ver_pos == std::string::npos) { + return -1; + } + compiler_ver_len = 11; + compiler_ver_offset = 3; + } + + std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len); + std::string major_ver_str = compiler_ver_str.substr(compiler_ver_offset, 2); + return std::atoi(major_ver_str.c_str()); +} + +// backend device context +struct ggml_backend_opencl_device_context { + cl_platform_id platform; + std::string platform_name; + + cl_device_id device; + std::string device_name; +}; + +// backend context +struct ggml_backend_opencl_context { + cl_device_id device; + std::string device_name; + + std::string driver_version; + + GPU_FAMILY gpu_family; + ADRENO_GPU_GEN adreno_gen; + + cl_int alignment; + size_t max_alloc_size; + bool fp16_support; + + int adreno_wave_size; + + cl_context context; + cl_command_queue queue; + + cl_program program; + cl_program program_1; + cl_program program_2; + + cl_kernel kernel_add, kernel_add_row; + cl_kernel kernel_mul, kernel_mul_row; + cl_kernel kernel_scale; + cl_kernel kernel_silu, kernel_silu_4; + cl_kernel kernel_gelu, kernel_gelu_4; + cl_kernel kernel_relu; + cl_kernel kernel_clamp; + cl_kernel kernel_norm; + cl_kernel kernel_rms_norm; + cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; + cl_kernel kernel_soft_max, kernel_soft_max_4; + cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0; + cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; + cl_kernel kernel_mul_mat_f32_f32; + cl_kernel kernel_mul_mat_f16_f16; + cl_kernel kernel_mul_mat_f16_f32_1row; + cl_kernel kernel_mul_mat_f16_f32; + cl_kernel kernel_mul_mat_f16_f32_l4; + cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; + cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0, kernel_mul_mat_q4_0_f32_flat; + cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; + cl_kernel kernel_convert_block_q4_0_noshuffle, kernel_mul_mat_q4_0_f32_flat_v0, + kernel_mul_mat_q4_0_f32_flat_img_v0; + cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; + cl_kernel kernel_mul_mv_q6_K_f32; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Transpose kernels + cl_program program_transpose_32; + cl_program program_transpose_32_16; + cl_program program_transpose_16; + cl_kernel kernel_transpose_32; + cl_kernel kernel_transpose_32_16; + cl_kernel kernel_transpose_16; + + cl_mem A_s_d_max; // max scale buffer size for transpose + cl_mem A_q_d_max; // max weight buffer size for transpose + cl_mem B_d_max; // max activation buffer size for transpose + + // Gemm and Gemv related programs, kernels, etc + cl_program program_CL_gemm; + cl_program program_CL_gemv_general; + cl_program program_CL_gemv_4096_1_11008; + cl_program program_CL_gemv_4096_1_4096; + cl_program program_CL_gemv_11008_1_4096; + cl_program program_CL_gemv_32000_1_4096; + cl_kernel CL_mul_mat_Ab_Bi_8x4; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; + cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS +}; + +static ggml_backend_device g_ggml_backend_opencl_device; +static ggml_backend_opencl_device_context g_ggml_ctx_dev_main { + /*.platform =*/ nullptr, + /*.platform_nane =*/ "", + /*.device =*/ nullptr, + /*.device_name =*/ "", +}; + +static int ggml_backend_opencl_n_devices = 0; + +// Profiling +#ifdef GGML_OPENCL_PROFILING +struct ProfilingInfo { + std::string op_name; + std::string kernel_name; + // Kernel execution time in nanoseconds. + cl_ulong duration_ns; + // Global and local work sizes. + size_t global_size[3]; + size_t local_size[3]; + // Op output size. + size_t output_size[4]; +}; + +std::vector g_profiling_info; +#endif + +inline std::string read_file(const std::string &path) { + std::ifstream ifs(path); + if (!ifs) { + return ""; + } + std::string text; + ifs.seekg(0, std::ios::end); + text.resize(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + ifs.read(&text[0], text.size()); + return text; +} + +static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) { + cl_program p; + char *program_log; + size_t program_size; + size_t log_size; + int err; + + program_size = strlen(program_buffer); + + p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err); + if(err < 0) { + GGML_LOG_ERROR("OpenCL error creating program"); + exit(1); + } + + err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL); + if(err < 0) { + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size); + program_log = (char*) malloc(log_size + 1); + program_log[log_size] = '\0'; + clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL); + GGML_LOG_ERROR("ggml_opencl: kernel compile error:\n\n%s\n", program_log); + free(program_log); + exit(1); + } + + return p; +} + +static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { + static bool initialized = false; + static ggml_backend_opencl_context *backend_ctx = nullptr; + + if (initialized) { + return backend_ctx; + } + + ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *)dev->context; + GGML_ASSERT(dev_ctx); + GGML_ASSERT(dev_ctx->platform == nullptr); + GGML_ASSERT(dev_ctx->device == nullptr); + GGML_ASSERT(backend_ctx == nullptr); + + initialized = true; + backend_ctx = new ggml_backend_opencl_context(); + backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + + cl_int err; + +#ifdef GGML_PROFILE_OPENCL + GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n"); +#endif + + struct cl_device; + struct cl_platform { + cl_platform_id id; + unsigned number; + char name[128]; + char vendor[128]; + struct cl_device * devices; + unsigned n_devices; + struct cl_device * default_device; + }; + + struct cl_device { + struct cl_platform * platform; + cl_device_id id; + unsigned number; + cl_device_type type; + char name[128]; + }; + + enum { NPLAT = 16, NDEV = 16 }; + + struct cl_platform platforms[NPLAT]; + unsigned n_platforms = 0; + struct cl_device devices[NDEV]; + unsigned n_devices = 0; + struct cl_device * default_device = NULL; + + cl_platform_id platform_ids[NPLAT]; + if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) { + GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n"); + return backend_ctx; + } + + for (unsigned i = 0; i < n_platforms; i++) { + struct cl_platform * p = &platforms[i]; + p->number = i; + p->id = platform_ids[i]; + CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL)); + CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL)); + + cl_device_id device_ids[NDEV]; + cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices); + if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) { + p->n_devices = 0; + } else { + CL_CHECK(clGetDeviceIDsError); + } + p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL; + p->default_device = NULL; + + for (unsigned j = 0; j < p->n_devices; j++) { + struct cl_device * d = &devices[n_devices]; + d->number = n_devices++; + d->id = device_ids[j]; + d->platform = p; + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL)); + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL)); + + if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) { + p->default_device = d; + } + } + + if (default_device == NULL && p->default_device != NULL) { + default_device = p->default_device; + } + } + + if (n_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n"); + return backend_ctx; + } + + char * user_platform_string = getenv("GGML_OPENCL_PLATFORM"); + char * user_device_string = getenv("GGML_OPENCL_DEVICE"); + int user_platform_number = -1; + int user_device_number = -1; + + unsigned n; + if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) { + user_platform_number = (int)n; + } + if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) { + user_device_number = (int)n; + } + if (user_platform_number != -1 && user_device_number != -1) { + cl_platform* platform = &platforms[user_platform_number]; + if ((unsigned)user_device_number >= platform->n_devices) { + GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number); + exit(1); + } + default_device = &platform->devices[user_device_number]; + } else { + + struct cl_device * selected_devices = devices; + unsigned n_selected_devices = n_devices; + + if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) { + for (unsigned i = 0; i < n_platforms; i++) { + struct cl_platform * p = &platforms[i]; + if (strstr(p->name, user_platform_string) != NULL || + strstr(p->vendor, user_platform_string) != NULL) { + user_platform_number = (int)i; + break; + } + } + if (user_platform_number == -1) { + GGML_LOG_ERROR("ggml_opencl: no platform matching '%s' was found.\n", user_platform_string); + exit(1); + } + } + if (user_platform_number != -1) { + struct cl_platform * p = &platforms[user_platform_number]; + selected_devices = p->devices; + n_selected_devices = p->n_devices; + default_device = p->default_device; + if (n_selected_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name); + exit(1); + } + } + + if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) { + for (unsigned i = 0; i < n_selected_devices; i++) { + struct cl_device * d = &selected_devices[i]; + if (strstr(d->name, user_device_string) != NULL) { + user_device_number = d->number; + break; + } + } + if (user_device_number == -1) { + GGML_LOG_ERROR("ggml_opencl: no device matching '%s' was found.\n", user_device_string); + exit(1); + } + } + if (user_device_number != -1) { + selected_devices = &devices[user_device_number]; + n_selected_devices = 1; + default_device = &selected_devices[0]; + } + + GGML_ASSERT(n_selected_devices > 0); + + if (default_device == NULL) { + default_device = &selected_devices[0]; + } + } + + GGML_LOG_INFO("ggml_opencl: selecting platform: '%s'\n", default_device->platform->name); + GGML_LOG_INFO("ggml_opencl: selecting device: '%s'\n", default_device->name); + if (default_device->type != CL_DEVICE_TYPE_GPU) { + GGML_LOG_WARN("ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name); + } + + dev_ctx->platform = default_device->platform->id; + dev_ctx->device = default_device->id; + backend_ctx->device = default_device->id; + + if (strstr(default_device->name, "Adreno")) { + backend_ctx->gpu_family = GPU_FAMILY::ADRENO; + backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name); + + // Default wave size is 128, A8x uses 64. + if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A8X) { + backend_ctx->adreno_wave_size = 64; + } else if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A7X || + backend_ctx->adreno_gen == ADRENO_GPU_GEN::X1E) { + backend_ctx->adreno_wave_size = 128; + } else { + backend_ctx->adreno_wave_size = 128; + GGML_LOG_WARN("ggml_opencl: Unsupported Adreno GPU: %s, " + "using wave size %d, " + "may not work as expected\n", + backend_ctx->device_name.c_str(), backend_ctx->adreno_wave_size); + } + } else if (strstr(default_device->name, "Intel")) { + backend_ctx->gpu_family = GPU_FAMILY::INTEL; + } else { + GGML_LOG_ERROR("Unsupported GPU: %s\n", default_device->name); + backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + return backend_ctx; + } + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { + GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " + "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); + return backend_ctx; + } +#endif + + // Populate backend device name + dev_ctx->platform_name = default_device->platform->name; + dev_ctx->device_name = default_device->name; + backend_ctx->device_name = default_device->name; + + // A local ref of cl_device_id for convenience + cl_device_id device = backend_ctx->device; + + // Check device OpenCL version, OpenCL 2.0 or above is required + size_t device_ver_str_size; + clGetDeviceInfo(device, CL_DEVICE_VERSION, 0, NULL, &device_ver_str_size); + char *device_ver_buffer = (char *)alloca(device_ver_str_size + 1); + clGetDeviceInfo(device, CL_DEVICE_VERSION, device_ver_str_size, device_ver_buffer, NULL); + device_ver_buffer[device_ver_str_size] = '\0'; + GGML_LOG_INFO("ggml_opencl: device OpenCL version: %s\n", device_ver_buffer); + + if (strstr(device_ver_buffer, "OpenCL 2") == NULL && + strstr(device_ver_buffer, "OpenCL 3") == NULL) { + GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); + return backend_ctx; + } + + // Check driver version + size_t driver_version_str_size; + clGetDeviceInfo(device, CL_DRIVER_VERSION, 0, NULL, &driver_version_str_size); + char *driver_version = (char *)alloca(driver_version_str_size + 1); + clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL); + driver_version[driver_version_str_size] = '\0'; + GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); + backend_ctx->driver_version = driver_version; + + int adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); + bool has_vector_subgroup_broadcast = + adreno_cl_compiler_version >= 47 || adreno_cl_compiler_version == 17; + GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", + has_vector_subgroup_broadcast ? "true" : "false"); + + size_t ext_str_size; + clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); + char *ext_buffer = (char *)alloca(ext_str_size + 1); + clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); + ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // Check if ext_buffer contains cl_khr_fp16 + backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; + GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + + // fp16 is required + if (!backend_ctx->fp16_support) { + GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); + return backend_ctx; + } + + // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes + // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) + if (strstr(device_ver_buffer, "OpenCL 3") && + strstr(ext_buffer, "cl_khr_subgroups") == NULL && + strstr(ext_buffer, "cl_intel_subgroups") == NULL) { + GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " + "(note that subgroups is an optional feature in OpenCL 3.0)\n"); + return backend_ctx; + } + + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &backend_ctx->alignment, NULL)); + GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); + + clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); + GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); + + // Check SVM. + cl_device_svm_capabilities svm_caps; + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0)); + GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", + svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", + svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", + svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", + svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + + // Print out configurations +#ifdef GGML_OPENCL_SOA_Q + GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); +#endif // GGML_OPENCL_SOA_Q + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_context_properties properties[] = { + (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)dev_ctx->platform, 0 + }; + + CL_CHECK((backend_ctx->context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err)); + + // A local ref of cl_context for convenience + cl_context context = backend_ctx->context; + + //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err), + // (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err : + // (queue = clCreateCommandQueue(context, device, 0, &err), err) + //))); + cl_command_queue_properties command_queue_props = 0; +#ifdef GGML_OPENCL_PROFILING + command_queue_props |= CL_QUEUE_PROFILING_ENABLE; +#endif + CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "ggml-opencl.cl.h" + }; +#else + const std::string kernel_src = read_file("ggml-opencl.cl"); +#endif + + std::string compile_opts = + "-cl-std=CL2.0 -cl-mad-enable -cl-unsafe-math-optimizations " + "-cl-finite-math-only -cl-fast-relaxed-math "; + backend_ctx->program = build_program_from_source(context, device, kernel_src.c_str(), compile_opts); + + // Non matmul kernels. + CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program, "kernel_get_rows_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program, "kernel_add", &err), err)); + CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program, "kernel_add_row", &err), err)); + CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program, "kernel_mul", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program, "kernel_mul_row", &err), err)); + CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program, "kernel_scale", &err), err)); + CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program, "kernel_silu", &err), err)); + CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); + CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program, "kernel_rms_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf", &err), err)); + CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f32", &err), err)); + + // Matmul kernels. + CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_1row", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_l4", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_v", &err), err)); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_convert_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); + + // Load additional mulmat kernels. +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_1 { + #include "ggml-opencl_mm.cl.h" + }; +#else + const std::string kernel_src_1 = read_file("ggml-opencl_mm.cl"); +#endif + backend_ctx->program_1 = build_program_from_source(context, device, kernel_src_1.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mv_q6_K_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_v0", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_img_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_img_v0", &err), err)); + + // Load additional data conversion kernels. +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_2 { + #include "ggml-opencl_cvt.cl.h" + }; +#else + const std::string kernel_src_2 = read_file("ggml-opencl_cvt.cl"); +#endif + backend_ctx->program_2 = build_program_from_source(context, device, kernel_src_2.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); + + // Kernels for Adreno +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string transpose_32_src { + #include "ggml-opencl_transpose_32.cl.h" + }; +#else + const std::string transpose_32_src = read_file("ggml-opencl_transpose_32.cl"); +#endif + backend_ctx->program_transpose_32 = build_program_from_source(context, device, transpose_32_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose_32, "kernel_transpose_32", &err), err)); + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string transpose_32_16_src { + #include "ggml-opencl_transpose_32_16.cl.h" + }; +#else + const std::string transpose_32_16_src = read_file("ggml-opencl_transpose_32_16.cl"); +#endif + backend_ctx->program_transpose_32_16 = build_program_from_source(context, device, transpose_32_16_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose_32_16, "kernel_transpose_32_16", &err), err)); + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string transpose_16_src { + #include "ggml-opencl_transpose_16.cl.h" + }; +#else + const std::string transpose_16_src = read_file("ggml-opencl_transpose_16.cl"); +#endif + backend_ctx->program_transpose_16 = build_program_from_source(context, device, transpose_16_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose_16, "kernel_transpose_16", &err), err)); + + // Gemv general + std::string CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "ggml-opencl_gemv_noshuffle_general.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("ggml-opencl_gemv_noshuffle_general.cl"); +#endif + + backend_ctx->program_CL_gemv_general = build_program_from_source( + context, device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + + // Gemv 2048, 16384 + CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv { + #include "ggml-opencl_gemv_noshuffle.cl.h" + }; +#else + const std::string kernel_src_CL_gemv = read_file("ggml-opencl_gemv_noshuffle.cl"); +#endif + + backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( + context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); + + // Gemv 2048, 16384 + CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( + context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); + + // Gemv 5504, 44032 + CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DLINE_STRIDE_A=5504 " + " -DBLOCK_STRIDE_A=44032 " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( + context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + + // Gemv 16000, 128000 + CL_gemv_compile_opts = + " -cl-std=CL2.0 " + " -cl-mad-enable " + " -DLINE_STRIDE_A=16000 " + " -DBLOCK_STRIDE_A=128000 " + " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); + if (has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source(context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + + // Gemm +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemm { + #include "ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h" + }; +#else + const std::string kernel_src_CL_gemm = read_file("ggml-opencl_mul_mat_Ab_Bi_8x4.cl"); +#endif + backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + + // Allocate intermediate buffers and images + size_t max_A_q_d_bytes = 311164928; + size_t max_A_s_d_bytes = 38895616; + size_t max_B_d_bytes = 45088768; + + CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err)); + CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err)); + CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err)); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + // For now we support a single devices + ggml_backend_opencl_n_devices = 1; + + return backend_ctx; +} + +static void ggml_cl2_free(void) { +#ifdef GGML_OPENCL_PROFILING + FILE * fperf = fopen("cl_profiling.csv", "w"); + if (!fperf) { + GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + return; + } + + float total_kernel_time = 0; + fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n"); + for (const ProfilingInfo & info : g_profiling_info) { + total_kernel_time += info.duration_ns/1.e6f; + fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", + info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f, + info.global_size[0], info.global_size[1], info.global_size[2], + info.local_size[0], info.local_size[2], info.local_size[2], + info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]); + } + fclose(fperf); + + GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time); +#endif +} + +//------------------------------------------------------------------------------ +// Tensor extra management +//------------------------------------------------------------------------------ +struct ggml_tensor_extra_cl { + // The buffer object that holds the data. + cl_mem data_device; + // The offset into the buffer object. This is primarily for scratch buffer + // and view operation. + // NB: this offset no longer includes view offset (view_offs). Whenever this + // offset is used, view_offs should be considered. + cl_ulong offset; + // The actual size of the cl_mem object. This is needed when returning the + // block to the pool. + size_t actual_size; + + void reset() { + data_device = nullptr; + offset = 0; + actual_size = 0; + } +}; + +// Additional tensor extra structs for quantized tensors. +// These tensors are loaded from files and should not be allocated in scratch -- +// they should always be allocated from the pool. Hence, they do not have an +// `offset`, which indicate their locations in the scratch buffer. +struct ggml_tensor_extra_cl_q4_0 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q4_0() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + q_img = nullptr; + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + +//------------------------------------------------------------------------------ +// Backend API +//------------------------------------------------------------------------------ + +// +// backend +// +static const char * ggml_backend_opencl_name(ggml_backend_t backend) { + return "OpenCL"; + + UNUSED(backend); +} + +static void ggml_backend_opencl_free(ggml_backend_t backend) { + ggml_cl2_free(); + + GGML_UNUSED(backend); +} + +static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_UNUSED(backend); + GGML_UNUSED(tensor); + GGML_UNUSED(data); + GGML_UNUSED(offset); + GGML_UNUSED(size); +} + +static void ggml_backend_opencl_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_UNUSED(backend); + GGML_UNUSED(tensor); + GGML_UNUSED(data); + GGML_UNUSED(offset); + GGML_UNUSED(size); +} + +static bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { + GGML_UNUSED(backend); + GGML_UNUSED(src); + GGML_UNUSED(dst); + return false; +} + +static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { + GGML_UNUSED(backend); +} + +static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + bool ok = ggml_cl_compute_forward(backend, node); + if (!ok) { + GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + } + + return GGML_STATUS_SUCCESS; +} + +static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + GGML_UNUSED(dev); + + switch (op->op) { + case GGML_OP_NONE: + return true; + case GGML_OP_GET_ROWS: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + case GGML_TYPE_Q4_0: +#ifdef GGML_OPENCL_SOA_Q + // We do not support flattened Q4_0 (and possibly other Q's) + return false; +#else // GGML_OPENCL_SOA_Q + return true; +#endif // GGML_OPENCL_SOA_Q + default: + return false; + } + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + case GGML_TYPE_F16: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + default: + return false; + } + case GGML_OP_ADD: + case GGML_OP_SCALE: + case GGML_OP_MUL: + return true; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + return ggml_is_contiguous(op->src[0]); + default: + return false; + } + case GGML_OP_CLAMP: + case GGML_OP_SOFT_MAX: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return true; + case GGML_OP_MUL_MAT: + if (op->src[0]->type == GGML_TYPE_F16) { + return true; + } else if (op->src[0]->type == GGML_TYPE_F32) { + return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + } else if (op->src[0]->type == GGML_TYPE_Q4_0 || + op->src[0]->type == GGML_TYPE_Q6_K) { + return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + } + return false; + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + case GGML_OP_DIAG_MASK_INF: + return op->ne[3] == 1; + case GGML_OP_ROPE: + return true; + default: + return false; + } +} + +// Forward declaration - implementation appears later in the file. +static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type); + +static ggml_guid_t ggml_backend_opencl_guid() { + static ggml_guid guid = { 0xde, 0xe0, 0x70, 0xa2, 0x73, 0x4e, 0x4d, 0xbc, 0xb0, 0xc7, 0x4f, 0xd4, 0x6d, 0x4e, 0x90, 0xfe }; + return &guid; +} + +static ggml_backend_i ggml_backend_opencl_i = { + /* .get_name = */ ggml_backend_opencl_name, + /* .free = */ ggml_backend_opencl_free, + /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ + /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ + /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ + /* .synchronize = */ NULL, /* ggml_backend_opencl_synchronize */ + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_opencl_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; + +ggml_backend_t ggml_backend_opencl_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0); + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_opencl_guid(), + /* .interface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx + }; + + return backend; +} + +bool ggml_backend_is_opencl(ggml_backend_t backend) { + return backend && backend->iface.get_name == ggml_backend_opencl_name; +} + +// +// buffer +// +struct ggml_backend_opencl_buffer_context { + // A buffer context can hold multiple cl_mem objects. This is for flattening + // quantized weights and should be used with GGML_OPENCL_SMALL_ALLOC where + // each tensor is allocated a separate buffer. When flattening is enabled + // with small allocation, each tensor is backed by two cl_mem objects (for + // quants and scales) packed into a backend_opencl_buffer. + ggml_backend_opencl_buffer_context(cl_mem buf) + : name("OpenCL") { + buffer.push_back(buf); + } + + ~ggml_backend_opencl_buffer_context() { + for (cl_mem buf : buffer) { + CL_CHECK(clReleaseMemObject(buf)); + } + for (cl_mem im : img) { + CL_CHECK(clReleaseMemObject(im)); + } + + // Delete all extras to trigger their destructors + for (ggml_tensor_extra_cl * e : temp_tensor_extras) { + delete e; + } + for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0) { + delete e; + } + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { + delete e; + } + } + + ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { + ggml_tensor_extra_cl * extra; + if (temp_tensor_extras.empty()) { + extra = new ggml_tensor_extra_cl(); + } else { + extra = temp_tensor_extras.back(); + temp_tensor_extras.pop_back(); + } + + temp_tensor_extras_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q4_0 * ggml_opencl_alloc_temp_tensor_extra_q4_0() { + ggml_tensor_extra_cl_q4_0 * extra; + if (temp_tensor_extras_q4_0.empty()) { + extra = new ggml_tensor_extra_cl_q4_0(); + } else { + extra = temp_tensor_extras_q4_0.back(); + temp_tensor_extras_q4_0.pop_back(); + } + + temp_tensor_extras_q4_0_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + void reset() { + for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { + temp_tensor_extras.push_back(e); + } + temp_tensor_extras_in_use.clear(); + + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { + temp_tensor_extras_q4_0.push_back(e); + } + temp_tensor_extras_q4_0_in_use.clear(); + } + + // Pools for extras. Available extras are in `temp_tensor_extras`. Extras + // being used are in `temp_tensor_extras_in_use`. At the first run, new + // extras get created and put in `in_use`. When the buffer is reset via + // the `reset` callback, all extras in `in_use` get moved to available extras + // for reuse. + std::vector temp_tensor_extras; + std::vector temp_tensor_extras_in_use; + std::vector temp_tensor_extras_q4_0; + std::vector temp_tensor_extras_q4_0_in_use; + + // The buffer_context is initially created by ggml_backend_buft_alloc_buffer + // before any tensor is initialized (at the beginning of alloc_tensor_range). + // Hence, there is alway a buffer object in this vector. When each tensor is + // being initialized, this original buffer object will be released if both + // flattening and small allocation are enabled, and additional buffer + // objects will be created in init_tensor to represent flattened quantized + // weights. + std::vector buffer; + // These are image1d_buffer_t objects that wrap around the quants and scales. + // For Q4_0 quantization, there should be two of them - one for quants and + // one for scales. They should be populated only when flattening and small + // allocation are enabled. + std::vector img; + std::string name; +}; + +static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000; + +static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + delete ctx; +} + +static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { + return cl_ptr_base; + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + + ggml_cl2_init(buffer->buft->device); + + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + + ggml_tensor_extra_cl * view_extra = (ggml_tensor_extra_cl *) tensor->view_src->extra; + GGML_ASSERT(view_extra && "view_extra is nullptr?"); + + // Reuse extra of the parent tensor. The offset of this view tensor + // becomes `extra->offset + view_offs` and needs to be calculated when + // it is used. This changes is needed because of the change to + // ggml_alloc.c in https://github.com/ggerganov/llama.cpp/pull/7640. + // `buffer` passed in here will always be `tensor->buffer`. It is OK + // to allocate extras from the same buffer context for ordinary + // intermediate tensors. But for views into kv cache tensors, doing so + // would mess up the extras used by kv cache. + // Before #7640, `buffer` is for intermediate tensors, which is always + // different from that of kv cache tensors. + // + // NB: now extra->offset no longer accounts for view_offs. + // NB: this should not apply to weight tensors (for end-to-end runs, but + // may apply for test-backend-ops). + // FIXME: if any unexpected results are seen, double check the offset - + // there could be other places that need fix. + tensor->extra = view_extra; + } else { + { + size_t offset = (char *)tensor->data - (char *)cl_ptr_base; + + ggml_tensor_extra_cl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra(); + extra->offset = offset; + extra->data_device = ctx->buffer[0]; + extra->actual_size = ggml_nbytes(tensor); + + tensor->extra = extra; + } + } +} + +// The optimized gemm and gemv kernels are used for large matrices without batch. +// tensor is the quantized weights matrix. +inline bool use_adreno_kernels(const ggml_tensor *tensor) { + return tensor->ne[0] >= 512 && tensor->ne[1] >= 512 && + tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + + cl_context context = backend_ctx->context; + cl_command_queue queue = backend_ctx->queue; + +#ifdef GGML_OPENCL_SOA_Q + // We separate the quantized bits and scale from block_q4_0 by using an + // additional kernel, where each thread handles a block. We first read the + // original weights into a temporary buffer, then create two separate + // buffers for quantized bits and scales, which are then populated by the + // conversion kernel. + if (tensor->type == GGML_TYPE_Q4_0) { + // Tensors should have been preallocated, therefore they should + // already have ggml_tensor_extra_cl as extra. + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_0(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // We consider the specified offset arg as always, although For weights + // the offset arg should be 0 (we do not assert this). + //GGML_ASSERT(offset == 0); + + // We create subbuffers from the original tensor buffer for scales and + // quants - i.e., scales and quants are aliases into the buffer obejct + // that backs the original tensor. This is a cleaner way to adapt to the + // new memory management. + // In the old code, we allocate new buffers for scales and quants + // respectively, which could still be done but would result in double + // allocation; properly deallocating the preallocated buffer that backs + // the tensors is tricky and would leak the backend specific information + // into the general backend code. + // Does this create misaligned subbuffers (alignment is 1024) in certain + // cases ? + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + // Create subbuffer for scales. + region.origin = extra_orig->offset + tensor->view_offs + offset; + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + // Create subbuffer for quants. + region.origin = extra_orig->offset + tensor->view_offs + offset + size_d; + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + //cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + + // The optimized kernels need weights in natural order, so unshuffle. + if (use_adreno_kernels(tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + #endif // GGML_OPENCL_USE_ADRENO_KERNELS + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + // transpose the weights and scales + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Only do transpose for large, non batched matrix + // TODO: use preallocated images instead of sub-buffer then image + if (use_adreno_kernels(tensor)) { + // <----------------------------------------------------------------------------------> // + // start transpose + // <----------------------------------------------------------------------------------> // + int M = tensor->ne[1]; // ne01 + int K = tensor->ne[0]; // ne00 + + // transpose is out of place, so we need to allocate transposed buffers + // <----------------------------------------------------------------------------------> // + // use sub_buffer of max buffer size instead + + size_t q_size_bytes = K * M / 8 * sizeof(float); + cl_buffer_region region; + region.origin = 0; + region.size = q_size_bytes; + cl_mem qT_d = clCreateSubBuffer( + backend_ctx->A_q_d_max, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err); + CL_CHECK(err); + + // size_t d_size_bytes = M * (K / 32) / 2 * sizeof(float); + size_t d_size_bytes = M * (K / 32) * 2; + region.origin = 0; + region.size = d_size_bytes; + cl_mem dT_d = clCreateSubBuffer( + backend_ctx->A_s_d_max, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + // cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err); + CL_CHECK(err); + + // <----------------------------------------------------------------------------------> // + + + // create images from the buffers + // <----------------------------------------------------------------------------------> // + cl_mem q_d_image1D; + cl_mem d_d_image1D; + cl_mem qT_d_image1D; + cl_mem dT_d_image1D; + + cl_image_format img_fmt_1d = { CL_RGBA, CL_FLOAT }; + cl_image_desc img_desc_1d; + + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 8 / 4; + img_desc_1d.buffer = extra->q; + q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 8 / 4; + img_desc_1d.buffer = qT_d; + qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32 / 4 / 2; + img_desc_1d.buffer = extra->d; + d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32 / 4 / 2; + img_desc_1d.buffer = dT_d; + dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + // <----------------------------------------------------------------------------------> // + + // set up and call the transpose kernels + // <----------------------------------------------------------------------------------> // + // weights + int height_q = M / 8; + int width_q = K / 8 / 4; + kernel = backend_ctx->kernel_transpose_16; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); + + size_t local_size_q[3] = {4, 16, 1}; + size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // scales + int height_s = M / 8; + int width_s = K / 32 / 8; + + kernel = backend_ctx->kernel_transpose_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); + + size_t local_size_s[3] = {4, 16, 1}; + size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + // <----------------------------------------------------------------------------------> // + + // copy transposed buffer contents to original buffers + // <----------------------------------------------------------------------------------> // + // weights + CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // scales + CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + // <----------------------------------------------------------------------------------> // + + // deallocate transpose buffers + // <----------------------------------------------------------------------------------> // + CL_CHECK(clReleaseMemObject(qT_d)); + CL_CHECK(clReleaseMemObject(dT_d)); + + // deallocate temporary images + CL_CHECK(clReleaseMemObject(q_d_image1D)); + CL_CHECK(clReleaseMemObject(d_d_image1D)); + CL_CHECK(clReleaseMemObject(qT_d_image1D)); + CL_CHECK(clReleaseMemObject(dT_d_image1D)); + // <----------------------------------------------------------------------------------> // + // end transpose + // <----------------------------------------------------------------------------------> // + } + #endif // GGML_OPENCL_USE_ADRENO_KERNELS + + return; + } +#endif // GGML_OPENCL_SOA_Q + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueWriteBuffer( + queue, extra->data_device, CL_TRUE, extra->offset + offset, + size, data, 0, NULL, NULL)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->extra); + + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + + cl_context context = backend_ctx->context; + cl_command_queue queue = backend_ctx->queue; + + // Make sure all previously submitted commands are finished. + CL_CHECK(clFinish(queue)); + +#ifdef GGML_OPENCL_SOA_Q + // In end-to-end runs, get_tensor is usually used to get back the logits, + // where we can simply do clEnqueueReadBuffer since they are f32. + // However, in test-backend-ops, the GPU graph is copied to the CPU backend, + // which requires reading back quantized weight tensors. + // To properly support this, we need to restore block_q4_0 struct arrays + // from the flattened buffers. + if (tensor->type == GGML_TYPE_Q4_0) { + ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_SOA_Q + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + + CL_CHECK(clEnqueueReadBuffer( + queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset, + size, data, 0, NULL, NULL)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_dev_t dev = buffer->buft->device; + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + cl_command_queue queue = backend_ctx->queue; + + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + for (cl_mem buf : ctx->buffer) { + CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL)); + } + CL_CHECK(clFinish(queue)); +} + +static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ctx->reset(); +} + +static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { + /* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer, + /* .get_base = */ ggml_backend_opencl_buffer_get_base, + /* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_opencl_buffer_clear, + /* .reset = */ ggml_backend_opencl_buffer_reset, +}; + +// +// buffer type +// + +static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type) { + return "OpenCL"; + + GGML_UNUSED(buffer_type); +} + +static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) { + ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device); + + // clCreateBuffer returns -61 for size 0 + size = std::max(size, (size_t)1); + + cl_int err; + cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err); + if (err != CL_SUCCESS) { + GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0); + return nullptr; + } + + ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context(mem); + + return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size); +} + +static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { + // FIXME: not thread safe, device may not be initialized yet + static cl_uint alignment = -1; + if (alignment == (cl_uint)-1) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); + alignment = backend_ctx->alignment; + } + return alignment; +} + +static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { + static size_t max_size = -1; + if (max_size == (size_t)-1) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); + max_size = backend_ctx->max_alloc_size; + } + return max_size; +} + +static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return ggml_backend_is_opencl(backend); + + UNUSED(buft); +} + +static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = { + /* .get_name = */ ggml_backend_opencl_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_opencl_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, + /* .is_host = */ NULL, +}; + +ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() { + static ggml_backend_buffer_type buffer_type = { + /* .iface = */ ggml_backend_opencl_buffer_type_interface, + /* .device = */ &g_ggml_backend_opencl_device, + /* .context = */ nullptr, + }; + + return &buffer_type; +} + +// +// backend device +// + +static const char * ggml_backend_opencl_device_get_name(ggml_backend_dev_t dev) { + return "GPUOpenCL"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + return dev_ctx->device_name.c_str(); +} + +static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + *free = 1; + *total = 1; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_opencl_device_get_name(dev); + props->description = ggml_backend_opencl_device_get_description(dev); + props->type = ggml_backend_opencl_device_get_type(dev); + ggml_backend_opencl_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = ggml_backend_dev_caps { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev); + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_opencl_guid(), + /* .interface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx, + }; + + return backend; + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_opencl_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + GGML_UNUSED(dev); + GGML_UNUSED(ptr); + GGML_UNUSED(size); + GGML_UNUSED(max_tensor_size); + return nullptr; +} + +static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + return ggml_opencl_supports_op(dev, op); +} + +static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_opencl_buffer_type_get_name; + + GGML_UNUSED(dev); +} + +static struct ggml_backend_device_i ggml_backend_opencl_device_i = { + /* .get_name = */ ggml_backend_opencl_device_get_name, + /* .get_description = */ ggml_backend_opencl_device_get_description, + /* .get_memory = */ ggml_backend_opencl_device_get_memory, + /* .get_type = */ ggml_backend_opencl_device_get_type, + /* .get_props = */ ggml_backend_opencl_device_get_props, + /* .init_backend = */ ggml_backend_opencl_device_init, + /* .get_buffer_type = */ ggml_backend_opencl_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_opencl_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_opencl_device_supports_op, + /* .supports_buft = */ ggml_backend_opencl_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// Backend registry + +static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) { + return "OpenCL"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) { + return ggml_backend_opencl_n_devices; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + return &g_ggml_backend_opencl_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = { + /* .get_name = */ ggml_backend_opencl_reg_get_name, + /* .device_count = */ ggml_backend_opencl_reg_device_count, + /* .device_get = */ ggml_backend_opencl_reg_device_get, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_opencl_reg(void) { + // TODO: make this thread-safe somehow? + static ggml_backend_reg reg; + static bool initialized = false; + + if (!initialized) { + reg = ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_opencl_reg_i, + /* .context = */ NULL, + }; + + g_ggml_backend_opencl_device = ggml_backend_device { + /* .iface = */ ggml_backend_opencl_device_i, + /* .reg = */ ®, + /* .context = */ &g_ggml_ctx_dev_main, + }; + + ggml_cl2_init(&g_ggml_backend_opencl_device); + + initialized = true; + } + + return ® +} + +GGML_BACKEND_DL_IMPL(ggml_backend_opencl_reg) + +//------------------------------------------------------------------------------ +// Debugging utils +//------------------------------------------------------------------------------ +#if 0 +#define QK4_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, + "wrong q4_0 block size/padding"); + +#include +#ifdef __cplusplus +#include "half.hpp" +#endif + +static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tensor) { + void * buf = malloc(ggml_nbytes(tensor)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; +#ifdef GGML_OPENCL_SOA_Q + void * buf_q; + void * buf_d; +#endif + +#ifdef GGML_USE_OPENCL + // Make sure everything is done. + CL_CHECK(clFinish(queue)); + +#ifdef GGML_OPENCL_SOA_Q + if (tensor->type == GGML_TYPE_Q4_0) { + ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *) tensor->extra; + GGML_ASSERT(extra); + + size_t size_q = ggml_nelements(tensor)/QK4_0 * QK4_0/2; + size_t size_d = ggml_nelements(tensor)/QK4_0 * sizeof(ggml_fp16_t); + GGML_ASSERT(size_q + size_d == ggml_nbytes(tensor)); + buf_q = malloc(size_q); + buf_d = malloc(size_d); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); + } else { + // Read out the tensor from GPU memory. + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, + extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); + } +#else + // Read out the tensor from GPU memory. + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, + extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); +#endif // GGML_OPENCL_SOA_Q +#endif // GGML_USE_OPENCL + + // Open file and dump. + char fname[512]; + sprintf(fname, "./tensor-dumps/%s.txt", tensor->name); + FILE * f = fopen(fname, "w"); + if (!f) { + printf("Failed to open %s\n", fname); + return; + } + + if (tensor->type == GGML_TYPE_F32) { + float * data = (float *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%f\n", data[i]); + } + } else if (tensor->type == GGML_TYPE_I32) { + int * data = (int *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%d\n", data[i]); + } + } else if (tensor->type == GGML_TYPE_F16) { +#ifdef __cplusplus + half_float::half * data = (half_float::half *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (std::isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%f\n", float(data[i])); + } +#endif + } else if (tensor->type == GGML_TYPE_Q4_0) { +#ifdef GGML_OPENCL_SOA_Q + ggml_fp16_t * data_d = (ggml_fp16_t *)buf_d; + unsigned char * data_q = (unsigned char *)buf_q; + + for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { + fprintf(f, "%04x, ", data_d[i]); + for (int k = 0; k < QK4_0/2; ++k) { + fprintf(f, "%02x, ", data_q[k]); + } + fprintf(f, "\n"); + data_q += QK4_0/2; + } + free(buf_d); + free(buf_q); +#else + block_q4_0 * data = (block_q4_0 *) buf; + for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { + fprintf(f, "%04x, ", data[i].d); + for (int k = 0; k < QK4_0/2; ++k) { + fprintf(f, "%02x, ", data[i].qs[k]); + } + fprintf(f, "\n"); + } +#endif // GGML_OPENCL_SOA_Q + } + free(buf); + fflush(f); + fclose(f); +} +#else +#define dump_tensor(tensor) +#endif + +//------------------------------------------------------------------------------ +// Profiling utility +//------------------------------------------------------------------------------ +#ifdef GGML_OPENCL_PROFILING +void populateProfilingInfo( + ProfilingInfo& info, cl_event evt, cl_kernel kernel, + size_t global_size[3], size_t local_size[3], + const ggml_tensor * tensor) { + cl_ulong start; + cl_ulong end; + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clGetEventProfilingInfo( + evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL)); + CL_CHECK(clGetEventProfilingInfo( + evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL)); + + char kernel_name[512]; + CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME, + sizeof(kernel_name), kernel_name, NULL)); + + info.duration_ns = end - start; + info.op_name = tensor->name; + info.kernel_name = kernel_name; + info.local_size[0] = local_size[0]; + info.local_size[1] = local_size[1]; + info.local_size[2] = local_size[2]; + info.global_size[0] = global_size[0]; + info.global_size[1] = global_size[1]; + info.global_size[2] = global_size[2]; + info.output_size[0] = tensor->ne[0]; + info.output_size[1] = tensor->ne[1]; + info.output_size[2] = tensor->ne[2]; + info.output_size[3] = tensor->ne[3]; +} +#endif + +//------------------------------------------------------------------------------ +// Ops +//------------------------------------------------------------------------------ + +static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); +} + +static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + UNUSED(backend); + UNUSED(src0); + UNUSED(src1); + UNUSED(dst); +} + +static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0 ? src0->ne[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const int ne10 = src1 ? src1->ne[0] : 0; + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_get_rows_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_get_rows_f16; + break; + case GGML_TYPE_Q4_0: + kernel = backend_ctx->kernel_get_rows_q4_0; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + + size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1}; + size_t local_work_size[] = {1, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + const int ne2 = dst ? dst->ne[2] : 0; + const int ne3 = dst ? dst->ne[3] : 0; + + const cl_ulong nb0 = dst ? dst->nb[0] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const cl_ulong nb3 = dst ? dst->nb[3] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_add_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_add; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + const int ne2 = dst ? dst->ne[2] : 0; + const int ne3 = dst ? dst->ne[3] : 0; + + const cl_ulong nb0 = dst ? dst->nb[0] : 0; + const cl_ulong nb1 = dst ? dst->nb[1] : 0; + const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const cl_ulong nb3 = dst ? dst->nb[3] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_mul_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_mul; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL); +#endif +} + +static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_silu_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_silu; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel = backend_ctx->kernel_relu; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + + cl_kernel kernel = backend_ctx->kernel_clamp; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &min)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &max)); + + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int ne00 = src0 ? src0->ne[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + const int nth = MIN(64, ne00); + + cl_kernel kernel = backend_ctx->kernel_norm; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL)); + + const int64_t nrows = ggml_nrows(src0); + + size_t global_work_size[] = {(size_t)nrows*nth, 1, 1}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_backend_opencl_device_context * dev_ctx = + (ggml_backend_opencl_device_context *)backend->device->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int ne00 = src0 ? src0->ne[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + const int nth = MIN(64, ne00); + + const int64_t nrows = ggml_nrows(src0); + + size_t global_work_size[] = {(size_t)nrows*nth, 1, 1}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_rms_norm; + + // Note, this kernel declares local memory in kernel args and the size + // depends on subgroup size. + // Retrieve subgroup size. + // Note, this requires OpenCL 2.1 and above + size_t sgs; + CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, + CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, + sizeof(local_work_size), local_work_size, + sizeof(size_t), &sgs, NULL)); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + // This is local memory - the size depends on subgroup size. + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs, NULL)); + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + +#ifdef GGML_OPENCL_SOA_Q + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; +#endif + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + + int r2 = ne12/ne02; + int r3 = ne13/ne03; + + GGML_ASSERT(ne00 == ne10); + + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + // The number of values produced by each subgroup + int ndst = 4; + + cl_kernel kernel; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_context context = backend_ctx->context; + + if (ne01 && ne1 && use_adreno_kernels(src0)) { + + // init CL objects + // <--------------------------------------------> // + cl_int status; + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + cl_buffer_region region; + cl_mem A_image1d = nullptr; + cl_mem B_image1d = nullptr; + cl_mem B_sub_buffer = nullptr; + cl_mem C_d = nullptr; + // for B transpose + cl_mem B_d = nullptr; + cl_mem B_d_input_image = nullptr; + // <--------------------------------------------> // + + // define matrix dimensions + // <--------------------------------------------> // + int M = ne01; + int N = ne1; + int K = ne00; + int padding; + // <--------------------------------------------> // + + // q4_0 x fp32 + if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { + // TODO: remove duplicate definitions of image description + format -- move to top + + // create an image for A + // <--------------------------------------------> // + if (N == 1) { + img_fmt_1d = { CL_R, CL_UNSIGNED_INT32}; + } else { + img_fmt_1d = { CL_R, CL_FLOAT}; + } + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 2 / 4; // Divide by 4 for char -> float + img_desc_1d.buffer = extra0_q4_0->q; + A_image1d = clCreateImage( + context, + CL_MEM_READ_ONLY, + &img_fmt_1d, + &img_desc_1d, + NULL, + &status); + CL_CHECK(status); + // <--------------------------------------------> // + + + // create a sub_buffer for B + // <--------------------------------------------> // + region.origin = (extra1->offset); + region.size = K * N * sizeof(float); + B_sub_buffer = clCreateSubBuffer( + extra1->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // <--------------------------------------------> // + + // transpose activation for Skyler's gemm + if (N != 1) { + //how many extra elements beyond multiple of 8 + int extra_elements = N % 8; + + //how much padding to add + padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // Specify the starting offset (in bytes) + region.origin = 0; + // Specify the size of the sub-buffer (divide by 2 for FP16) + region.size = K * (N + padding) * sizeof(float)/2; + B_d = clCreateSubBuffer( + backend_ctx->B_d_max, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + + cl_image_format image_format_B_d_input = { CL_RGBA, CL_FLOAT }; + cl_image_desc image_desc_B_d_input = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(K * N / 4), + 0, 0, 0, 0, 0, 0, 0, { B_sub_buffer } + }; + B_d_input_image = clCreateImage( + context, + 0, + &image_format_B_d_input, + &image_desc_B_d_input, + NULL, + &status); + CL_CHECK(status); + + cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) + cl_image_desc image_desc_B_d_output = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(K * (N + padding)/4), + 0, 0, 0, 0, 0, 0, 0, { B_d } + }; + B_image1d = clCreateImage( + context, + 0, + &image_format_B_d_output, + &image_desc_B_d_output, + NULL, + &status); + CL_CHECK(status); + + int height_B = N/4; + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_d_input_image)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_size_t[2] = { 1, 16 }; + //WGS tuning + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_size_t[0]=4; + local_size_t[1]=8; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_size_t[0]=2; + local_size_t[1]=8; + } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_size_t[0]=1; + local_size_t[1]=8; + } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_size_t[0]=2; + local_size_t[1]=8; + } + + size_t global_size_t[2] = { + static_cast(width_B), + static_cast(padded_height_B) + }; + + #ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_size_t, local_size_t, dst); + #else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, NULL)); + #endif + } else { + // no need to transpose B in other cases + // create an image for B from sub_buffer + // <--------------------------------------------> // + img_fmt_1d = {CL_RGBA, CL_FLOAT}; + + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_width = K * N / 4; + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.buffer = B_sub_buffer; + B_image1d = clCreateImage( + context, + CL_MEM_READ_ONLY, + &img_fmt_1d, + &img_desc_1d, + NULL, + &status); + CL_CHECK(status); + // <--------------------------------------------> // + } + + // choose gemm or gemv kernel + // <--------------------------------------------> // + if (N == 1) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; + if (M == 4096 && K == 4096) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; + } else if (M == 4096 && K == 11008) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; + } else if (M == 11008 && K == 4096) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; + } else if (M == 32000 && K == 4096) { + kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + } + } else { + kernel = backend_ctx->CL_mul_mat_Ab_Bi_8x4; + } + // <--------------------------------------------> // + + // set kernel args + // <--------------------------------------------> // + cl_uint k_arg = 0; + + if (N == 1) { + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); + } else { + region.origin = extrad->offset; // Specify the starting offset (in bytes) + region.size = M * N * sizeof(float); // Specify the size of the sub-buffer + C_d = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + int padded_N = ne1 + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); //A_q_dextra0_q4_0->q + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); //A_s_d + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d)); //B_d + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &C_d)); //C_d + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); //M + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &padded_N)); //N with padding + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); //K + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne1)); //N without padding + } + // <--------------------------------------------> // + + // choose workgroup size + // <--------------------------------------------> // + size_t global_work_size[3] = { + 64, static_cast((M+63)/64), static_cast((N+31)/32)}; + size_t local_work_size[3] = {64, 2, 4}; + + global_work_size[0] = (size_t)(ceil((float)ne1/8)); + global_work_size[1] = (size_t)(ne01/4); + global_work_size[2] = (size_t)(1); + + local_work_size[0] = (size_t)(1); //4x32 for FP32 + local_work_size[1] = (size_t)(128); + local_work_size[2] = (size_t)(1); + + //WGS tuning + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 1; + local_work_size[1] = 128; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } + + if (N == 1) { + local_work_size[0] = backend_ctx->adreno_wave_size; // localsize + local_work_size[1] = 4; // reduce factor + local_work_size[2] = 1; + + global_work_size[0] = M / 2; + global_work_size[1] = 4; // reduce factor + global_work_size[2] = 1; + } + // <--------------------------------------------> // + + // enqueue kernel with profiling + // <--------------------------------------------> // + #ifdef GGML_OPENCL_PROFILING + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + // enqueue kernel without profiling + #else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + #endif + // <--------------------------------------------> // + + // deallocate sub buffers and images + // <--------------------------------------------> // + CL_CHECK(clReleaseMemObject(A_image1d)); + CL_CHECK(clReleaseMemObject(B_sub_buffer)); + CL_CHECK(clReleaseMemObject(B_image1d)); + + if (N != 1) { + CL_CHECK(clReleaseMemObject(B_d)); + CL_CHECK(clReleaseMemObject(B_d_input_image)); + CL_CHECK(clReleaseMemObject(C_d)); + } + // <--------------------------------------------> // + + return; + } + } // if (ne01 && ne1) +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + if (!ggml_is_transposed(src0) && + !ggml_is_transposed(src1) && + src1t == GGML_TYPE_F32 && + ne00%32 == 0 && + ne11 > 2) { +#ifdef GGML_OPENCL_SOA_Q + // Set up kernel. + switch(src0t) { + case GGML_TYPE_Q4_0: + // This should have been satisfied. + GGML_ASSERT(ne11 == ne1); + GGML_ASSERT(ne01 == ne0); + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + break; + default: + break; + } + + // Launch kernel. + if (src0t == GGML_TYPE_Q4_0) { + size_t global_work_size[] = {(size_t)(ne01 + 7)/8*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + if (backend_ctx->gpu_family == INTEL) { + // Set global size for Intel. It uses 16x output values. + global_work_size[0] = (size_t)(ne01 + 15)/16*nth0; + global_work_size[1] = (size_t)ne11*nth1; + global_work_size[2] = (size_t)ne12*ne13; + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + return; + } +#else // GGML_OPENCL_SOA_Q + // TODO: add block_q4_0 variant. +#endif // GGML_OPENCL_SOA_Q + } + + // use custom matrix x vector kernel + switch (src0t) { + case GGML_TYPE_F32: + //GGML_ASSERT(ne02 == ne12); + GGML_ASSERT(src1t == GGML_TYPE_F32); + kernel = backend_ctx->kernel_mul_mat_f32_f32; + nrows = 4; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 32; + nth1 = 1; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); + break; + case GGML_TYPE_F16: + //GGML_ASSERT(ne02 == ne12); + if (backend_ctx->gpu_family == INTEL) { + nth0 = 32; + nth1 = 1; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + kernel = backend_ctx->kernel_mul_mat_f16_f32_1row; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + kernel = backend_ctx->kernel_mul_mat_f16_f32_l4; + nrows = ne11; + } else { + kernel = backend_ctx->kernel_mul_mat_f16_f32; + nrows = 4; + } + } else { + kernel = backend_ctx->kernel_mul_mat_f16_f16; + nrows = 4; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); + break; + case GGML_TYPE_Q4_0: + // This should have been satisfied. + GGML_ASSERT(ne11 == ne1); + GGML_ASSERT(ne01 == ne0); + +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; + ndst =8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#else // GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + // Use 1D local size. Each workgroup is a SIMD group. Each SIMD + // group produces N_DST (4 for Q4_0 kernel) values in the result. + // The number of workgroups on dim 0 (the leading dimension) is + // the nearest multiple of 4 that covers ne0 (equals ne01). + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_v; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + kernel = backend_ctx->kernel_mul_mv_q6_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 2; + nth1 = 16; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 2; + nth1 = 64; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + if (src0t == GGML_TYPE_Q4_0 || + src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_Q2_K) { + // Each SIMD group produces N_DST values in the result. Assuming each + // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will + // produce N_DST*N_SIMDGROUP values in the result. Hence, the grid size + // (number of workgroups) will be a nearest multiple of + // N_DST*N_SIMDGROUP to cover the size of the dimension. Below, 4 is + // N_DST*N_SIMDGROUP (see the kernel for Q4_0 matmul). + size_t global_work_size[] = {(size_t)(ne01 + ndst-1)/ndst*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else if (src0t == GGML_TYPE_Q4_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q3_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q5_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q6_K) { + size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + int64_t ny = (ne11 + nrows - 1)/nrows; + + size_t global_work_size[] = {(size_t)ne01*nth0, (size_t)ny*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + float scale; + memcpy(&scale, dst->op_params, sizeof(scale)); + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel = backend_ctx->kernel_scale; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); + + int n = ggml_nelements(dst)/4; + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + + // GGML_OP_CPY happens between src0 and src1. + // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst. + UNUSED(dst); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb00 = src0 ? src0->nb[0] : 0; + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; + + const cl_ulong nb10 = src1 ? src1->nb[0] : 0; + const cl_ulong nb11 = src1 ? src1->nb[1] : 0; + const cl_ulong nb12 = src1 ? src1->nb[2] : 0; + const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + + const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; + const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + + cl_kernel kernel; + + switch (src0t) { + case GGML_TYPE_F32: + switch (src1t) { + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_cpy_f32_f16; + break; + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_cpy_f32_f32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; + case GGML_TYPE_F16: + switch (src1t) { + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_cpy_f16_f16; + break; + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_cpy_f16_f32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + + const int nth = MIN(64, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, src1); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cl_cpy(backend, src0, dst, nullptr); + UNUSED(src1); +} + +static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + int n_past = ((int32_t *)(dst->op_params))[0]; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + if (ne00%8 == 0) { + kernel = backend_ctx->kernel_diag_mask_inf_8; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n_past)); + + size_t global_work_size[] = {(size_t)ne00*ne01*ne02/8, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + kernel = backend_ctx->kernel_diag_mask_inf; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n_past)); + + size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + // Softmax can now fuse KQ mask and KQ scale, which used to be two additional + // ops before softmax. It now also fuses alibi if `max_bias > 0`. For llama, + // alibi is not used; however, for some other models, it is used. + // KQ_mask + if (src1) { + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + } + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + float scale, max_bias; + memcpy(&scale, dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, dst->op_params + 1, sizeof(float)); + + const int nrows_x = ggml_nrows(src0); + const int nrows_y = src0->ne[1]; + + const int n_head = nrows_x/nrows_y; + const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // Local size must be wave size. Each workgroup is a wave, working on a row, + // where a row corresponds to leading dimension. + int nth = MIN(32, ne00); + + if (backend_ctx->gpu_family == INTEL) { + // This is the same as the initial value. + nth = MIN(32, ne00); + } + else if (backend_ctx->gpu_family == ADRENO) { + nth = 64; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + cl_kernel kernel; + + if (ne00%4 == 0) { + kernel = backend_ctx->kernel_soft_max_4; + } else { + kernel = backend_ctx->kernel_soft_max; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), extra1 ? &extra1->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2)); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_tensor * src2 = dst->src[2]; + ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr; + + cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0; + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const int nb00 = src0 ? src0->nb[0] : 0; + const int nb01 = src0 ? src0->nb[1] : 0; + const int nb02 = src0 ? src0->nb[2] : 0; + const int nb03 = src0 ? src0->nb[3] : 0; + + const int ne10 = src1 ? src1->ne[0] : 0; + const int ne11 = src1 ? src1->ne[1] : 0; UNUSED(ne11); + const int ne12 = src1 ? src1->ne[2] : 0; UNUSED(ne12); + const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + + const int ne0 = dst ? dst->ne[0] : 0; + const int ne1 = dst ? dst->ne[1] : 0; + const int ne2 = dst ? dst->ne[2] : 0; + const int ne3 = dst ? dst->ne[3] : 0; + + const int nb0 = dst ? dst->nb[0] : 0; + const int nb1 = dst ? dst->nb[1] : 0; + const int nb2 = dst ? dst->nb[2] : 0; + const int nb3 = dst ? dst->nb[3] : 0; + + GGML_ASSERT(ne10 == ne02); + + int nth = MIN(64, ne00); + + const int n_past = ((int *) dst->op_params)[0]; + const int n_dims = ((int *) dst->op_params)[1]; + const int mode = ((int *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + const bool is_neox = mode & 2; + + cl_kernel kernel; + + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_norm_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_norm_f16; + break; + default: + GGML_ASSERT(false); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_rope_neox_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_rope_neox_f16; + break; + default: + GGML_ASSERT(false); + }; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), extra2 ? &extra2->data_device : &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &n_past)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &n_dims)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &n_ctx_orig)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &freq_base)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(float), &freq_scale)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &ext_factor)); + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float), &attn_factor)); + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float), &beta_fast)); + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &beta_slow)); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +//------------------------------------------------------------------------------ +// Op offloading +//------------------------------------------------------------------------------ + +typedef void (*ggml_cl_func_t)(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor) { + ggml_cl_func_t func = nullptr; + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + + const bool any_on_device = tensor->extra + || (src0 != nullptr && src0->extra) + || (src1 != nullptr && src1->extra); + + switch (tensor->op) { + case GGML_OP_GET_ROWS: + if (!any_on_device) { + return false; + } + func = ggml_cl_get_rows; + break; + case GGML_OP_CPY: + if (!any_on_device) { + return false; + } + func = ggml_cl_cpy; + break; + case GGML_OP_DUP: + case GGML_OP_CONT: + if (!any_on_device) { + return false; + } + func = ggml_cl_dup; + break; + case GGML_OP_ADD: + if (!any_on_device) { + return false; + } + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + func = ggml_cl_add; + break; + case GGML_OP_MUL: + if (!any_on_device) { + return false; + } + func = ggml_cl_mul; + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_GELU: + if (!any_on_device) { + return false; + } + func = ggml_cl_gelu; + break; + case GGML_UNARY_OP_SILU: + if (!any_on_device) { + return false; + } + func = ggml_cl_silu; + break; + case GGML_UNARY_OP_RELU: + if (!any_on_device) { + return false; + } + func = ggml_cl_relu; + break; + default: + return false; + } break; + case GGML_OP_CLAMP: + if (!any_on_device) { + return false; + } + func = ggml_cl_clamp; + break; + case GGML_OP_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_norm; + break; + case GGML_OP_RMS_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_rms_norm; + break; + case GGML_OP_MUL_MAT: + if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { + return false; + } + func = ggml_cl_mul_mat; + break; + case GGML_OP_SCALE: + if (!any_on_device) { + return false; + } + func = ggml_cl_scale; + break; + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + if (!any_on_device) { + return false; + } + func = ggml_cl_nop; + break; + case GGML_OP_DIAG_MASK_INF: + if (!any_on_device) { + return false; + } + func = ggml_cl_diag_mask_inf; + break; + case GGML_OP_SOFT_MAX: + if (!any_on_device) { + return false; + } + func = ggml_cl_soft_max; + break; + case GGML_OP_ROPE: + if (!any_on_device) { + return false; + } + func = ggml_cl_rope; + break; + default: + return false; + } + + func(backend, tensor->src[0], tensor->src[1], tensor); + return true; +} diff --git a/ggml/src/ggml-opencl/kernels/embed_kernel.py b/ggml/src/ggml-opencl/kernels/embed_kernel.py new file mode 100644 index 000000000..b5d1d7242 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/embed_kernel.py @@ -0,0 +1,26 @@ +# + +import sys +import logging +logger = logging.getLogger("opencl-embed-kernel") + + +def main(): + logging.basicConfig(level=logging.INFO) + + if len(sys.argv) != 3: + logger.info("Usage: python embed_kernel.py ") + sys.exit(1) + + ifile = open(sys.argv[1], "r") + ofile = open(sys.argv[2], "w") + + for i in ifile: + ofile.write('R"({})"\n'.format(i)) + + ifile.close() + ofile.close() + + +if __name__ == "__main__": + main() diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl new file mode 100644 index 000000000..d1cdf709b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl @@ -0,0 +1,2683 @@ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#elif defined(cl_amd_fp16) +#pragma OPENCL EXTENSION cl_amd_fp16 : enable +#else +#error "Half precision floating point not supportedby OpenCL implementation on your device." +#endif + +#ifdef cl_khr_subgroups +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#elif defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#error "Subgroup not supported on your device." +#endif + +#ifdef cl_intel_required_subgroup_size +// Always use subgroup size of 32 on Intel. +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +// Always use subgroups size of 64 on Adreno. +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +// TODO: do not know how to choose subgroup size on other GPUs. +#error "Selecting subgroup size is not supported on your device." +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q4_1 +//------------------------------------------------------------------------------ +struct block_q4_1 +{ + half d; + half m; + uint8_t qs[QK4_1 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q5_0 +//------------------------------------------------------------------------------ +struct block_q5_0 +{ + half d; + uint32_t qh; + uint8_t qs[QK5_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q5_1 +//------------------------------------------------------------------------------ +struct block_q5_1 +{ + half d; + half m; + uint32_t qh; + uint8_t qs[QK5_1 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q8_0 +//------------------------------------------------------------------------------ +struct block_q8_0 +{ + half d; + int8_t qs[QK8_0]; +}; + +//------------------------------------------------------------------------------ +// block_q2_K +//------------------------------------------------------------------------------ +struct block_q2_K +{ + uint8_t scales[16]; + uint8_t qs[64]; + half d; + half dmin; +}; + +//------------------------------------------------------------------------------ +// block_q3_K +//------------------------------------------------------------------------------ +struct block_q3_K +{ + uint8_t hmask[32]; + uint8_t qs[64]; + uint8_t scales[12]; + half d; +}; + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +struct block_q4_K +{ + half d; + half dmin; + uint8_t scales[12]; + uint8_t qs[128]; +}; + +//------------------------------------------------------------------------------ +// block_q5_K +//------------------------------------------------------------------------------ +struct block_q5_K +{ + half d; + half dmin; + uint8_t scales[12]; + uint8_t qh[32]; + uint8_t qs[128]; +}; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +struct block_q6_K +{ + uint8_t ql[128]; + uint8_t qh[64]; + int8_t scales[16]; + half d; +}; + +//------------------------------------------------------------------------------ +// dequantize_q4_0_f32, dequantize_q4_0_f16 +//------------------------------------------------------------------------------ +void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { + global ushort * qs = ((global ushort *)xb + 1); + float d1 = il ? (xb->d / 16.h) : xb->d; + float d2 = d1 / 256.f; + float md = -8.h * xb->d; + ushort mask0 = il ? 0x00F0 : 0x000F; + ushort mask1 = mask0 << 8; + + reg->s0 = d1 * (qs[0] & mask0) + md; + reg->s1 = d2 * (qs[0] & mask1) + md; + + reg->s2 = d1 * (qs[1] & mask0) + md; + reg->s3 = d2 * (qs[1] & mask1) + md; + + reg->s4 = d1 * (qs[2] & mask0) + md; + reg->s5 = d2 * (qs[2] & mask1) + md; + + reg->s6 = d1 * (qs[3] & mask0) + md; + reg->s7 = d2 * (qs[3] & mask1) + md; + + reg->s8 = d1 * (qs[4] & mask0) + md; + reg->s9 = d2 * (qs[4] & mask1) + md; + + reg->sa = d1 * (qs[5] & mask0) + md; + reg->sb = d2 * (qs[5] & mask1) + md; + + reg->sc = d1 * (qs[6] & mask0) + md; + reg->sd = d2 * (qs[6] & mask1) + md; + + reg->se = d1 * (qs[7] & mask0) + md; + reg->sf = d2 * (qs[7] & mask1) + md; +} + +void dequantize_q4_0_f16(global struct block_q4_0 * xb, short il, half16 * reg) { + global ushort * qs = ((global ushort *)xb + 1); + half d1 = il ? (xb->d / 16.h) : xb->d; + half d2 = d1 / 256.h; + half md = -8.h * xb->d; + ushort mask0 = il ? 0x00F0 : 0x000F; + ushort mask1 = mask0 << 8; + + reg->s0 = d1 * (qs[0] & mask0) + md; + reg->s1 = d2 * (qs[0] & mask1) + md; + + reg->s2 = d1 * (qs[1] & mask0) + md; + reg->s3 = d2 * (qs[1] & mask1) + md; + + reg->s4 = d1 * (qs[2] & mask0) + md; + reg->s5 = d2 * (qs[2] & mask1) + md; + + reg->s6 = d1 * (qs[3] & mask0) + md; + reg->s7 = d2 * (qs[3] & mask1) + md; + + reg->s8 = d1 * (qs[4] & mask0) + md; + reg->s9 = d2 * (qs[4] & mask1) + md; + + reg->sa = d1 * (qs[5] & mask0) + md; + reg->sb = d2 * (qs[5] & mask1) + md; + + reg->sc = d1 * (qs[6] & mask0) + md; + reg->sd = d2 * (qs[6] & mask1) + md; + + reg->se = d1 * (qs[7] & mask0) + md; + reg->sf = d2 * (qs[7] & mask1) + md; +} + +//------------------------------------------------------------------------------ +// add +//------------------------------------------------------------------------------ + +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient +kernel void kernel_add( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] + src1[idx1]; +} + +//------------------------------------------------------------------------------ +// mul +//------------------------------------------------------------------------------ +kernel void kernel_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] * src1[idx1]; +} + +//------------------------------------------------------------------------------ +// scale +//------------------------------------------------------------------------------ +kernel void kernel_scale( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + float scale +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale; +} + +//------------------------------------------------------------------------------ +// gelu +//------------------------------------------------------------------------------ +#define GELU_COEF_A 0.044715f +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f + +kernel void kernel_gelu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +//------------------------------------------------------------------------------ +// silu +//------------------------------------------------------------------------------ +kernel void kernel_silu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} + +//------------------------------------------------------------------------------ +// relu +//------------------------------------------------------------------------------ +kernel void kernel_relu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); +} + +//------------------------------------------------------------------------------ +// clamp +//------------------------------------------------------------------------------ +kernel void kernel_clamp( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float min, + float max +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = src0[get_global_id(0)] < min ? + min : + (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); +} + +//------------------------------------------------------------------------------ +// norm +//------------------------------------------------------------------------------ +kernel void kernel_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + float eps, + local float * sum +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global void*)((global char*)dst + offsetd); + + global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01); + + // MEAN + // parallel sum + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + sum[get_local_id(0)] += x[i00]; + } + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float mean = sum[0] / ne00; + + // recenter and VARIANCE + barrier(CLK_LOCAL_MEM_FENCE); + global float * y = dst + get_group_id(0)*ne00; + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = x[i00] - mean; + sum[get_local_id(0)] += y[i00] * y[i00]; + } + + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float variance = sum[0] / ne00; + + float scale = 1.0f/sqrt(variance + eps); + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = y[i00] * scale; + } +} + +//------------------------------------------------------------------------------ +// rms_norm +//------------------------------------------------------------------------------ +// This kernel depends on subgroup size. +kernel void kernel_rms_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + float eps, + local float * sum // Note, the size depends on number of subgroups +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01); + global float * x_scalar = (global float *) x; + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; + all_sum = sub_group_reduce_add(all_sum); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = all_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00); + global float * y_scalar = (global float *) y; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } + if (get_local_id(0) == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } + } +} + +//------------------------------------------------------------------------------ +// diag_mask_inf kernels +//------------------------------------------------------------------------------ +kernel void kernel_diag_mask_inf( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i02 = get_global_id(2); + int i01 = get_global_id(1); + int i00 = get_global_id(0); + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + int i = 2*get_global_id(0); + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int i4 = 4*i; + int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + int i01 = i4/(ne00); i4 -= i01*ne00; + int i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + (&dst[i+1])[k] = -INFINITY; + if (i00 + k > n_past + i01) { + (&dst[i])[k] = -INFINITY; + } + } +} + +//------------------------------------------------------------------------------ +// softmax +//------------------------------------------------------------------------------ +kernel void kernel_soft_max( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} + +//------------------------------------------------------------------------------ +// kernel_rope +//------------------------------------------------------------------------------ +float rope_yarn_ramp(float low, float high, int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +float2 rope_yarn( + float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + return (float2)(cos(theta) * mscale, sin(theta) * mscale); +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +float2 rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow +) { + // start and end correction dims + return (float2)( + max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), + min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) + ); +} + +kernel void kernel_rope_norm_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_norm_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +//------------------------------------------------------------------------------ +// cpy +//------------------------------------------------------------------------------ + +kernel void kernel_cpy_f16_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + global half * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + + src0 = (global half*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + global float * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +//------------------------------------------------------------------------------ +// get_rows +//------------------------------------------------------------------------------ +kernel void kernel_get_rows_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_q4_0( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int NL = 2; + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { + float16 temp; + dequantize_q4_0_f32( + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} + +//------------------------------------------------------------------------------ +// mul_mat_f32_f32 +//------------------------------------------------------------------------------ +#define N_F32_F32 4 + +kernel void kernel_mul_mat_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F32_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global float * x = (global float *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global float4 * x4 = (global float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +//------------------------------------------------------------------------------ +// mul_mat_f16_f16 +//------------------------------------------------------------------------------ +#define N_F16_F16 4 + +kernel void kernel_mul_mat_f16_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3) +{ + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F16; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + global half4 * y4 = (global half4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (half) x4[i].s0 * y4[i].s0; + sumf += (half) x4[i].s1 * y4[i].s1; + sumf += (half) x4[i].s2 * y4[i].s2; + sumf += (half) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (half) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +//------------------------------------------------------------------------------ +// mul_mat_f16_f32_1row +//------------------------------------------------------------------------------ +kernel void kernel_mul_mat_f16_f32_1row( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * x = (global half *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + if (ne00 < 128) { + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + global half4 * x4 = (global half4 *) x; + global float4 * y4 = (global float4 *) y; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + +} + +//------------------------------------------------------------------------------ +// mul_mat_f16_f32 +//------------------------------------------------------------------------------ +#define N_F16_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += convert_float(x[i]) * y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +//------------------------------------------------------------------------------ +// mul_mat_f16_f32_l4 +//------------------------------------------------------------------------------ +// Assumes row size (ne00) is a multiple of 4 +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_l4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nrows = ne11; + int r0 = get_group_id(0); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half4 * x4 = (global half4 *) (src0 + offset_src0); + + for (int r1 = 0; r1 < nrows; ++r1) { + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float4 * y4 = (global float4 *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_4_0_dot_y( + global struct block_q4_0 * qb_curr, + float sumy, + private float * yl, + int il +) { + float d = qb_curr->d; + float2 acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc.s0 + acc.s1); +} + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); + } + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +// +// This variant unrolls the loops and uses vector types instead of pointers. +// It improves performance on Adreno but not so much on Intel. +// +inline float block_q_4_0_dot_y_v( + global struct block_q4_0 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_v( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; // src1 vector cache + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_v( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_0 +// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_0( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_0/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_0( + global uchar * src_q, + global half * src_d, + global struct block_q4_0 * dst +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_0/2; ++i) { + b->qs[i] = q[i]; + } +} + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32_flat +// +// This variation uses flat arrays (struct of arrays, SOA) representation for +// quant tensors. +//------------------------------------------------------------------------------ + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl new file mode 100644 index 000000000..e2024332f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl @@ -0,0 +1,106 @@ +//------------------------------------------------------------------------------ +// This file is contains additional kernels for data conversion. +// These kernels are used when loading the model, so its performance is less +// important. +//------------------------------------------------------------------------------ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#elif defined(cl_amd_fp16) +#pragma OPENCL EXTENSION cl_amd_fp16 : enable +#else +#error "Half precision floating point not supportedby OpenCL implementation on your device." +#endif + +#ifdef cl_khr_subgroups +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#elif defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#error "Subgroup not supported on your device." +#endif + +#ifdef cl_intel_required_subgroup_size +// Always use subgroup size of 32 on Intel. +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +// Always use subgroups size of 64 on Adreno. +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +// TODO: do not know how to choose subgroup size on other GPUs. +#error "Selecting subgroup size is not supported on your device." +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32_flat_noshuffle +// +// This variation uses flat arrays (struct of arrays, SOA) representation for +// quant tensors. It also uses non shuffled bit order for weights. +// +// The shuffled version is kept in the original file because moving it here +// seems to result in worse performance for adreno. +//------------------------------------------------------------------------------ + +kernel void kernel_convert_block_q4_0_noshuffle( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + for (int i = 0; i < QK4_0/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + // Workaround for adreno - must have the following printf statement for + // the kernel to work properly. Otherwise it produces incorrect result. + // convert_uchar above also seems necessary. + // Compare against a large number so that it does not print anything. + // get_sub_group_local_id() also works. + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl new file mode 100644 index 000000000..5e195411d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl @@ -0,0 +1,265 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +// assume +#define QK4_0 32 +#define N_SIMDGROUP 4 + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +__attribute__((qcom_reqd_sub_group_size("full"))) +__kernel void kernel_gemv_noshuffle( + __read_only image1d_buffer_t src0_q, // quantized A + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B + ulong offset1, // offset to B (0) + global float * dst, // C + ulong offsetd, // offset to C (0) + uint K, // K + int ne01, // M + int ne02, // 1 + int ne10, // K + int ne12, // 1 + int ne0, // M + int ne1, // N + int r2, // 1 + int r3) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + __private uint4 regA; + __private half2 regS; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + // reduction in local memory, assumes #wave=4 + __local float2 reduceLM[SIMDGROUP_WIDTH * 3]; + if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum; + if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum; + if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum; + barrier(CLK_LOCAL_MEM_FENCE); + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl new file mode 100644 index 000000000..5bdd4d067 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl @@ -0,0 +1,271 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +// assume +#define QK4_0 32 +#define N_SIMDGROUP 4 + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \ + + +__attribute__((qcom_reqd_sub_group_size("full"))) +__kernel void kernel_gemv_noshuffle( + __read_only image1d_buffer_t src0_q, // quantized A + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B + ulong offset1, // offset to B (0) + global float * dst, // C + ulong offsetd, // offset to C (0) + int ne00, // K + int ne01, // M + int ne02, // 1 + int ne10, // K + int ne12, // 1 + int ne0, // M + int ne1, // N + int r2, // 1 + int r3) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = N_SIMDGROUP * M; + + __private uint4 regA; + __private half2 regS; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + // reduction in local memory, assumes #wave=4 + __local float2 reduceLM[SIMDGROUP_WIDTH * 3]; + if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum; + if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum; + if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum; + barrier(CLK_LOCAL_MEM_FENCE); + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl new file mode 100644 index 000000000..e19e9a2f4 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl @@ -0,0 +1,1225 @@ +//------------------------------------------------------------------------------ +// This file is contains additional mulmat kernels +// (and potentially other kernels). +//------------------------------------------------------------------------------ +#ifdef cl_khr_fp16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#elif defined(cl_amd_fp16) +#pragma OPENCL EXTENSION cl_amd_fp16 : enable +#else +#error "Half precision floating point not supportedby OpenCL implementation on your device." +#endif + +#ifdef cl_khr_subgroups +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#elif defined(cl_intel_subgroups) +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#error "Subgroup not supported on your device." +#endif + +#ifdef cl_intel_required_subgroup_size +// Always use subgroup size of 32 on Intel. +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +// Always use subgroups size of 64 on Adreno. +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +// TODO: do not know how to choose subgroup size on other GPUs. +#error "Selecting subgroup size is not supported on your device." +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; + +//------------------------------------------------------------------------------ +// These are the variant for matmatmul, based on the matvecmul kernel with +// flattened block_q4_0. +//------------------------------------------------------------------------------ + +// Common dot prod. +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 8x output. +// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 16x output. +// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); + sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); + sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); + sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); + + sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); + sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); + sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); + sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float16 tot = (float16)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), + + sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), + sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), + sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), + sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + + if (first_row + 8 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; + } + if (first_row + 9 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; + } + if (first_row + 10 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; + } + if (first_row + 11 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; + } + + if (first_row + 12 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; + } + if (first_row + 13 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; + } + if (first_row + 14 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; + } + if (first_row + 15 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} + +//------------------------------------------------------------------------------ +// kernel_mul_mat_q4_0_f32_flat_v0 +//------------------------------------------------------------------------------ +inline float block_q_4_0_dot_y_flat_v2( + half x, + half d, + float sumy, + float4 yl +) { + uchar2 q = as_uchar2(x); + float acc = 0.0f; + + acc += (q.s0 & 0x0F) * yl.s0; + acc += (q.s1 & 0x0F) * yl.s1; + + acc += (q.s0 & 0xF0) * yl.s2; + acc += (q.s1 & 0xF0) * yl.s3; + + return d * (sumy * -8.f + acc);; +} + +inline float block_q_4_0_dot_y_flat_v4( + float x, + half d, + float sumy, + float8 yl +) { + uchar4 q = as_uchar4(x); + float acc = 0.0f; + + acc += (q.s0 & 0x0F) * yl.s0; + acc += (q.s1 & 0x0F) * yl.s1; + acc += (q.s2 & 0x0F) * yl.s2; + acc += (q.s3 & 0x0F) * yl.s3; + + acc += (q.s0 & 0xF0) * yl.s4; + acc += (q.s1 & 0xF0) * yl.s5; + acc += (q.s2 & 0xF0) * yl.s6; + acc += (q.s3 & 0xF0) * yl.s7; + + return d * (sumy * -8.f + acc);; +} + +inline float block_q_4_0_dot_y_flat_v8( + float2 x, + half d, + float sumy, + float16 yl +) { + uchar8 q = as_uchar8(x); + float acc = 0.0f; + + acc += (q.s0 & 0x0F) * yl.s0; + acc += (q.s1 & 0x0F) * yl.s1; + acc += (q.s2 & 0x0F) * yl.s2; + acc += (q.s3 & 0x0F) * yl.s3; + acc += (q.s4 & 0x0F) * yl.s4; + acc += (q.s5 & 0x0F) * yl.s5; + acc += (q.s6 & 0x0F) * yl.s6; + acc += (q.s7 & 0x0F) * yl.s7; + + acc += (q.s0 & 0xF0) * yl.s8; + acc += (q.s1 & 0xF0) * yl.s9; + acc += (q.s2 & 0xF0) * yl.sa; + acc += (q.s3 & 0xF0) * yl.sb; + acc += (q.s4 & 0xF0) * yl.sc; + acc += (q.s5 & 0xF0) * yl.sd; + acc += (q.s6 & 0xF0) * yl.se; + acc += (q.s7 & 0xF0) * yl.sf; + + return d * (sumy * -8.f + acc);; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined (ADRENO_GPU) +#define THREADS_PER_BLK 4 +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block +# define ACT_TY float16 +# define Q_BLK_LD_TY float2 +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 +#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block +# define ACT_TY float8 +# define Q_BLK_LD_TY float +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 +#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block +# define ACT_TY float4 +# define Q_BLK_LD_TY half +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 +#endif + +#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) + +#if N_DST == 2 +# define SUM_TY float2 +#elif N_DST == 4 +# define SUM_TY float4 +#elif N_DST == 8 +# define SUM_TY float8 +#elif N_DST == 16 +# define SUM_TY float16 +#endif + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_flat_v0( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + int ix = get_sub_group_local_id()/THREADS_PER_BLK; + int il = get_sub_group_local_id()%THREADS_PER_BLK; + + global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; + + // Registers for caching activation + ACT_TY yl = 0.f; + + // Registers for caching quants + Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; +#endif +#if N_DST == 8 || N_DST == 16 + Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; +#endif + + // Partial sum + SUM_TY sumf = 0.f; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { + float sumy = 0.f; + + q_blk_0 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 0*nb*QK4_0/2); + q_blk_1 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 1*nb*QK4_0/2); +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + q_blk_2 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 2*nb*QK4_0/2); + q_blk_3 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 3*nb*QK4_0/2); +#endif +#if N_DST == 8 || N_DST == 16 + q_blk_4 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 4*nb*QK4_0/2)); + q_blk_5 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 5*nb*QK4_0/2)); + q_blk_6 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 6*nb*QK4_0/2)); + q_blk_7 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 7*nb*QK4_0/2)); +#endif + + // Load activation +#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block + yl.s01234567 = *(global float8 *)(yb); + yl.s89abcdef = *(global float8 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; + sumy += yl.s3; + sumy += yl.s4; + sumy += yl.s5; + sumy += yl.s6; + sumy += yl.s7; + sumy += yl.s8; yl.s8 /= 16.f; + sumy += yl.s9; yl.s9 /= 16.f; + sumy += yl.sa; yl.sa /= 16.f; + sumy += yl.sb; yl.sb /= 16.f; + sumy += yl.sc; yl.sc /= 16.f; + sumy += yl.sd; yl.sd /= 16.f; + sumy += yl.se; yl.se /= 16.f; + sumy += yl.sf; yl.sf /= 16.f; +#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block + yl.s0123 = *(global float4 *)(yb); + yl.s4567 = *(global float4 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; + sumy += yl.s3; + sumy += yl.s4; yl.s4 /= 16.f; + sumy += yl.s5; yl.s5 /= 16.f; + sumy += yl.s6; yl.s6 /= 16.f; + sumy += yl.s7; yl.s7 /= 16.f; +#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block + yl.s01 = *(global float2 *)(yb); + yl.s23 = *(global float2 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; yl.s2 /= 16.f; + sumy += yl.s3; yl.s3 /= 16.f; +#endif + + sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, *(d + ib + 0*nb), sumy, yl); + sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, *(d + ib + 1*nb), sumy, yl); +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, *(d + ib + 2*nb), sumy, yl); + sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, *(d + ib + 3*nb), sumy, yl); +#endif +#if N_DST == 8 || N_DST == 16 + sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, *(d + ib + 4*nb), sumy, yl); + sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, *(d + ib + 5*nb), sumy, yl); + sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, *(d + ib + 6*nb), sumy, yl); + sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, *(d + ib + 7*nb), sumy, yl); +#endif + + yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); + } + + SUM_TY tot = (SUM_TY)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) +#endif +#if N_DST == 8 || N_DST == 16 + , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) + , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) +#endif + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } +#endif +#if N_DST == 8 || N_DST == 16 + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } +#endif + } +} + +//------------------------------------------------------------------------------ +// Using image1d_buffer_t + +#if defined(cl_qcom_subgroup_shuffle) +#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable +float qcom_sub_group_reduce_add(float sum) { + sum += qcom_sub_group_shuffle_down(sum, 32, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 16, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 8, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 4, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 2, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + sum += qcom_sub_group_shuffle_down(sum, 1, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); + return sum; +} +#define sub_group_reduce_add qcom_sub_group_reduce_add +#else +#define sub_group_reduce_add sub_group_reduce_add +#endif + +#undef THREADS_PER_BLK +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined (ADRENO_GPU) +#define THREADS_PER_BLK 4 +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block +# define ACT_TY float16 +# define Q_BLK_LD_TY float2 +# define EXTRACT_BLK_DATA(tmp, part) *((float2*)&tmp + part) +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 +#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block +# define ACT_TY float8 +# define Q_BLK_LD_TY float +# define EXTRACT_BLK_DATA(tmp, part) *((float*)&tmp + part) +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 +#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block +# define ACT_TY float4 +# define Q_BLK_LD_TY half +# define EXTRACT_BLK_DATA(tmp, part) *((half*)&tmp + part) +# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 +#endif + +#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) + +#if N_DST == 2 +# define SUM_TY float2 +#elif N_DST == 4 +# define SUM_TY float4 +#elif N_DST == 8 +# define SUM_TY float8 +#elif N_DST == 16 +# define SUM_TY float16 +#endif + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_flat_img_v0( + read_only image1d_buffer_t src0_q, + read_only image1d_buffer_t src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + int ix = get_sub_group_local_id()/THREADS_PER_BLK; + int il = get_sub_group_local_id()%THREADS_PER_BLK; + + global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; + + // Registers for caching activation + ACT_TY yl = 0.f; + + // Registers for caching quants + Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; +#endif +#if N_DST == 8 || N_DST == 16 + Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; +#endif + + // Partial sum + SUM_TY sumf = 0.f; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { + float sumy = 0.f;; + + float4 tmp; + tmp = read_imagef(src0_q, offset0_q + ib + 0*nb); + q_blk_0 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 1*nb); + q_blk_1 = EXTRACT_BLK_DATA(tmp, il); +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + tmp = read_imagef(src0_q, offset0_q + ib + 2*nb); + q_blk_2 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 3*nb); + q_blk_3 = EXTRACT_BLK_DATA(tmp, il); +#endif +#if N_DST == 8 || N_DST == 16 + tmp = read_imagef(src0_q, offset0_q + ib + 4*nb); + q_blk_4 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 5*nb); + q_blk_5 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 6*nb); + q_blk_6 = EXTRACT_BLK_DATA(tmp, il); + tmp = read_imagef(src0_q, offset0_q + ib + 7*nb); + q_blk_7 = EXTRACT_BLK_DATA(tmp, il); +#endif + + // Load activation +#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block + yl.s01234567 = *(global float8 *)(yb); + yl.s89abcdef = *(global float8 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; + sumy += yl.s3; + sumy += yl.s4; + sumy += yl.s5; + sumy += yl.s6; + sumy += yl.s7; + sumy += yl.s8; yl.s8 /= 16.f; + sumy += yl.s9; yl.s9 /= 16.f; + sumy += yl.sa; yl.sa /= 16.f; + sumy += yl.sb; yl.sb /= 16.f; + sumy += yl.sc; yl.sc /= 16.f; + sumy += yl.sd; yl.sd /= 16.f; + sumy += yl.se; yl.se /= 16.f; + sumy += yl.sf; yl.sf /= 16.f; +#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block + yl.s0123 = *(global float4 *)(yb); + yl.s4567 = *(global float4 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; + sumy += yl.s3; + sumy += yl.s4; yl.s4 /= 16.f; + sumy += yl.s5; yl.s5 /= 16.f; + sumy += yl.s6; yl.s6 /= 16.f; + sumy += yl.s7; yl.s7 /= 16.f; +#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block + yl.s01 = *(global float2 *)(yb); + yl.s23 = *(global float2 *)(yb + 16); + + sumy += yl.s0; + sumy += yl.s1; + sumy += yl.s2; yl.s2 /= 16.f; + sumy += yl.s3; yl.s3 /= 16.f; +#endif + + sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, read_imageh(src0_d, offset0_d + ib + 0*nb).s0, sumy, yl); + sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, read_imageh(src0_d, offset0_d + ib + 1*nb).s0, sumy, yl); +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, read_imageh(src0_d, offset0_d + ib + 2*nb).s0, sumy, yl); + sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, read_imageh(src0_d, offset0_d + ib + 3*nb).s0, sumy, yl); +#endif +#if N_DST == 8 || N_DST == 16 + sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, read_imageh(src0_d, offset0_d + ib + 4*nb).s0, sumy, yl); + sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, read_imageh(src0_d, offset0_d + ib + 5*nb).s0, sumy, yl); + sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, read_imageh(src0_d, offset0_d + ib + 6*nb).s0, sumy, yl); + sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, read_imageh(src0_d, offset0_d + ib + 7*nb).s0, sumy, yl); +#endif + + yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); + } + + SUM_TY tot = (SUM_TY)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) +#endif +#if N_DST == 8 || N_DST == 16 + , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) + , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) +#endif + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } +#if N_DST == 4 || N_DST == 8 || N_DST == 16 + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } +#endif +#if N_DST == 8 || N_DST == 16 + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } +#endif + } +} + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32 +//------------------------------------------------------------------------------ + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 1 // number of rows each SIMD group works on +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 1 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + uchar kmask1 = 0x03; + uchar kmask2 = 0x0C; + uchar kmask3 = 0x30; + uchar kmask4 = 0xC0; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int row = N_SIMDGROUP * r0 + get_sub_group_id(); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + + // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a + // block. Values in a subblock shares a scale that is quantized with 8 bits; + // the entire block shares a single floating point scale. + // For work distribution, each thread processes a subblock (16 weights), hence + // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 + // (super) blocks -- this is the block stride. + // The 16 threads that process a (super) block are split into 2 portions, each has + // 8 threads; each portion works on 8 subblocks. + // For subgroup of 16 threads, the entire subgroup works on a single (super) block + // before moving to the next (super) block. Thread0 - thread7 work on the + // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. + // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on + // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but + // works on a total of 16 weight values. + int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 + int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + int y_offset = 128*ip + l0; + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += BLOCK_STRIDE) { + + global uint8_t * q1 = x[i].ql + q_offset_l; + global uint8_t * q2 = q1 + QK_K/8; + global uint8_t * qh = x[i].qh + q_offset_h; + global int8_t * sc = x[i].scales + is; + + global float * y = yy + i * QK_K + y_offset; + + float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); + sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); + sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); + sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); + sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); + sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); + sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); + sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); + sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); + sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); + sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); + sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); + sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); + + sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); + } + + float tot = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl new file mode 100644 index 000000000..57768c803 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl @@ -0,0 +1,130 @@ +// src0_q, src0_d, src1 are transposed as a preprocessing step +// 4-bit weights are transposed in groups of 4 (unsigned short int) +// consider weights originally "next to each other", now "on top of each other" +// each fiber computes a 8x4 tile of output elements +// using unshuffled weights + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +__attribute__((qcom_reqd_sub_group_size("full"))) +kernel void kernel_mul_mat_Ab_Bi_8x4( + global const ushort * src0_q, // quantized A + global const half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements + half8 B; // registers for activations + half4 dequantized_weights; // registers for dequantized weights + __global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights + __global const half* scale_ptr = src0_d + gx_2; // pointer for scales + + for(int i=0; i> 4) - 8) * scale.s0; // dequantize a row of the 16 weights + dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1; + dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2; + dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3; + c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights + dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1; + dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2; + dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3; + c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights + dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1; + dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2; + dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3; + c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements + + // conditional check if store is to a valid location. Required when N is not a multiple of 8 + // if statements allow registers to be reused for each store + // provides a performance boost due to reduced register footprint, which increases number of concurrent waves + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl new file mode 100644 index 000000000..d59a0c05d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl @@ -0,0 +1,32 @@ +// 16-bit transpose, loading/storing an 8x8 tile of elements + +kernel void kernel_transpose_16( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_3 = i<<3; + const int j_3 = j<<3; + + ushort8 temp0 = as_ushort8(read_imagef(input, (j_3+0)*cols+i)); + ushort8 temp1 = as_ushort8(read_imagef(input, (j_3+1)*cols+i)); + ushort8 temp2 = as_ushort8(read_imagef(input, (j_3+2)*cols+i)); + ushort8 temp3 = as_ushort8(read_imagef(input, (j_3+3)*cols+i)); + ushort8 temp4 = as_ushort8(read_imagef(input, (j_3+4)*cols+i)); + ushort8 temp5 = as_ushort8(read_imagef(input, (j_3+5)*cols+i)); + ushort8 temp6 = as_ushort8(read_imagef(input, (j_3+6)*cols+i)); + ushort8 temp7 = as_ushort8(read_imagef(input, (j_3+7)*cols+i)); + + write_imagef(output, (i_3+0)*rows+j, as_float4((ushort8)(temp0.s0, temp1.s0, temp2.s0, temp3.s0, temp4.s0, temp5.s0, temp6.s0, temp7.s0))); + write_imagef(output, (i_3+1)*rows+j, as_float4((ushort8)(temp0.s1, temp1.s1, temp2.s1, temp3.s1, temp4.s1, temp5.s1, temp6.s1, temp7.s1))); + write_imagef(output, (i_3+2)*rows+j, as_float4((ushort8)(temp0.s2, temp1.s2, temp2.s2, temp3.s2, temp4.s2, temp5.s2, temp6.s2, temp7.s2))); + write_imagef(output, (i_3+3)*rows+j, as_float4((ushort8)(temp0.s3, temp1.s3, temp2.s3, temp3.s3, temp4.s3, temp5.s3, temp6.s3, temp7.s3))); + write_imagef(output, (i_3+4)*rows+j, as_float4((ushort8)(temp0.s4, temp1.s4, temp2.s4, temp3.s4, temp4.s4, temp5.s4, temp6.s4, temp7.s4))); + write_imagef(output, (i_3+5)*rows+j, as_float4((ushort8)(temp0.s5, temp1.s5, temp2.s5, temp3.s5, temp4.s5, temp5.s5, temp6.s5, temp7.s5))); + write_imagef(output, (i_3+6)*rows+j, as_float4((ushort8)(temp0.s6, temp1.s6, temp2.s6, temp3.s6, temp4.s6, temp5.s6, temp6.s6, temp7.s6))); + write_imagef(output, (i_3+7)*rows+j, as_float4((ushort8)(temp0.s7, temp1.s7, temp2.s7, temp3.s7, temp4.s7, temp5.s7, temp6.s7, temp7.s7))); +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl new file mode 100644 index 000000000..914ec0193 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl @@ -0,0 +1,25 @@ +// 32-bit transpose, loading/storing a 4x4 tile of elements + +kernel void kernel_transpose_32( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + float4 temp0 = read_imagef(input, (j_2+0)*cols+i); + float4 temp1 = read_imagef(input, (j_2+1)*cols+i); + float4 temp2 = read_imagef(input, (j_2+2)*cols+i); + float4 temp3 = read_imagef(input, (j_2+3)*cols+i); + + write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); + +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl new file mode 100644 index 000000000..d3bd1fabb --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl @@ -0,0 +1,35 @@ +// 32-bit transpose, loading/storing a 4x4 tile of elements +// Only used for activations +// converts to FP16 +// also adds zero padding for non multiple of 8 prompt lengths +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + half4 temp0 = {0,0,0,0}; // initialize outputs to 0 + half4 temp1 = {0,0,0,0}; + half4 temp2 = {0,0,0,0}; + half4 temp3 = {0,0,0,0}; + + if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 + temp0 = read_imageh(input, (j_2+0)*cols+i); + } + if((j_2+1)*cols+i*4+3 < rows*cols*16){ + temp1 = read_imageh(input, (j_2+1)*cols+i); + } + if((j_2+2)*cols+i*4+3 < rows*cols*16){ + temp2 = read_imageh(input, (j_2+2)*cols+i); + } + if((j_2+3)*cols+i*4+3 < rows*cols*16){ + temp3 = read_imageh(input, (j_2+3)*cols+i); + } + + write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding + write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp new file mode 100644 index 000000000..7c3e24103 --- /dev/null +++ b/ggml/src/ggml-opt.cpp @@ -0,0 +1,854 @@ +#include "ggml-opt.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include +#include + +struct ggml_opt_dataset { + struct ggml_context * ctx = nullptr; + ggml_backend_buffer_t buf = nullptr; + struct ggml_tensor * data = nullptr; + struct ggml_tensor * labels = nullptr; + + int64_t ndata = -1; + int64_t ndata_shard = -1; + size_t nbs_data = -1; + size_t nbs_labels = -1; + + std::vector permutation; +}; + +struct ggml_opt_context { + ggml_backend_sched_t backend_sched = nullptr; + ggml_cgraph * allocated_graph = nullptr; + ggml_cgraph * allocated_graph_copy = nullptr; + struct ggml_context * ctx_static = nullptr; + struct ggml_context * ctx_static_cpu = nullptr; + struct ggml_context * ctx_compute = nullptr; + struct ggml_context * ctx_copy = nullptr; + ggml_backend_buffer_t buf_static = nullptr; + ggml_backend_buffer_t buf_static_cpu = nullptr; + std::mt19937 rng; + + struct ggml_tensor * inputs = nullptr; + struct ggml_tensor * outputs = nullptr; + struct ggml_tensor * labels = nullptr; + + struct ggml_tensor * loss = nullptr; + struct ggml_tensor * pred = nullptr; + struct ggml_tensor * ncorrect = nullptr; + + struct ggml_cgraph * gf = nullptr; + struct ggml_cgraph * gb_grad = nullptr; + struct ggml_cgraph * gb_opt = nullptr; + + int64_t iter = 1; + int32_t opt_period = 1; + int32_t opt_i = 0; + bool loss_per_datapoint = false; + + ggml_opt_get_optimizer_params get_opt_pars = nullptr; + void * get_opt_pars_ud = nullptr; + struct ggml_tensor * adamw_params = nullptr; +}; + +struct ggml_opt_result { + int64_t ndata = 0; + std::vector loss; + std::vector pred; + int64_t ncorrect = 0; + + int64_t opt_period = -1; + bool loss_per_datapoint = false; +}; + +// ====== Dataset ====== + +ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) { + GGML_ASSERT(ne_datapoint > 0); + GGML_ASSERT(ne_label >= 0); + GGML_ASSERT(ndata > 0); + GGML_ASSERT(ndata_shard > 0); + + ggml_opt_dataset_t result = new ggml_opt_dataset; + result->ndata = ndata; + result->ndata_shard = ndata_shard; + + { + struct ggml_init_params params = { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx = ggml_init(params); + } + + result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata); + result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata; + + if (ne_label > 0) { + result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata); + result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata; + } else { + result->labels = nullptr; + result->nbs_labels = 0; + } + + result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type()); + + const int64_t nshards = ndata/ndata_shard; + result->permutation.resize(nshards); + for (int64_t i = 0; i < nshards; ++i) { + result->permutation[i] = i; + } + return result; +} + +void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) { + ggml_backend_buffer_free(dataset->buf); + ggml_free(dataset->ctx); + delete dataset; +} + +struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) { + return dataset->data; +} + +struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) { + return dataset->labels; +} + +void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) { + GGML_ASSERT(idata <= dataset->ndata); + + if (idata < 0) { + std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng); + return; + } + + GGML_ASSERT(idata % dataset->ndata_shard == 0); + const int64_t ishard_max = idata / dataset->ndata_shard; + std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng); +} + +void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) { + GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch)); + GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch)); + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + + const size_t nb_data_batch = ggml_nbytes(data_batch); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + if (labels_batch) { + const size_t nb_labels_batch = ggml_nbytes(labels_batch); + GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels); + } + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data; + ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels; + ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels); + } +} + +// ====== Model / Context ====== + +struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) { + GGML_UNUSED(userdata); + + ggml_opt_optimizer_params result; + + result.adamw.alpha = 0.001f; + result.adamw.beta1 = 0.9f; + result.adamw.beta2 = 0.999f; + result.adamw.eps = 1e-8f; + result.adamw.wd = 0.0f; + + return result; +} + +struct ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + struct ggml_context * ctx_compute, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs, + enum ggml_opt_loss_type loss_type) { + return { + /*backend_sched =*/ backend_sched, + /*ctx_compute =*/ ctx_compute, + /*inputs =*/ inputs, + /*logits =*/ outputs, + /*loss_type =*/ loss_type, + /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT, + /*opt_period =*/ 1, + /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params, + /*get_opt_pars_ud =*/ nullptr, + }; +} + +static ggml_tensor * map_tensor(std::map & tensor_map, ggml_context * ctx, ggml_tensor * tensor) { + if (!tensor) { + return nullptr; + } + + if (tensor_map.find(tensor) != tensor_map.end()) { + return tensor_map[tensor]; + } + + ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor); + tensor_map[tensor] = new_tensor; + + new_tensor->op = tensor->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + new_tensor->nb[i] = tensor->nb[i]; + } + new_tensor->flags = tensor->flags; + memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params)); + strcpy(new_tensor->name, tensor->name); + new_tensor->data = tensor->data; + new_tensor->buffer = tensor->buffer; + new_tensor->extra = tensor->extra; + new_tensor->view_offs = tensor->view_offs; + new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src); + for (int i = 0; i < GGML_MAX_SRC; i++) { + new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]); + } + + return new_tensor; +} + +static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) { + std::map tensor_map; + + ggml_cgraph * dst = ggml_new_graph_custom(ctx, src->size, /*grads =*/ true); + + for (int i = 0; i < src->n_leafs; i++) { + ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i])); + } + GGML_ASSERT(dst->n_leafs == src->n_leafs); + for (int i = 0; i < src->n_nodes; i++) { + ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i])); + } + GGML_ASSERT(dst->n_nodes == src->n_nodes); + for (int i = 0; i < src->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + + GGML_ASSERT(igrad_src != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src)); + GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst)); + + dst->grads[igrad_dst] = src->grads[igrad_src]; + dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; + } + + return dst; +} + +static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) { + GGML_ASSERT(graph); + if (opt_ctx->allocated_graph == graph) { + return; + } + + ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + + { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_free(opt_ctx->ctx_copy); + opt_ctx->ctx_copy = ggml_init(params); + } + + opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); + + ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->allocated_graph = graph; +} + +ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { + ggml_opt_context_t result = new struct ggml_opt_context; + result->backend_sched = params.backend_sched; + result->ctx_compute = params.ctx_compute; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->opt_period = params.opt_period; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + + GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically"); + GGML_ASSERT(result->opt_period >= 1); + + const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD || + (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1); + + ggml_set_input(result->inputs); + ggml_set_output(result->outputs); + + result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + ggml_build_forward_expand(result->gf, result->outputs); + + int n_param = 0; + for (int i = 0; i < result->gf->n_nodes; ++i) { + if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) { + n_param++; + } + } + + { + // The static context is used for: + // - gradients (1 tensor per param if using gradient accumulation) + // - optimizer momenta (2 tensors per param) + // - labels + // - loss + its gradient (up to 5 tensors) + // - pred + // - ncorrect (2 tensors). + const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead(); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx_static = ggml_init(params); + } + { + // The static cpu context is used for: + // - optimizer parameters (1 for the entire context) + const size_t size_meta = 1 * ggml_tensor_overhead(); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx_static_cpu = ggml_init(params); + } + + + switch (params.loss_type) { + case GGML_OPT_LOSS_TYPE_MEAN: { + result->loss = ggml_sum(result->ctx_static, result->outputs); + ggml_set_name(result->loss, "loss_sum"); + const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); + result->loss = ggml_scale(result->ctx_static, result->loss, scale); + ggml_set_name(result->loss, "loss_mean"); + result->loss_per_datapoint = true; + break; + } + case GGML_OPT_LOSS_TYPE_SUM: { + result->loss = ggml_sum(result->ctx_static, result->outputs); + ggml_set_name(result->loss, "loss_sum"); + result->loss_per_datapoint = false; + break; + } + case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: { + result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); + ggml_set_input(result->labels); + ggml_set_name(result->labels, "labels"); + result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels); + ggml_set_name(result->loss, "loss_cross_entropy"); + if (result->opt_period > 1) { + result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period); + ggml_set_name(result->loss, "loss_cross_entropy_scaled"); + } + result->loss_per_datapoint = true; + break; + } + case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { + result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); + ggml_set_input(result->labels); + ggml_set_name(result->labels, "labels"); + result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels); + ggml_set_name(result->loss, "loss_error"); + result->loss = ggml_sqr(result->ctx_static, result->loss); + ggml_set_name(result->loss, "loss_squared_error"); + result->loss = ggml_sum(result->ctx_static, result->loss); + ggml_set_name(result->loss, "loss_sum_squared_error"); + const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); + result->loss = ggml_scale(result->ctx_static, result->loss, scale); + ggml_set_name(result->loss, "loss_mean_squared_error"); + result->loss_per_datapoint = true; + break; + } + } + ggml_set_output(result->loss); + ggml_set_loss(result->loss); + ggml_build_forward_expand(result->gf, result->loss); + + result->pred = ggml_argmax(result->ctx_static, result->outputs); + ggml_set_name(result->pred, "pred"); + ggml_set_output(result->pred); + ggml_build_forward_expand(result->gf, result->pred); + + if (result->labels) { + result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels)); + ggml_set_name(result->ncorrect, "ncorrect"); + ggml_set_output(result->ncorrect); + ggml_build_forward_expand(result->gf, result->ncorrect); + } else { + result->ncorrect = nullptr; + } + + if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) { + result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + return result; + } + + // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. + result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf); + ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate); + + if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) { + result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + ggml_graph_reset(result->gb_grad); + return result; + } + + GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT); + + // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. + result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad); + + result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7); + ggml_set_input(result->adamw_params); + ggml_set_name(result->adamw_params, "adamw_params"); + + for (int i = result->gf->n_nodes-1; i >= 0; --i) { + struct ggml_tensor * node = result->gb_opt->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node); + + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node); + struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node); + struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params); + ggml_build_forward_expand(result->gb_opt, opt_step); + } + } + + result->buf_static = ggml_backend_alloc_ctx_tensors( + result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + + result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type()); + + ggml_graph_reset(result->gb_opt); + + return result; +} + +void ggml_opt_free(ggml_opt_context_t opt_ctx) { + if (opt_ctx == nullptr) { + return; + } + ggml_backend_buffer_free(opt_ctx->buf_static); + ggml_backend_buffer_free(opt_ctx->buf_static_cpu); + ggml_free(opt_ctx->ctx_static); + ggml_free(opt_ctx->ctx_static_cpu); + delete opt_ctx; +} + +void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) { + if (optimizer) { + ggml_graph_reset(opt_ctx->gb_opt); + opt_ctx->iter = 1; + } else { + ggml_graph_reset(opt_ctx->gb_grad); + } +} + +struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) { + return opt_ctx->inputs; +} + +struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) { + return opt_ctx->outputs; +} + +struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) { + return opt_ctx->labels; +} + +struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) { + return opt_ctx->loss; +} + +struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) { + return opt_ctx->pred; +} + +struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) { + return opt_ctx->ncorrect; +} + +struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) { + return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node); +} + +// ====== Optimization Result ====== + +ggml_opt_result_t ggml_opt_result_init() { + return new ggml_opt_result; +} + +void ggml_opt_result_free(ggml_opt_result_t result) { + delete result; +} + +void ggml_opt_result_reset(ggml_opt_result_t result) { + result->ndata = 0; + result->loss.clear(); + result->pred.clear(); + result->ncorrect = 0; +} + +void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) { + *ndata = result->ndata; +} + +void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) { + const int64_t nbatches = result->loss.size(); // Number of physical batches. + + if (nbatches == 0) { + *loss = 0.0; + *unc = NAN; + return; + } + + double sum = 0.0; + double sum_squared = 0.0; + + for (const float & loss : result->loss) { + // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch. + const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss; + sum += loss_scaled; + sum_squared += loss_scaled*loss_scaled; + } + + const double mean = sum/nbatches; + *loss = result->loss_per_datapoint ? mean : sum; + + if (!unc) { + return; + } + + if (nbatches < 2) { + *unc = NAN; + return; + } + + const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1) + *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1)); +} + +void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) { + for (size_t i = 0; i < result->pred.size(); ++i) { + pred[i] = result->pred[i]; + } +} + +void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) { + *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN; + + if (!unc) { + return; + } + + *unc = result->ncorrect >= 0 && result->ndata >= 2 ? + sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN; +} + +// ====== Computation ====== + +static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) { + if (graph != opt_ctx->gf) { + struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); + + GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.eps >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd <= 1.0f); + + // beta1, beta2 after applying warmup + const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter)); + const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter)); + + float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params); + adamw_par_data[0] = opt_pars.adamw.alpha; + adamw_par_data[1] = opt_pars.adamw.beta1; + adamw_par_data[2] = opt_pars.adamw.beta2; + adamw_par_data[3] = opt_pars.adamw.eps; + adamw_par_data[4] = opt_pars.adamw.wd; + adamw_par_data[5] = beta1h; + adamw_par_data[6] = beta2h; + } + + ggml_opt_alloc_graph(opt_ctx, graph); + ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt; + + if (!result) { + return; + } + + if (result->ndata == 0) { + result->loss_per_datapoint = opt_ctx->loss_per_datapoint; + result->opt_period = opt_ctx->opt_period; + } else { + GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint); + GGML_ASSERT(result->opt_period == opt_ctx->opt_period); + } + + const int64_t ndata = opt_ctx->outputs->ne[1]; + GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported"); + result->ndata += ndata; + + GGML_ASSERT(ggml_is_scalar(opt_ctx->loss)); + GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32); + float loss; + ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss)); + result->loss.push_back(loss); + + GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); + std::vector pred(ndata); + ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); + result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + + if (!opt_ctx->labels || result->ncorrect < 0) { + result->ncorrect = -1; + return; + } + + GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect)); + GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64); + int64_t ncorrect; + ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect)); + result->ncorrect += ncorrect; +} + +void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result); +} + +void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { + if (opt_ctx->opt_period == 1) { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); + return; + } + + const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + if (opt_i_next == 0) { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); + ggml_opt_reset(opt_ctx, /*optimizer =*/ false); + } else { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result); + } + opt_ctx->opt_i = opt_i_next; +} + +// ====== High-Level Functions ====== + +void ggml_opt_epoch( + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx); + struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); + struct ggml_tensor * data = ggml_opt_dataset_data(dataset); + GGML_ASSERT(data->ne[0] == inputs->ne[0]); + + const int64_t ndata = data->ne[1]; + const int64_t ndata_batch = inputs->ne[1]; + + GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0); + const int64_t nbatches = ndata/ndata_batch; + + idata_split = idata_split < 0 ? ndata : idata_split; + GGML_ASSERT(idata_split % ndata_batch == 0); + const int64_t ibatch_split = idata_split / ndata_batch; + + int64_t ibatch = 0; + int64_t t_loop_start = ggml_time_us(); + for (; ibatch < ibatch_split; ++ibatch) { + ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + ggml_opt_forward_backward(opt_ctx, result_train); + if (callback_train) { + callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start); + } + } + t_loop_start = ggml_time_us(); + for (; ibatch < nbatches; ++ibatch) { + ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + ggml_opt_forward(opt_ctx, result_eval); + if (callback_eval) { + callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start); + } + } +} + +void ggml_opt_epoch_callback_progress_bar( + bool train, + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us) { + fprintf(stderr, "%s[", train ? "train: " : "val: "); + + constexpr int64_t bar_length = 25; + for (int64_t j = 0; j < bar_length; ++j) { + const int64_t ibatch_j = ibatch_max * j/bar_length; + if (ibatch_j < ibatch) { + fprintf(stderr, "="); + } else if (ibatch_max * (j - 1)/bar_length < ibatch) { + fprintf(stderr, ">"); + } else { + fprintf(stderr, " "); + } + } + + const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1]; + const int64_t idata = ibatch*batch_size; + const int64_t idata_max = ibatch_max*batch_size; + + double loss; + double loss_unc; + ggml_opt_result_loss(result, &loss, &loss_unc); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc); + + const int64_t t_ibatch_us = ggml_time_us() - t_start_us; + int64_t t_ibatch_s = t_ibatch_us / 1000000; + const int64_t t_ibatch_h = t_ibatch_s / 3600; + t_ibatch_s -= t_ibatch_h * 3600; + const int64_t t_ibatch_m = t_ibatch_s / 60; + t_ibatch_s -= t_ibatch_m * 60; + + const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch; + int64_t t_eta_s = t_eta_us / 1000000; + const int64_t t_eta_h = t_eta_s / 3600; + t_eta_s -= t_eta_h * 3600; + const int64_t t_eta_m = t_eta_s / 60; + t_eta_s -= t_eta_m * 60; + + fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, " + "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r", + idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc, + t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s); + if (ibatch == ibatch_max) { + fprintf(stderr, "\n"); + } + fflush(stderr); + + GGML_UNUSED(dataset); +} + +void ggml_opt_fit( + ggml_backend_sched_t backend_sched, + ggml_context * ctx_compute, + ggml_tensor * inputs, + ggml_tensor * outputs, + ggml_opt_dataset_t dataset, + enum ggml_opt_loss_type loss_type, + ggml_opt_get_optimizer_params get_opt_pars, + int64_t nepoch, + int64_t nbatch_logical, + float val_split, + bool silent) { + ggml_time_init(); + const int64_t t_start_us = ggml_time_us(); + + const int64_t ndata = ggml_opt_dataset_data(dataset)->ne[1]; + const int64_t nbatch_physical = inputs->ne[1]; + GGML_ASSERT(ndata % nbatch_logical == 0); + GGML_ASSERT(nbatch_logical % nbatch_physical == 0); + + const int64_t opt_period = nbatch_logical / nbatch_physical; + const int64_t nbatches_logical = ndata / nbatch_logical; + + GGML_ASSERT(val_split >= 0.0f); + GGML_ASSERT(val_split < 1.0f); + const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical) + const int64_t idata_split = ibatch_split * nbatch_physical; + + int64_t epoch = 1; + + ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + params.opt_period = opt_period; + params.get_opt_pars = get_opt_pars; + params.get_opt_pars_ud = &epoch; + ggml_opt_context_t opt_ctx = ggml_opt_init(params); + + // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch. + if (nbatch_logical < ndata) { + ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation). + } + + ggml_opt_result_t result_train = ggml_opt_result_init(); + ggml_opt_result_t result_val = ggml_opt_result_init(); + + ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar; + + for (; epoch <= nepoch; ++epoch) { + if (nbatch_logical < idata_split) { + ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split); + } + + ggml_opt_result_reset(result_train); + ggml_opt_result_reset(result_val); + + if (!silent) { + fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch); + } + ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback); + if (!silent) { + fprintf(stderr, "\n"); + } + } + + if (!silent) { + int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000; + const int64_t t_total_h = t_total_s / 3600; + t_total_s -= t_total_h * 3600; + const int64_t t_total_m = t_total_s / 60; + t_total_s -= t_total_m * 60; + fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s); + } + + ggml_opt_free(opt_ctx); + ggml_opt_result_free(result_train); + ggml_opt_result_free(result_val); +} diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 8c31e2cca..7918388ae 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3,7 +3,8 @@ #include "ggml-quants.h" #include "ggml-impl.h" - +#include "ggml-cpu/ggml-cpu-impl.h" +#include "ggml-cpu.h" #include #include @@ -26,637 +27,6 @@ #define UNUSED GGML_UNUSED -// some compilers don't provide _mm256_set_m128i, e.g. gcc 7 -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = _mm_sign_epi8(x, x); - // Sign the values of the y vectors - const __m128i sy = _mm_sign_epi8(y, x); - // Perform multiplication and create 16-bit values - const __m128i dot = _mm_maddubs_epi16(ax, sy); - const __m128i ones = _mm_set1_epi16(1); - return _mm_madd_epi16(ones, dot); -} - -#if __AVX__ || __AVX2__ || __AVX512F__ -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - return _mm_cvtss_f32(res); -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - const __m128i hi64 = _mm_unpackhi_epi64(a, a); - const __m128i sum64 = _mm_add_epi32(hi64, a); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -#if defined(__AVX2__) || defined(__AVX512F__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = _mm256_set_epi64x( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); - const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytes = _mm256_or_si256(bytes, bit_mask); - return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); - const __m256i lowMask = _mm256_set1_epi8( 0xF ); - return _mm256_and_si256(lowMask, bytes); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - const __m256i ones = _mm256_set1_epi16(1); - const __m256i summed_pairs = _mm256_madd_epi16(ones, x); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_float(dot); -#endif -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_float(ax, sy); -#endif -} - -static inline __m128i packNibbles( __m256i bytes ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh -#if __AVX512F__ - const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 - bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh - return _mm256_cvtepi16_epi8(bytes); // abcd_efgh -#else - const __m256i lowByte = _mm256_set1_epi16( 0xFF ); - __m256i high = _mm256_andnot_si256( lowByte, bytes ); - __m256i low = _mm256_and_si256( lowByte, bytes ); - high = _mm256_srli_epi16( high, 4 ); - bytes = _mm256_or_si256( low, high ); - - // Compress uint16_t lanes into bytes - __m128i r0 = _mm256_castsi256_si128( bytes ); - __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); - return _mm_packus_epi16( r0, r1 ); -#endif -} -#elif defined(__AVX__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); - __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); - __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); - const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytesl = _mm_or_si128(bytesl, bit_mask); - bytesh = _mm_or_si128(bytesh, bit_mask); - bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); - bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); - return MM256_SET_M128I(bytesh, bytesl); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - // Load 16 bytes from memory - __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); - __m128i tmph = _mm_srli_epi16(tmpl, 4); - const __m128i lowMask = _mm_set1_epi8(0xF); - tmpl = _mm_and_si128(lowMask, tmpl); - tmph = _mm_and_si128(lowMask, tmph); - return MM256_SET_M128I(tmph, tmpl); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { - const __m128i ones = _mm_set1_epi16(1); - const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); - const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); - const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - const __m128i axl = _mm256_castsi256_si128(ax); - const __m128i axh = _mm256_extractf128_si256(ax, 1); - const __m128i syl = _mm256_castsi256_si128(sy); - const __m128i syh = _mm256_extractf128_si256(sy, 1); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - const __m128i xl = _mm256_castsi256_si128(x); - const __m128i xh = _mm256_extractf128_si256(x, 1); - const __m128i yl = _mm256_castsi256_si128(y); - const __m128i yh = _mm256_extractf128_si256(y, 1); - // Get absolute values of x vectors - const __m128i axl = _mm_sign_epi8(xl, xl); - const __m128i axh = _mm_sign_epi8(xh, xh); - // Sign the values of the y vectors - const __m128i syl = _mm_sign_epi8(yl, xl); - const __m128i syh = _mm_sign_epi8(yh, xh); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m128i lowByte = _mm_set1_epi16( 0xFF ); - __m128i high = _mm_andnot_si128( lowByte, bytes1 ); - __m128i low = _mm_and_si128( lowByte, bytes1 ); - high = _mm_srli_epi16( high, 4 ); - bytes1 = _mm_or_si128( low, high ); - high = _mm_andnot_si128( lowByte, bytes2 ); - low = _mm_and_si128( lowByte, bytes2 ); - high = _mm_srli_epi16( high, 4 ); - bytes2 = _mm_or_si128( low, high ); - - return _mm_packus_epi16( bytes1, bytes2); -} -#endif -#elif defined(__SSSE3__) -// horizontally add 4x4 floats -static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { - __m128 res_0 =_mm_hadd_ps(a, b); - __m128 res_1 =_mm_hadd_ps(c, d); - __m128 res =_mm_hadd_ps(res_0, res_1); - res =_mm_hadd_ps(res, res); - res =_mm_hadd_ps(res, res); - - return _mm_cvtss_f32(res); -} -#endif // __AVX__ || __AVX2__ || __AVX512F__ -#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) - -#if defined(__ARM_NEON) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__) -#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s -#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) -#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) -#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) -#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) -#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) -#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) -#define B8(c,s ) B7(c,s, c), B7(c,s, s) - -// precomputed tables for expanding 8bits to 8 bytes: -static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 -static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 -#endif - -#if defined(__loongarch_asx) - -#ifdef __clang__ -#define VREGS_PREFIX "$vr" -#define XREGS_PREFIX "$xr" -#else // GCC -#define VREGS_PREFIX "$f" -#define XREGS_PREFIX "$f" -#endif -#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" -// Convert __m128i to __m256i -static inline __m256i ____m256i(__m128i in) { - __m256i out = __lasx_xvldi(0); - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " XREGS_PREFIX"\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " VREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - : [out] "+f" (out) : [in] "f" (in) - ); - return out; -} -// Convert two __m128i to __m256i -static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { - __m256i out; - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".ifnc %[out], %[hi] \n\t" - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " XREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" - " xvori.b $xr\\i, $xr\\j, 0 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".endif \n\t" - : [out] "=f" (out), [hi] "+f" (inhi) - : [lo] "f" (inlo) - ); - return out; -} -// Convert __m256i low part to __m128i -static inline __m128i lasx_extracti128_lo(__m256i in) { - __m128i out; - __asm__ volatile ( - ".ifnc %[out], %[in] \n\t" - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " XREGS_PREFIX "\\j \n\t" - " vori.b $vr\\i, $vr\\j, 0 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".endif \n\t" - : [out] "=f" (out) : [in] "f" (in) - ); - return out; -} -// Convert __m256i high part to __m128i -static inline __m128i lasx_extracti128_hi(__m256i in) { - __m128i out; - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " XREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - : [out] "=f" (out) : [in] "f" (in) - ); - return out; -} - -static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) { - v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7}; - return (__m256i)__ret; -} - -static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { - v4i32 __ret = {d, c, b, a}; - return (__m128i)__ret; -} - -static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { - v4i64 __ret = {d, c, b, a}; - return (__m256i)__ret; -} - -static __m256i lasx_insertf128( __m128i x, __m128i y) { - return lasx_set_q(x, y); -} - -static __m128i lsx_shuffle_b(__m128i a, __m128i b) { - __m128i mask_f, zero, tmp0, tmp2, mask; - int f = 0x8f; - mask_f = __lsx_vreplgr2vr_b(f); - zero = __lsx_vldi(0); - tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits - tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive - mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask - tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones - return __lsx_vshuf_b(a, zero, tmp2); -} - -static __m256i lasx_shuffle_b(__m256i a, __m256i b) { - __m256i mask_f, zero, tmp0, tmp2, mask; - int f = 0x8f; - mask_f = __lasx_xvreplgr2vr_b(f); - zero = __lasx_xvldi(0); - tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits - tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive - mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask - tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones - return __lasx_xvshuf_b(a, zero, tmp2); -} - -static __m256i lasx_extu8_16(__m128i a) { - __m128i zero = __lsx_vldi(0); - __m128i vlo = __lsx_vilvl_b(zero, a); - __m128i vhi = __lsx_vilvh_b(zero, a); - return lasx_set_q(vhi, vlo); -} - -static __m256i lasx_ext8_16(__m128i a) { - __m128i sign = __lsx_vslti_b(a, 0); - __m128i vlo = __lsx_vilvl_b(sign, a); - __m128i vhi = __lsx_vilvh_b(sign, a); - return lasx_set_q(vhi, vlo); -} - -static __m256i lasx_ext16_32(__m128i a) { - __m256i tmp1; - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6); - tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7); - return tmp1; -} - -static __m128i lasx_extracti128( __m256i a, int pos) { - __m128i ret; - if( pos == 0) - { - ret = lasx_extracti128_lo(a); - } else { - ret = lasx_extracti128_hi(a); - } - return ret; -} - -static __m128 lasx_extractf128( __m256 a, int pos) { - __m128 ret; - if( pos == 0) - { - ret = (__m128)lasx_extracti128_lo((__m256i)a); - } else { - ret = (__m128)lasx_extracti128_hi((__m256i)a); - } - return ret; -} - -static __m128i lsx_hadd_h(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_h(b, a); - __m128i tmp2 = __lsx_vpickod_h(b, a); - return __lsx_vadd_h(tmp1, tmp2); -} - -static __m128i lsx_hadd_w(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_w(b, a); - __m128i tmp2 = __lsx_vpickod_w(b, a); - return __lsx_vadd_w(tmp1, tmp2); -} - -static __m128 lsx_hadd_s(__m128 a, __m128 b) { - __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); - __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); - - return __lsx_vfadd_s(tmp1, tmp2); -} - -static __m256i lasx_maddubs_h(__m256i a, __m256i b) { - __m256i tmp1, tmp2; - tmp1 = __lasx_xvmulwev_h_b(a, b); - tmp2 = __lasx_xvmulwod_h_b(a, b); - return __lasx_xvsadd_h(tmp1, tmp2); -} - -static __m256i lasx_madd_h(__m256i a, __m256i b) { - __m256i tmp1, tmp2; - tmp1 = __lasx_xvmulwev_w_h(a, b); - tmp2 = __lasx_xvmulwod_w_h(a, b); - return __lasx_xvadd_w(tmp1, tmp2); -} - -static __m256i lasx_packs_w(__m256i a, __m256i b) { - __m256i tmp, tmp1; - tmp = __lasx_xvsat_w(a, 15); - tmp1 = __lasx_xvsat_w(b, 15); - return __lasx_xvpickev_h(tmp1, tmp); -} - -static __m256i lasx_packs_h(__m256i a, __m256i b) { - __m256i tmp, tmp1; - tmp = __lasx_xvsat_h(a, 7); - tmp1 = __lasx_xvsat_h(b, 7); - return __lasx_xvpickev_b(tmp1, tmp); -} - -static __m128i lsx_packs_w(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_w(a, 15); - tmp1 = __lsx_vsat_w(b, 15); - return __lsx_vpickev_h(tmp1, tmp); -} - -static __m128i lsx_packs_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_h(a, 7); - tmp1 = __lsx_vsat_h(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - -static __m128i lsx_packus_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_hu(a, 7); - tmp1 = __lsx_vsat_hu(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - - -static __m128i lsx_maddubs_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_h_b(a, b); - tmp2 = __lsx_vmulwod_h_b(a, b); - return __lsx_vsadd_h(tmp1, tmp2); -} - -static __m128i lsx_madd_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_w_h(a, b); - tmp2 = __lsx_vmulwod_w_h(a, b); - return __lsx_vadd_w(tmp1, tmp2); -} - -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = __lsx_vsigncov_b(x, x); - // Sign the values of the y vectors - const __m128i sy = __lsx_vsigncov_b(x, y); - // Perform multiplication and create 16-bit values - const __m128i dot = lsx_maddubs_h(ax, sy); - const __m128i ones = __lsx_vreplgr2vr_h(1); - return lsx_madd_h(ones, dot); -} - -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = lasx_extractf128(x, 1); - ft_union tmp; - res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); - res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); - res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); - tmp.i = __lsx_vpickve2gr_w(res, 0); - return tmp.f; -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - - __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); - __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); - - __m128i tmp1_128 = lasx_extracti128_lo(tmp1); - __m128i tmp2_128 = lasx_extracti128_lo(tmp2); - - __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); - - __m128i ev = __lsx_vpickev_w(sum128, sum128); - __m128i od = __lsx_vpickod_w(sum128, sum128); - __m128i sum64 = __lsx_vadd_w(ev, od); - - int sum64_1, sum64_2; - sum64_1 = __lsx_vpickve2gr_w(sum64, 0); - sum64_2 = __lsx_vpickve2gr_w(sum64, 1); - - return sum64_1 + sum64_2; -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - __m128i ev = __lsx_vpickev_w(a, a); - __m128i od = __lsx_vpickod_w(a, a); - __m128i sum64 = __lsx_vadd_w(ev, od); - - int sum64_1, sum64_2; - sum64_1 = __lsx_vpickve2gr_w(sum64, 0); - sum64_2 = __lsx_vpickve2gr_w(sum64, 1); - - return sum64_1 + sum64_2; -} - -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = lasx_set_d( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - - __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask); - const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe); - bytes = __lasx_xvor_v(bytes, bit_mask); - return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { - const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); - __m128i hi = __lsx_vsrli_h(lo, 4); - return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - __m256i v = __lasx_xvpackod_h(x, x); - __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); - return __lasx_xvffint_s_w(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - // Perform multiplication and create 16-bit values - const __m256i dot = lasx_maddubs_h(ax, sy); - return sum_i16_pairs_float(dot); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - - // Get absolute values of x vectors - const __m256i ax = __lasx_xvsigncov_b(x, x); - // Sign the values of the y vectors - const __m256i sy = __lasx_xvsigncov_b(x, y); - - return mul_sum_us8_pairs_float(ax, sy); -} - -static inline __m128i packNibbles( __m256i bytes ) { - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF); - __m256i high = __lasx_xvandn_v(lowByte, bytes); - __m256i low = __lasx_xvand_v(lowByte, bytes); - high = __lasx_xvsrli_h(high, 4); - bytes = __lasx_xvor_v(low, high); - // Compress uint16_t lanes into bytes - __m128i *r0 = (__m128i *)&bytes; - __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11); - __m128i *r1 = (__m128i *)&tmp_h128; - - __m128i zero = __lsx_vldi(0); - __m128i tmp, tmp2, tmp3; - - tmp = __lsx_vmax_h(zero, *r0); - tmp2 = __lsx_vsat_hu(tmp, 7); - - tmp = __lsx_vmax_h(zero, *r1); - tmp3 = __lsx_vsat_hu(tmp, 7); - return __lsx_vpickev_b(tmp3, tmp2); -} -#endif //__loongarch_asx - // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -695,11 +65,6 @@ void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, in } } -void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_0_ref(x, y, k); -} - - void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; @@ -737,10 +102,6 @@ void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, in } } -void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_1_ref(x, y, k); -} - void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; @@ -785,10 +146,6 @@ void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, in } } -void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_0_ref(x, y, k); -} - void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; @@ -833,10 +190,6 @@ void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, in } } -void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_1_ref(x, y, k); -} - // reference implementation for deterministic creation of model files void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); @@ -863,291 +216,6 @@ void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, in } } -void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - } - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#elif defined(__riscv_v_intrinsic) - - size_t vl = __riscv_vsetvl_e32m4(QK8_0); - - for (int i = 0; i < nb; i++) { - // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); - - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); - vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); - float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); - - // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); - - // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); - } - -#elif defined(__POWER9_VECTOR__) - for (int i = 0; i < nb; i++) { - vector float srcv [8]; - vector float asrcv[8]; - vector float amaxv[8]; - vector signed int vi[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(vec_extract(amaxv[0], 0), - vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), - vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - const vector float vid = vec_splats(id); - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const vector float v = vec_round(vec_mul(srcv[j], vid)); - vi[j] = vec_cts(v, 0); - } - vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); - vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); - } - -#elif defined(__loongarch_asx) - for (int i = 0; i < nb; i++) { - ft_union fi; - __m256 v0 = (__m256)__lasx_xvld( x , 0); - __m256 v1 = (__m256)__lasx_xvld( x , 32); - __m256 v2 = (__m256)__lasx_xvld( x , 64); - __m256 v3 = (__m256)__lasx_xvld( x , 96); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); - __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); - - __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) ); - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); - __m128 tmp = max4; - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 )); - fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); - const float max_scalar = fi.f; - - // Quantize these floats - const float d = max_scalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id ); - - // Apply the multiplier - v0 = __lasx_xvfmul_s( v0, mul ); - v1 = __lasx_xvfmul_s( v1, mul ); - v2 = __lasx_xvfmul_s( v2, mul ); - v3 = __lasx_xvfmul_s( v3, mul ); - - // Round to nearest integer - __m256i i0 = __lasx_xvftintrne_w_s( v0 ); - __m256i i1 = __lasx_xvftintrne_w_s( v1 ); - __m256i i2 = __lasx_xvftintrne_w_s( v2 ); - __m256i i3 = __lasx_xvftintrne_w_s( v3 ); - - __m128i ni0 = lasx_extracti128( i0, 0 ); - __m128i ni1 = lasx_extracti128( i0, 1); - __m128i ni2 = lasx_extracti128( i1, 0); - __m128i ni3 = lasx_extracti128( i1, 1); - __m128i ni4 = lasx_extracti128( i2, 0); - __m128i ni5 = lasx_extracti128( i2, 1); - __m128i ni6 = lasx_extracti128( i3, 0); - __m128i ni7 = lasx_extracti128( i3, 1); - - // Convert int32 to int16 - ni0 = lsx_packs_w( ni0, ni1 ); - ni2 = lsx_packs_w( ni2, ni3 ); - ni4 = lsx_packs_w( ni4, ni5 ); - ni6 = lsx_packs_w( ni6, ni7 ); - // Convert int16 to int8 - ni0 = lsx_packs_h( ni0, ni2 ); - ni4 = lsx_packs_h( ni4, ni6 ); - - __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); - __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); - - } -#else - GGML_UNUSED(nb); - // scalar - quantize_row_q8_0_ref(x, y, k); -#endif -} - // reference implementation for deterministic creation of model files void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); @@ -1184,334 +252,6 @@ void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, in } } -void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - int32x4_t accv = vdupq_n_s32(0); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - - accv = vaddq_s32(accv, vi); - } - - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - v128_t accv = wasm_i32x4_splat(0); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - - accv = wasm_i32x4_add(accv, vi); - } - - y[i].s = GGML_FP32_TO_FP16( - d * (wasm_i32x4_extract_lane(accv, 0) + - wasm_i32x4_extract_lane(accv, 1) + - wasm_i32x4_extract_lane(accv, 2) + - wasm_i32x4_extract_lane(accv, 3))); - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float max_scalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = max_scalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Compute the sum of the quants and set y[i].s - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Compute the sum of the quants and set y[i].s - const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); - const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1))); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#elif defined(__riscv_v_intrinsic) - - size_t vl = __riscv_vsetvl_e32m4(QK8_1); - - for (int i = 0; i < nb; i++) { - // load elements - vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); - - vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); - vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); - float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); - - // convert to integer - vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); - vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); - - // store result - __riscv_vse8_v_i8m1(y[i].qs , vs, vl); - - // compute sum for y[i].s - vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); - vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); - - // set y[i].s - int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); - y[i].s = GGML_FP32_TO_FP16(sum*d); - } - -#elif defined(__POWER9_VECTOR__) - for (int i = 0; i < nb; i++) { - vector float srcv [8]; - vector float asrcv[8]; - vector float amaxv[8]; - vector signed int vi[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(vec_extract(amaxv[0], 0), - vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), - vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - const vector float vid = vec_splats(id); - - y[i].d = GGML_FP32_TO_FP16(d); - - vector int accv = vec_splats(0); - - for (int j = 0; j < 8; j++) { - const vector float v = vec_round(vec_mul(srcv[j], vid)); - vi[j] = vec_cts(v, 0); - - accv = vec_add(accv, vi[j]); - } - vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); - vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); - - accv = vec_add(accv, vec_sld(accv, accv, 4)); - accv = vec_add(accv, vec_sld(accv, accv, 8)); - y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); - } - -#elif defined(__loongarch_asx) - for (int i = 0; i < nb; i++) { - ft_union ft; - __m256 v0 = (__m256)__lasx_xvld( x , 0 ); - __m256 v1 = (__m256)__lasx_xvld( x , 32 ); - __m256 v2 = (__m256)__lasx_xvld( x , 64 ); - __m256 v3 = (__m256)__lasx_xvld( x , 96 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); - __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); - - __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); - __m128 tmp = max4; - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); - ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); - const float max_scalar = ft.f; - - // Quantize these floats - const float d = max_scalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = __lasx_xvreplfr2vr_s( id ); - - // Apply the multiplier - v0 = __lasx_xvfmul_s( v0, mul ); - v1 = __lasx_xvfmul_s( v1, mul ); - v2 = __lasx_xvfmul_s( v2, mul ); - v3 = __lasx_xvfmul_s( v3, mul ); - - // Round to nearest integer - __m256i i0 = __lasx_xvftintrne_w_s( v0 ); - __m256i i1 = __lasx_xvftintrne_w_s( v1 ); - __m256i i2 = __lasx_xvftintrne_w_s( v2 ); - __m256i i3 = __lasx_xvftintrne_w_s( v3 ); - - __m128i ni0 = lasx_extracti128(i0, 0); - __m128i ni1 = lasx_extracti128( i0, 1); - __m128i ni2 = lasx_extracti128( i1, 0); - __m128i ni3 = lasx_extracti128( i1, 1); - __m128i ni4 = lasx_extracti128( i2, 0 ); - __m128i ni5 = lasx_extracti128( i2, 1); - __m128i ni6 = lasx_extracti128( i3, 0); - __m128i ni7 = lasx_extracti128( i3, 1); - - // Compute the sum of the quants and set y[i].s - const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3)); - const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7)); - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1))); - - // Convert int32 to int16 - ni0 = lsx_packs_w( ni0, ni1 ); - ni2 = lsx_packs_w( ni2, ni3 ); - ni4 = lsx_packs_w( ni4, ni5 ); - ni6 = lsx_packs_w( ni6, ni7 ); - // Convert int16 to int8 - ni0 = lsx_packs_h( ni0, ni2 ); - ni4 = lsx_packs_h( ni4, ni6 ); - - __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); - __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); - } -#else - GGML_UNUSED(nb); - // scalar - quantize_row_q8_1_ref(x, y, k); -#endif -} - void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_0; @@ -2001,10 +741,6 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6 } } -void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q2_K_ref(x, vy, k); -} - static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, float rmin, float rdelta, int nstep, bool use_mad) { @@ -2367,10 +1103,6 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6 } } -void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q3_K_ref(x, vy, k); -} - static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { assert(n_per_row % QK_K == 0); const int nb = n_per_row / QK_K; @@ -2569,12 +1301,6 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6 } } -void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_q4_K * restrict y = vy; - quantize_row_q4_K_ref(x, y, k); -} - static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { assert(n_per_row % QK_K == 0); const int64_t nb = n_per_row / QK_K; @@ -2780,12 +1506,6 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6 } } -void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_q5_K * restrict y = vy; - quantize_row_q5_K_ref(x, y, k); -} - static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { assert(n_per_row % QK_K == 0); const int64_t nb = n_per_row / QK_K; @@ -2998,12 +1718,6 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6 } } -void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_q6_K * restrict y = vy; - quantize_row_q6_K_ref(x, y, k); -} - static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { assert(n_per_row % QK_K == 0); const int64_t nb = n_per_row / QK_K; @@ -3406,33 +2120,20 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, } } -void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_tq1_0 * restrict y = vy; - quantize_row_tq1_0_ref(x, y, k); -} - -void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_tq2_0 * restrict y = vy; - quantize_row_tq2_0_ref(x, y, k); -} - size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row); - quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row); + quantize_row_tq1_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row); - quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row); + quantize_row_tq2_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } - void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3825,8994 +2526,6 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6 } } -void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q8_K_ref(x, y, k); -} - -//===================================== Dot products ================================= - -// -// Helper functions -// -#if __AVX__ || __AVX2__ || __AVX512F__ - -// shuffles to pick the required scales in dot products -static inline __m256i get_scale_shuffle_q3k(int i) { - static const uint8_t k_shuffle[128] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m256i get_scale_shuffle_k4(int i) { - static const uint8_t k_shuffle[256] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, - 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, - 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, - 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m128i get_scale_shuffle(int i) { - static const uint8_t k_shuffle[128] = { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, - 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, - 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, - 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 - }; - return _mm_loadu_si128((const __m128i*)k_shuffle + i); -} -#elif defined(__loongarch_asx) -// shuffles to pick the required scales in dot products -static inline __m256i get_scale_shuffle_q3k(int i) { - static const uint8_t k_shuffle[128] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, - }; - return __lasx_xvld((const __m256i*)k_shuffle + i, 0); -} -static inline __m256i get_scale_shuffle_k4(int i) { - static const uint8_t k_shuffle[256] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, - 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, - 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, - 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 - }; - return __lasx_xvld((const __m256i*)k_shuffle + i, 0); -} -static inline __m128i get_scale_shuffle(int i) { - static const uint8_t k_shuffle[128] = { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, - 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, - 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, - 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 - }; - return __lsx_vld((const __m128i*)k_shuffle + i, 0); -} -#endif - -void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); -#if defined(__ARM_FEATURE_MATMUL_INT8) - assert((nrc == 2) || (nrc == 1)); -#else - assert(nrc == 1); -#endif - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q4_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_FEATURE_MATMUL_INT8) - if (nrc == 2) { - const block_q4_0 * restrict vx0 = vx; - const block_q4_0 * restrict vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx); - const block_q8_0 * restrict vy0 = vy; - const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); - - float32x4_t sumv0 = vdupq_n_f32(0.0f); - - for (int i = 0; i < nb; i++) { - const block_q4_0 * restrict b_x0 = &vx0[i]; - const block_q4_0 * restrict b_x1 = &vx1[i]; - const block_q8_0 * restrict b_y0 = &vy0[i]; - const block_q8_0 * restrict b_y1 = &vy1[i]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); - const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); - const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); - const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); - const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); - - // load y - const int8x16_t y0_l = vld1q_s8(b_y0->qs); - const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); - const int8x16_t y1_l = vld1q_s8(b_y1->qs); - const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - - float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; - - float32x4_t scale = vld1q_f32(_scale); - - int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - - int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - - int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - - int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - - sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); - } - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); - float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - - vst1_f32(s, vget_low_f32(sumv2)); - vst1_f32(s + bs, vget_high_f32(sumv2)); - return; - } -#endif - - int ib = 0; - float sumf = 0; - -#if defined(__ARM_FEATURE_SVE) - if (ggml_sve_cnt_b == QK8_0) { - const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); - - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); - - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); - - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - - // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - } -#elif defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * restrict x0 = &x[ib + 0]; - const block_q4_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - // dot product into int32x4_t - const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); - const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - qx = _mm256_sub_epi8( qx, off ); - - __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( d, q, acc ); - } - - sumf = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0); - - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); - - // Apply the scale, and accumulate - acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); - } - - sumf = hsum_float_8(acc); -#elif defined(__SSSE3__) - // set constants - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - // Initialize accumulator with zeros - __m128 acc_0 = _mm_setzero_ps(); - __m128 acc_1 = _mm_setzero_ps(); - __m128 acc_2 = _mm_setzero_ps(); - __m128 acc_3 = _mm_setzero_ps(); - - for (; ib + 1 < nb; ib += 2) { - _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); - __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); - __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); - __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); - - // Acummulate - acc_0 = _mm_add_ps(p0_d, acc_0); - acc_1 = _mm_add_ps(p1_d, acc_1); - acc_2 = _mm_add_ps(p2_d, acc_2); - acc_3 = _mm_add_ps(p3_d, acc_3); - } - - sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (; ib < nb; ++ib) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - // subtract offset - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector signed char v8 = vec_splats((signed char)0x8); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 8 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector signed char q4x0 = vec_and(qxs, lowMask); - vector signed char q4x1 = vec_sr(qxs, v4); - - q4x0 = vec_sub(q4x0, v8); - q4x1 = vec_sub(q4x1, v8); - - vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - - vector signed int vsumi0 = v0; - - vsumi0 = vec_sum4s(qv0, vsumi0); - vsumi0 = vec_sum4s(qv1, vsumi0); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = __lasx_xvreplgr2vr_b( 8 ); - qx = __lasx_xvsub_b( qx, off ); - - __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = __lasx_xvfmadd_s( d, q, acc ); - } - - sumf = hsum_float_8(acc); -#elif defined(__loongarch_sx) - // set constants - const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); - const __m128i off = __lsx_vreplgr2vr_b(8); - - // Initialize accumulator with zeros - __m128 acc_0 = __lsx_vldi(0); - __m128 acc_1 = __lsx_vldi(0); - __m128 acc_2 = __lsx_vldi(0); - __m128 acc_3 = __lsx_vldi(0); - - for (; ib + 1 < nb; ib += 2) { - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - - const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0); - - __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); - __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); - bx_0 = __lsx_vsub_b(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); - __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0); - bx_1 = __lsx_vsub_b(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); - - const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0); - - __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); - __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0); - bx_2 = __lsx_vsub_b(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); - __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0); - bx_3 = __lsx_vsub_b(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = __lsx_vffint_s_w(i32_0); - __m128 p1 = __lsx_vffint_s_w(i32_1); - __m128 p2 = __lsx_vffint_s_w(i32_2); - __m128 p3 = __lsx_vffint_s_w(i32_3); - - // Apply the scale - __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 ); - __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 ); - __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 ); - __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 ); - - // Acummulate - acc_0 = __lsx_vfadd_s(p0_d, acc_0); - acc_1 = __lsx_vfadd_s(p1_d, acc_1); - acc_2 = __lsx_vfadd_s(p2_d, acc_2); - acc_3 = __lsx_vfadd_s(p3_d, acc_3); - } - - sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#endif - for (; ib < nb; ++ib) { - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[ib].qs[j] & 0x0F) - 8; - const int v1 = (x[ib].qs[j] >> 4) - 8; - - sumi0 += (v0 * y[ib].qs[j]); - sumi1 += (v1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); - } - - *s = sumf; -} - -void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_1; - const int nb = n / qk; - - assert(n % qk == 0); -#if defined(__ARM_FEATURE_MATMUL_INT8) - assert((nrc == 2) || (nrc == 1)); -#else - assert(nrc == 1); -#endif - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q4_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - -#if defined(__ARM_FEATURE_MATMUL_INT8) - if (nrc == 2) { - const block_q4_1 * restrict vx0 = vx; - const block_q4_1 * restrict vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx); - const block_q8_1 * restrict vy0 = vy; - const block_q8_1 * restrict vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by); - - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t summs0 = vdupq_n_f32(0.0f); - - for (int i = 0; i < nb; i++) { - const block_q4_1 * restrict b_x0 = &vx0[i]; - const block_q4_1 * restrict b_x1 = &vx1[i]; - const block_q8_1 * restrict b_y0 = &vy0[i]; - const block_q8_1 * restrict b_y1 = &vy1[i]; - - float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), - GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), - GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), - GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)}; - summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); - const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); - - // 4-bit -> 8-bit - const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // load y - const int8x16_t y0_l = vld1q_s8(b_y0->qs); - const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); - const int8x16_t y1_l = vld1q_s8(b_y1->qs); - const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - - // mmla into int32x4_t - float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d, - GGML_FP16_TO_FP32(b_x0->d)*b_y1->d, - GGML_FP16_TO_FP32(b_x1->d)*b_y0->d, - GGML_FP16_TO_FP32(b_x1->d)*b_y1->d}; - float32x4_t scale = vld1q_f32(_scale); - - int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - - int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - - int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - - int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); - } - - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); - float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - sumv2 = vaddq_f32(sumv2, summs0); - - vst1_f32(s, vget_low_f32 (sumv2)); - vst1_f32(s + bs, vget_high_f32(sumv2)); - return; - } -#endif - - int ib = 0; - float sumf = 0; - - // TODO: add WASM SIMD -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs = 0; - - for (; ib + 1 < nb; ib += 2) { - const block_q4_1 * restrict x0 = &x[ib + 0]; - const block_q4_1 * restrict x1 = &x[ib + 1]; - const block_q8_1 * restrict y0 = &y[ib + 0]; - const block_q8_1 * restrict y1 = &y[ib + 1]; - - summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - // dot product into int32x4_t - const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); - const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - // Main loop - for (; ib < nb; ++ib) { - const float d0 = GGML_FP16_TO_FP32(x[ib].d); - const float d1 = GGML_FP16_TO_FP32(y[ib].d); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - const __m256 d0v = _mm256_set1_ps( d0 ); - const __m256 d1v = _mm256_set1_ps( d1 ); - - // Compute combined scales - const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[ib].qs); - const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs ); - - const __m256 xy = mul_sum_us8_pairs_float(qx, qy); - - // Accumulate d0*d1*x*y -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d0d1, xy, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); -#endif - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (; ib < nb; ++ib) { - // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - // mask and store lower part of x, and then upper part - vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); - vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f}; - vsumf0 = vec_madd(vxmin, vys, vsumf0); - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask); - vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4); - - vector signed int vsumi0 = v0; - - vsumi0 = vec_msum(q8y0, q4x0, vsumi0); - vsumi0 = vec_msum(q8y1, q4x1, vsumi0); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - float summs = 0; - - // Main loop - for (; ib < nb; ++ib) { - const float d0 = GGML_FP16_TO_FP32(x[ib].d); - const float d1 = GGML_FP16_TO_FP32(y[ib].d); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); - const __m256 d1v = __lasx_xvreplfr2vr_s( d1 ); - - // Compute combined scales - const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[ib].qs); - const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0); - - const __m256 xy = mul_sum_us8_pairs_float(qx, qy); - - // Accumulate d0*d1*x*y - acc = __lasx_xvfmadd_s( d0d1, xy, acc ); - } - - sumf = hsum_float_8(acc) + summs; -#endif - for (; ib < nb; ++ib) { - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[ib].qs[j] & 0x0F); - const int v1 = (x[ib].qs[j] >> 4); - - sumi0 += (v0 * y[ib].qs[j]); - sumi1 += (v1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); - } - - *s = sumf; -} - -void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - int ib = 0; - float sumf = 0; - - assert(n % qk == 0); - assert(qk == QK5_0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - for (; ib + 1 < nb; ib += 2) { - const block_q5_0 * restrict x0 = &x[ib]; - const block_q5_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - // extract the 5th bit via lookup table ((!b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_1[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_1[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (; ib < nb; ++ib) { - const block_q5_0 * restrict x0 = &x[ib]; - const block_q8_0 * restrict y0 = &y[ib]; - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); - const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); - } - - sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); - qx = _mm256_or_si256(qx, bxhi); - - __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps(d, q, acc); - } - - sumf = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8((char)0xF0); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - - __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); - const __m256i bxhi = bytes_from_bits_32(x[ib].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_andnot_si128(bxhil, mask); - bxhih = _mm_andnot_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx_0); - __m128i bxh = _mm256_extractf128_si256(bx_0, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx_0 = MM256_SET_M128I(bxh, bxl); - - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); - - /* Multiply q with scale and accumulate */ - acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); - } - - sumf = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // These temporary registers are for masking and shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); - - vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); - vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - - for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); - - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); - vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); - vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)4); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])}; - vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])}; - - vector signed char qh0 = (vector signed char)aux64x2_0; - vector signed char qh1 = (vector signed char)aux64x2_1; - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - - vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); - vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); - - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl( 16, y[ib].qs); - - vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); - - qv0 = vec_add(qv0, qv1); - - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); //FIXME - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0)); - qx = __lasx_xvor_v(qx, bxhi); - - __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = __lasx_xvfmadd_s(d, q, acc); - } - - sumf = hsum_float_8(acc); -#endif - for (; ib < nb; ++ib) { - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); - const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); - - sumi0 += (x0 * y[ib].qs[j]); - sumi1 += (x1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; - } - - *s = sumf; -} - -void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_1; - const int nb = n / qk; - - int ib = 0; - float sumf = 0; - - assert(n % qk == 0); - assert(qk == QK5_1); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs0 = 0.0f; - float summs1 = 0.0f; - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - for (; ib + 1 < nb; ib += 2) { - const block_q5_1 * restrict x0 = &x[ib]; - const block_q5_1 * restrict x1 = &x[ib + 1]; - const block_q8_1 * restrict y0 = &y[ib]; - const block_q8_1 * restrict y1 = &y[ib + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); - summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); - - // extract the 5th bit via lookup table ((b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_0[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_0[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit - const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - float summs = 0.0f; - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (; ib < nb; ++ib) { - const block_q5_1 * restrict x0 = &x[ib]; - const block_q8_1 * restrict y0 = &y[ib]; - - summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit - const v128_t v0lf = wasm_v128_or(v0l, qhl); - const v128_t v0hf = wasm_v128_or(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); - } - - sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.0f; - - // Main loop - for (; ib < nb; ++ib) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); - qx = _mm256_or_si256(qx, bxhi); - - const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); - const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_us8_pairs_float(qx, qy); - - acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8(0x10); - - float summs = 0.0f; - - // Main loop - for (; ib < nb; ++ib) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); - const __m256i bxhi = bytes_from_bits_32(x[ib].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_and_si128(bxhil, mask); - bxhih = _mm_and_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx_0); - __m128i bxh = _mm256_extractf128_si256(bx_0, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx_0 = MM256_SET_M128I(bxh, bxl); - - const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); - - acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - uint32_t qh; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - // temporary registers for shift operations - vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); - vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - - for (; ib < nb; ++ib) { - memcpy(&qh, x[ib].qh, sizeof(uint32_t)); - - // load qh - vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); - - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); - vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); - vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); - - // ((qh >> (j + 12)) ) & 0x10; - vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); - vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); - - // narrowing - vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); - vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); - - vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); - vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); - - // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); - - vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); - vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); - - vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); - vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); - - vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); - vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); - - vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); - vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); - vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f}; - vsumf0 = vec_madd(vxmin, vys, vsumf0); - - vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])}; - vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])}; - - vector signed char qh0 = (vector signed char)aux64x2_0; - vector signed char qh1 = (vector signed char)aux64x2_1; - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - - vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0); - vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1); - - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl( 16, y[ib].qs); - - vector signed int vsumi0 = v0; - - vsumi0 = vec_msum(q8y0, q5x0, vsumi0); - vsumi0 = vec_msum(q8y1, q5x1, vsumi0); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - float summs = 0.0f; - - // Main loop - for (; ib < nb; ++ib) { - const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d)); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); - qx = __lasx_xvor_v(qx, bxhi); - - const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d)); - const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_us8_pairs_float(qx, qy); - - acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); - } - - sumf = hsum_float_8(acc) + summs; -#endif - for (; ib < nb; ++ib) { - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; - - sumi0 += (x0 * y[ib].qs[j]); - sumi1 += (x1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); - } - - *s = sumf; -} - -void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); -#if defined(__ARM_FEATURE_MATMUL_INT8) - assert((nrc == 2) || (nrc == 1)); -#else - assert(nrc == 1); -#endif - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q8_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_FEATURE_MATMUL_INT8) - if (nrc == 2) { - const block_q8_0 * restrict vx0 = vx; - const block_q8_0 * restrict vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx); - const block_q8_0 * restrict vy0 = vy; - const block_q8_0 * restrict vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); - - float32x4_t sumv0 = vdupq_n_f32(0.0f); - - for (int i = 0; i < nb; i++) { - const block_q8_0 * restrict b_x0 = &vx0[i]; - const block_q8_0 * restrict b_y0 = &vy0[i]; - - const block_q8_0 * restrict b_x1 = &vx1[i]; - const block_q8_0 * restrict b_y1 = &vy1[i]; - - const int8x16_t x0_l = vld1q_s8(b_x0->qs); - const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16); - const int8x16_t x1_l = vld1q_s8(b_x1->qs); - const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16); - - // load y - const int8x16_t y0_l = vld1q_s8(b_y0->qs); - const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); - const int8x16_t y1_l = vld1q_s8(b_y1->qs); - const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - - float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)}; - float32x4_t scale = vld1q_f32(_scale); - - int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - - int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - - int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - - int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - - sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); - } - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); - float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - - vst1_f32(s, vget_low_f32(sumv2)); - vst1_f32(s + bs, vget_high_f32(sumv2)); - return; - } -#endif - - int ib = 0; - float sumf = 0; - -#if defined(__ARM_FEATURE_SVE) - if (ggml_sve_cnt_b == QK8_0) { - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); - - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - } -#elif defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * restrict x0 = &x[ib + 0]; - const block_q8_0 * restrict x1 = &x[ib + 1]; - const block_q8_0 * restrict y0 = &y[ib + 0]; - const block_q8_0 * restrict y1 = &y[ib + 1]; - - const int8x16_t x0_0 = vld1q_s8(x0->qs); - const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); - const int8x16_t x1_0 = vld1q_s8(x1->qs); - const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); - - // load y - const int8x16_t y0_0 = vld1q_s8(y0->qs); - const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); - const int8x16_t y1_0 = vld1q_s8(y1->qs); - const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs); - __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - // Multiply q with scale and accumulate -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d, q, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); -#endif - } - - sumf = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - size_t vl = __riscv_vsetvl_e8m1(qk); - - for (; ib < nb; ++ib) { - // load elements - vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl); - vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); - - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); - - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); - - sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); - } -#elif defined(__POWER9_VECTOR__) - const vector signed int v0 = vec_splats((int32_t)0); - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 8 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed char q8x0 = vec_xl( 0, x[ib].qs); - vector signed char q8x1 = vec_xl(16, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector signed short qv0 = vec_mule(q8x0, q8y0); - vector signed short qv1 = vec_mulo(q8x0, q8y0); - vector signed short qv2 = vec_mule(q8x1, q8y1); - vector signed short qv3 = vec_mulo(q8x1, q8y1); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - - vsumi0 = vec_sum4s(qv0, vsumi0); - vsumi1 = vec_sum4s(qv1, vsumi1); - vsumi0 = vec_sum4s(qv2, vsumi0); - vsumi1 = vec_sum4s(qv3, vsumi1); - - vsumi0 = vec_add(vsumi0, vsumi1); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - // Main loop - for (; ib < nb; ++ib) { - // Compute combined scale for the block - const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0); - __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - // Multiply q with scale and accumulate - acc = __lasx_xvfmadd_s( d, q, acc ); - } - - sumf = hsum_float_8(acc); -#endif - for (; ib < nb; ++ib) { - int sumi = 0; - - for (int j = 0; j < qk; j++) { - sumi += x[ib].qs[j]*y[ib].qs[j]; - } - - sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); - } - - *s = sumf; -} - -void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_tq1_0 * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - float sumf = 0.0f; - - uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; - - const uint8x16_t shift = vld1q_u8(k_shift); - - for (int i = 0; i < nb; ++i) { -#if defined(__ARM_FEATURE_DOTPROD) - int32x4_t sumi0 = vdupq_n_s32(0); - int32x4_t sumi1 = vdupq_n_s32(0); -#else - int16x8_t sumi0 = vdupq_n_s16(0); - int16x8_t sumi1 = vdupq_n_s16(0); -#endif - - // first 32 bytes of 5 elements - { - uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); - uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); - uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); - uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); - uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); - uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); - uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); - uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); - uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); - uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); - - // multiply by 3 and keep the 2 bits above 8 bits - int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); - int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); - int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); - int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); - int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); - const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); - const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); - const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); - const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); - const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); - const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); - const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); - const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); - const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); - sumi0 = vdotq_s32(sumi0, sqx6, qy6); - sumi1 = vdotq_s32(sumi1, sqx7, qy7); - sumi0 = vdotq_s32(sumi0, sqx8, qy8); - sumi1 = vdotq_s32(sumi1, sqx9, qy9); -#else - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); -#endif - } - - // last 16 bytes of 5-element, along with the 4 bytes of 4 elements - { - uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); - uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); - uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); - uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); - uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned - uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); - qx5 = vmulq_u8(qx5, shift); - - // multiply by 3 and keep the 2 bits above 8 bits - int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); - const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); - const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); - const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); - const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); - const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); -#else - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); -#endif - } - - const int16x8_t ysum0 = vld1q_s16(y[i].bsums); - const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vaddq_s32(sumi0, sumi1); - sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); - - sumf += d * (float) vaddvq_s32(sumi0); -#else - sumi0 = vaddq_s16(sumi0, sumi1); - sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); - - sumf += d * (float) vaddlvq_s16(sumi0); -#endif - } - - *s = sumf; - -#elif defined(__AVX2__) - __m256 sumf = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - // 16-bit sums - __m256i sumi0 = _mm256_setzero_si256(); - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - - // first 32 bytes of 5 elements - { - __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs)); - // 8-bit multiplies with shifts, masks and adds - __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3 - __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9 - __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9 - __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9 - - // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits? - - // Cancel the +1 from avg so that it behaves like a halving add - qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1)); - qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1)); - qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1)); - qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1)); - qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1)); - // Multiply by 3 and get the top 2 bits - qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256())); - qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256())); - qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256())); - qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256())); - qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256())); - qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3)); - qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3)); - qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3)); - qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); - qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); - - const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); - const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); - const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); - const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); - const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); - - qx0 = _mm256_maddubs_epi16(qx0, qy0); - qx1 = _mm256_maddubs_epi16(qx1, qy1); - qx2 = _mm256_maddubs_epi16(qx2, qy2); - qx3 = _mm256_maddubs_epi16(qx3, qy3); - qx4 = _mm256_maddubs_epi16(qx4, qy4); - - sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); - sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); - sumi2 = _mm256_add_epi16(sumi2, qx4); - } - - // last 16 bytes of 5-element, along with the 4 bytes of 4 elements - { - __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32)); - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned - __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh)); - __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3 - __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9 - __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9 - __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9 - __m256i qx01 = MM256_SET_M128I(qx1, qx0); - __m256i qx23 = MM256_SET_M128I(qx3, qx2); - - // avx2 does not have 8-bit multiplies, so 16-bit it is. - qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1)); - qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF)); - __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1)); - - __m256i qx45 = MM256_SET_M128I(qx5, qx4); - - // Cancel the +1 from avg so that it behaves like a halving add - qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1)); - qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1)); - qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1)); - // Multiply by 3 and get the top 2 bits - qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256())); - qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256())); - qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256())); - qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3)); - qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); - qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); - - const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); - const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); - const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); - - qx01 = _mm256_maddubs_epi16(qx01, qy01); - qx23 = _mm256_maddubs_epi16(qx23, qy23); - qx45 = _mm256_maddubs_epi16(qx45, qy45); - - sumi0 = _mm256_add_epi16(sumi0, qx01); - sumi1 = _mm256_add_epi16(sumi1, qx23); - sumi2 = _mm256_add_epi16(sumi2, qx45); - } - - const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); - const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); - - sumi0 = _mm256_sub_epi16(sumi0, ysum); - sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); - sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); - - sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); - } - - *s = hsum_float_8(sumf); - -#else - const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; - - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - int sum = 0; - - for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { - for (size_t l = 0; l < 5; ++l) { - for (size_t m = 0; m < 32; ++m) { - uint8_t q = x[i].qs[j + m] * pow3[l]; - uint16_t xi = ((uint16_t) q * 3) >> 8; - sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; - } - } - } - for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { - for (size_t l = 0; l < 5; ++l) { - for (size_t m = 0; m < 16; ++m) { - uint8_t q = x[i].qs[j + m] * pow3[l]; - uint16_t xi = ((uint16_t) q * 3) >> 8; - sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; - } - } - } - - for (size_t l = 0; l < 4; ++l) { - for (size_t j = 0; j < sizeof(x->qh); ++j) { - uint8_t q = x[i].qh[j] * pow3[l]; - uint16_t xi = ((uint16_t) q * 3) >> 8; - sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; - } - } - - sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d); - } - - *s = sumf; -#endif -} - -void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_tq2_0 * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - float sumf = 0.0f; - - const uint8x16_t m3 = vdupq_n_u8(3); - - for (int i = 0; i < nb; ++i) { -#if defined(__ARM_FEATURE_DOTPROD) - int32x4_t sumi0 = vdupq_n_s32(0); - int32x4_t sumi1 = vdupq_n_s32(0); -#else - int16x8_t sumi0 = vdupq_n_s16(0); - int16x8_t sumi1 = vdupq_n_s16(0); -#endif - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - uint8x16_t qx0 = vld1q_u8(x[i].qs + j); - uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); - uint8x16_t qx2 = vshrq_n_u8(qx0, 2); - uint8x16_t qx3 = vshrq_n_u8(qx1, 2); - uint8x16_t qx4 = vshrq_n_u8(qx0, 4); - uint8x16_t qx5 = vshrq_n_u8(qx1, 4); - uint8x16_t qx6 = vshrq_n_u8(qx0, 6); - uint8x16_t qx7 = vshrq_n_u8(qx1, 6); - - int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); - int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); - int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); - const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); - const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); - const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); - const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); - const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); - const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); - const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); - sumi0 = vdotq_s32(sumi0, sqx6, qy6); - sumi1 = vdotq_s32(sumi1, sqx7, qy7); -#else - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); -#endif - } - - const int16x8_t ysum0 = vld1q_s16(y[i].bsums); - const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vaddq_s32(sumi0, sumi1); - sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); - - sumf += d * (float) vaddvq_s32(sumi0); -#else - sumi0 = vaddq_s16(sumi0, sumi1); - sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); - - sumf += d * (float) vaddlvq_s16(sumi0); -#endif - } - - *s = sumf; - -#elif defined(__AVX2__) - __m256 sumf = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - // 16-bit sums, because 256*127 still fits - __m256i sumi0 = _mm256_setzero_si256(); - __m256i sumi1 = _mm256_setzero_si256(); - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j)); - __m256i qx1 = _mm256_srli_epi16(qx0, 2); - __m256i qx2 = _mm256_srli_epi16(qx0, 4); - __m256i qx3 = _mm256_srli_epi16(qx0, 6); - - // 0, 1, 2 (should not be 3) - qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3)); - qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3)); - qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3)); - qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3)); - - const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); - const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); - const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); - const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); - - qx0 = _mm256_maddubs_epi16(qx0, qy0); - qx1 = _mm256_maddubs_epi16(qx1, qy1); - qx2 = _mm256_maddubs_epi16(qx2, qy2); - qx3 = _mm256_maddubs_epi16(qx3, qy3); - - sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); - sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); - } - - const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); - const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); - - sumi0 = _mm256_add_epi16(sumi0, sumi1); - sumi0 = _mm256_sub_epi16(sumi0, ysum); - sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); - - sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); - } - - *s = hsum_float_8(sumf); - -#else - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - int32_t sumi = 0; - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - for (size_t l = 0; l < 4; ++l) { - for (size_t k = 0; k < 32; ++k) { - sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); - } - } - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - sumf += (float) sumi * d; - } - - *s = sumf; -#endif -} - -void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q2_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - const uint8x16_t m3 = vdupq_n_u8(0x3); - const uint8x16_t m4 = vdupq_n_u8(0xF); - - const int32x4_t vzero = vdupq_n_s32(0); - - ggml_int8x16x2_t q2bytes; - uint8_t aux[16]; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint8_t * restrict sc = x[i].scales; - - const uint8x16_t mins_and_scales = vld1q_u8(sc); - const uint8x16_t scales = vandq_u8(mins_and_scales, m4); - vst1q_u8(aux, scales); - - const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); - const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); - const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; - const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), - vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); - const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), - vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); - sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); - - int isum = 0; - int is = 0; - -// We use this macro instead of a function call because for some reason -// the code runs 2-3% slower, even if the function is declared inline -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; - -#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ - q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ - MULTIPLY_ACCUM_WITH_SCALE((index)); - - for (int j = 0; j < QK_K/128; ++j) { - const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; - - ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); - - MULTIPLY_ACCUM_WITH_SCALE(0); - - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); - - is += 8; - } - - sum += d * isum; - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m128i m4 = _mm_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m256i mins = _mm256_cvtepi8_epi16(mins8); - const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); - - const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - - __m256i sumi = _mm256_setzero_si256(); - - for (int j = 0; j < QK_K/128; ++j) { - - const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i q2_0 = _mm256_and_si256(q2bits, m3); - const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); - const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); - const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); - - __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); - __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); - - p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); - - p0 = _mm256_add_epi32(p0, p1); - p2 = _mm256_add_epi32(p2, p3); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); - } - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(0x3); - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - // load mins and scales from block_q2_K.scales[QK_K/16] - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); - const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); - - // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 - const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); - const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); - - // sumf += -dmin * summs in 32bits*8 - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); - - const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); - const __m128i scales[2] = { scales_0, scales_1 }; - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - for (int j = 0; j < QK_K/128; ++j) { - - // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - // load 2bits*16*8 from block_q2_K.qs[QK_K/4] - __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_1 = _mm_and_si128(q2bits, m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - - // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 - __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); - __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); - __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); - __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); - __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); - __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); - __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); - __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); - - // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 - __m128i shuffle = _mm_set1_epi16(0x0100); - p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); - shuffle = _mm_add_epi16(shuffle, m2); - p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); - shuffle = _mm_add_epi16(shuffle, m2); - p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); - shuffle = _mm_add_epi16(shuffle, m2); - p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); - shuffle = _mm_add_epi16(shuffle, m2); - p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); - shuffle = _mm_add_epi16(shuffle, m2); - p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); - shuffle = _mm_add_epi16(shuffle, m2); - p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); - shuffle = _mm_add_epi16(shuffle, m2); - p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); - - p0 = _mm_add_epi32(p0, p1); - p2 = _mm_add_epi32(p2, p3); - p4 = _mm_add_epi32(p4, p5); - p6 = _mm_add_epi32(p6, p7); - - // isum in 32bits*4*2 - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); - } - - // sumf += dall * isum - dmin * summs in 32bits - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - size_t vl = 16; - - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - - vl = 32; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - - uint8_t is=0; - int isum=0; - - for (int j = 0; j < QK_K/128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); - - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); - - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); - - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(isum1); - - q2+=32; q8+=128; is=8; - - } - - sumf += dall * isum; - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0x3); - const vector signed char lowScaleMask = vec_splats((signed char)0xF); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); - vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - - vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales); - vector signed char vscales = vec_and(q2xmins, lowScaleMask); - - q2xmins = vec_sr(q2xmins, v4); - vector signed short q2xmins0 = vec_unpackh(q2xmins); - vector signed short q2xmins1 = vec_unpackl(q2xmins); - - vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); - vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); - vector signed int prod2 = vec_mule(q2xmins1, q8ysums1); - vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); - vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - vector signed int vsumi4 = v0; - vector signed int vsumi5 = v0; - vector signed int vsumi6 = v0; - vector signed int vsumi7 = v0; - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/128; ++j) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q2); - vector signed char qxs1 = (vector signed char)vec_xl(16, q2); - q2 += 32; - - vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask); - vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask); - vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask); - vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask); - vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask); - vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask); - vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask); - vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y01 = vec_xl( 32, q8); - vector signed char q8y11 = vec_xl( 48, q8); - vector signed char q8y02 = vec_xl( 64, q8); - vector signed char q8y12 = vec_xl( 80, q8); - vector signed char q8y03 = vec_xl( 96, q8); - vector signed char q8y13 = vec_xl(112, q8); - q8 += 128; - - vector signed int qv0 = vec_msum(q8y00, q2x00, v0); - vector signed int qv1 = vec_msum(q8y01, q2x01, v0); - vector signed int qv2 = vec_msum(q8y02, q2x02, v0); - vector signed int qv3 = vec_msum(q8y03, q2x03, v0); - vector signed int qv4 = vec_msum(q8y10, q2x10, v0); - vector signed int qv5 = vec_msum(q8y11, q2x11, v0); - vector signed int qv6 = vec_msum(q8y12, q2x12, v0); - vector signed int qv7 = vec_msum(q8y13, q2x13, v0); - - vector signed short vscales_07 = vec_unpackh(vscales); - vector signed int vscales_03 = vec_unpackh(vscales_07); - vector signed int vscales_47 = vec_unpackl(vscales_07); - vector signed int vs0 = vec_splat(vscales_03, 0); - vector signed int vs1 = vec_splat(vscales_03, 1); - vector signed int vs2 = vec_splat(vscales_03, 2); - vector signed int vs3 = vec_splat(vscales_03, 3); - vector signed int vs4 = vec_splat(vscales_47, 0); - vector signed int vs5 = vec_splat(vscales_47, 1); - vector signed int vs6 = vec_splat(vscales_47, 2); - vector signed int vs7 = vec_splat(vscales_47, 3); - vscales = vec_sld(vscales, vscales, 8); - - vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1); - vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2); - vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3); - vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4); - vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5); - vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6); - vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7); - } - - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - const __m256i m3 = __lasx_xvreplgr2vr_b(3); - const __m128i m4 = __lsx_vreplgr2vr_b(0xF); - - __m256 acc = (__m256)__lasx_xvldi(0); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0); - const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4); - const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4); - const __m256i mins = lasx_ext8_16(mins8); - const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); - - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); - - const __m256i all_scales = lasx_ext8_16(scales8); - const __m128i l_scales = lasx_extracti128(all_scales, 0); - const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; - - __m256i sumi = __lasx_xvldi(0); - - for (int j = 0; j < QK_K/128; ++j) { - - const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32; - - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - const __m256i q2_0 = __lasx_xvand_v(q2bits, m3); - const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3); - const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3); - const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3); - - __m256i p0 = lasx_maddubs_h(q2_0, q8_0); - __m256i p1 = lasx_maddubs_h(q2_1, q8_1); - __m256i p2 = lasx_maddubs_h(q2_2, q8_2); - __m256i p3 = lasx_maddubs_h(q2_3, q8_3); - - p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3); - - p0 = __lasx_xvadd_w(p0, p1); - p2 = __lasx_xvadd_w(p2, p3); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2)); - } - - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#else - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - int summs = 0; - for (int j = 0; j < 16; ++j) { - summs += y[i].bsums[j] * (sc[j] >> 4); - } - - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - int isum = 0; - int is = 0; - int d; - for (int k = 0; k < QK_K/128; ++k) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - d = sc[is++] & 0xF; - int isuml = 0; - for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - d = sc[is++] & 0xF; - isuml = 0; - for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - shift += 2; - q8 += 32; - } - q2 += 32; - } - sumf += dall * isum - dmin * summs; - } - *s = sumf; -#endif -} - -void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; - - const block_q3_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - - uint32_t aux[3]; - uint32_t utmp[4]; - - const uint8x16_t m3b = vdupq_n_u8(0x3); - const int32x4_t vzero = vdupq_n_s32(0); - - const uint8x16_t m0 = vdupq_n_u8(1); - const uint8x16_t m1 = vshlq_n_u8(m0, 1); - const uint8x16_t m2 = vshlq_n_u8(m0, 2); - const uint8x16_t m3 = vshlq_n_u8(m0, 3); - const int8_t m32 = 32; - - ggml_int8x16x4_t q3bytes; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); - - ggml_uint8x16x4_t q3h; - - int32_t isum = 0; - - // Set up scales - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= m32; - - for (int j = 0; j < QK_K/128; ++j) { - - const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; - const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; - const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; - - q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); - q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); - q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); - q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; - - scale += 4; - - q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); - q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); - q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); - q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; - - scale += 4; - - if (j == 0) { - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); - } - - } - sum += d * isum; - - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m256i mone = _mm256_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - uint32_t aux[3]; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - // Set up scales - memcpy(aux, x[i].scales, 12); - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - - // high bit - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); - - // integer accumulator - __m256i sumi = _mm256_setzero_si256(); - - int bit = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits - const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; - - // prepare low and high bits - const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); - const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); - const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); - const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); - const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); - - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - - // multiply with scales - p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); - - // accumulate - p16_0 = _mm256_add_epi32(p16_0, p16_1); - p16_2 = _mm256_add_epi32(p16_2, p16_3); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); - - } - - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(3); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - const uint32_t *aux; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - // Set up scales - aux = (const uint32_t *)x[i].scales; - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); - const __m128i scales[2] = { scales_0, scales_1 }; - - // high bit *128*2 from block_q3_K.hmask[QK_K/8] - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); - - // integer accumulator - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] - const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - - // prepare low and high bits - const int bit = j << 2; - - const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); - const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); - const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); - const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); - - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); - const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - - const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); - const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); - const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - - const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); - const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); - const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - - // load Q8 quants from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); - - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); - - // multiply with scales - __m128i shuffle = _mm_set1_epi16(0x0100); - p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); - shuffle = _mm_add_epi16(shuffle, m2); - p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); - shuffle = _mm_add_epi16(shuffle, m2); - p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); - shuffle = _mm_add_epi16(shuffle, m2); - p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); - shuffle = _mm_add_epi16(shuffle, m2); - p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); - shuffle = _mm_add_epi16(shuffle, m2); - p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); - shuffle = _mm_add_epi16(shuffle, m2); - p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); - shuffle = _mm_add_epi16(shuffle, m2); - p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); - - // accumulate - p16_0 = _mm_add_epi32(p16_0, p16_1); - p16_2 = _mm_add_epi32(p16_2, p16_3); - p16_4 = _mm_add_epi32(p16_4, p16_5); - p16_6 = _mm_add_epi32(p16_6, p16_7); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); - - } - - // multiply with block scale and accumulate - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - uint32_t aux[3]; - uint32_t utmp[4]; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; - - - size_t vl = 32; - uint8_t m = 1; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - - int sum_t = 0; - - for (int j = 0; j < QK_K; j += 128) { - - vl = 32; - - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; - - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - - q3 += 32; q8 += 128; scale += 8; - - } - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - sumf += d*sum_t; - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0x3); - const vector signed char lowMask1 = vec_splats((int8_t)0xf); - const vector signed char lowMask2 = vec_splats((int8_t)0x30); - const vector int v0 = vec_splats((int32_t)0); - const vector signed char v1 = vec_splats((signed char)0x1); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector signed char off = vec_splats((signed char)0x20); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - UNUSED(kmask1); - UNUSED(kmask2); - - vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); - vector signed char u1 = vec_and(u0, lowMask1); - vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); - vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2)); - vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4); - vector signed char u31 = vec_and(u3, lowMask2); - - u1 = vec_or(u1, u30); - u2 = vec_or(vec_sr(u0, v4), u31); - - vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2); - vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); - vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); - - vscales = vec_sub(vscales, off); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - vector signed int vsumi4 = v0; - vector signed int vsumi5 = v0; - vector signed int vsumi6 = v0; - vector signed int vsumi7 = v0; - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/128; ++j) { - __builtin_prefetch(q3, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q3); - vector signed char qxs1 = (vector signed char)vec_xl(16, q3); - q3 += 32; - - //the low 2 bits - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); - vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask); - vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask); - vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask); - vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask); - - //the 3rd bit - vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); - vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2); - vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); - vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2); - vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2); - vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2); - vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2); - vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2); - qxhs0 = vec_sr(qxhs0, v4); - qxhs1 = vec_sr(qxhs1, v4); - - vector signed char q3x00 = vec_sub(qxs00, qxh00); - vector signed char q3x01 = vec_sub(qxs01, qxh01); - vector signed char q3x02 = vec_sub(qxs02, qxh02); - vector signed char q3x03 = vec_sub(qxs03, qxh03); - vector signed char q3x10 = vec_sub(qxs10, qxh10); - vector signed char q3x11 = vec_sub(qxs11, qxh11); - vector signed char q3x12 = vec_sub(qxs12, qxh12); - vector signed char q3x13 = vec_sub(qxs13, qxh13); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y01 = vec_xl( 32, q8); - vector signed char q8y11 = vec_xl( 48, q8); - vector signed char q8y02 = vec_xl( 64, q8); - vector signed char q8y12 = vec_xl( 80, q8); - vector signed char q8y03 = vec_xl( 96, q8); - vector signed char q8y13 = vec_xl(112, q8); - q8 += 128; - - vector signed short vscales_h = vec_unpackh(vscales); - vector signed short vs0 = vec_splat(vscales_h, 0); - vector signed short vs1 = vec_splat(vscales_h, 1); - vector signed short vs2 = vec_splat(vscales_h, 2); - vector signed short vs3 = vec_splat(vscales_h, 3); - vector signed short vs4 = vec_splat(vscales_h, 4); - vector signed short vs5 = vec_splat(vscales_h, 5); - vector signed short vs6 = vec_splat(vscales_h, 6); - vector signed short vs7 = vec_splat(vscales_h, 7); - vscales = vec_sld(vscales, vscales, 8); - - vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); - vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); - vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02)); - vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03)); - vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); - vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); - vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); - vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); - - vsumi0 = vec_msum(qv00, vs0, vsumi0); - vsumi1 = vec_msum(qv01, vs2, vsumi1); - vsumi2 = vec_msum(qv02, vs4, vsumi2); - vsumi3 = vec_msum(qv03, vs6, vsumi3); - vsumi4 = vec_msum(qv10, vs1, vsumi4); - vsumi5 = vec_msum(qv11, vs3, vsumi5); - vsumi6 = vec_msum(qv12, vs5, vsumi6); - vsumi7 = vec_msum(qv13, vs7, vsumi7); - } - - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - const __m256i m3 = __lasx_xvreplgr2vr_b(3); - const __m256i mone = __lasx_xvreplgr2vr_b(1); - const __m128i m32 = __lsx_vreplgr2vr_b(32); - - __m256 acc = (__m256)__lasx_xvldi(0); - - uint32_t aux[3]; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - // Set up scales - memcpy(aux, x[i].scales, 12); - __m128i scales128 = lsx_set_w( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = __lsx_vsub_b(scales128, m32); - const __m256i all_scales = lasx_ext8_16(scales128); - const __m128i l_scales = lasx_extracti128(all_scales, 0); - const __m128i h_scales = lasx_extracti128(all_scales, 1); - const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; - - // high bit - const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); - - // integer accumulator - __m256i sumi = __lasx_xvldi(0); - - int bit = 0; - int is = 0; - __m256i xvbit; - - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits - const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; - - xvbit = __lasx_xvreplgr2vr_h(bit); - // prepare low and high bits - const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3); - const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3); - const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3); - const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - xvbit = __lasx_xvreplgr2vr_h(bit); - const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3); - const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); - ++bit; - - // load Q8 quants - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0); - __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1); - __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2); - __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3); - - __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1); - __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2); - __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3); - - p16_0 = __lasx_xvsub_h(p16_0, q8s_0); - p16_1 = __lasx_xvsub_h(p16_1, q8s_1); - p16_2 = __lasx_xvsub_h(p16_2, q8s_2); - p16_3 = __lasx_xvsub_h(p16_3, q8s_3); - - // multiply with scales - p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); - - // accumulate - p16_0 = __lasx_xvadd_w(p16_0, p16_1); - p16_2 = __lasx_xvadd_w(p16_2, p16_3); - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); - } - // multiply with block scale and accumulate - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME - } - - *s = hsum_float_8(acc); - -#else - // scalar version - // This function is written like this so the compiler can manage to vectorize most of it - // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the - // manually vectorized version above. Every other version I tried would run at least 4 times slower. - // The ideal situation would be if we could just write the code once, and the compiler would - // automatically produce the best possible set of machine instructions, instead of us having to manually - // write vectorized versions for AVX, ARM_NEON, etc. - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - uint32_t auxs[4]; - const int8_t * scales = (const int8_t*)auxs; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - q3 += 32; - } - a = aux8; - - memcpy(auxs, x[i].scales, 12); - uint32_t tmp = auxs[2]; - auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); - auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - for (int j = 0; j < QK_K/16; ++j) { - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; - q8 += 8; a += 8; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; - -#endif - -} - -void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q4_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - -#ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); - const int32x4_t mzero = vdupq_n_s32(0); - - ggml_int8x16x2_t q4bytes; - ggml_int8x16x2_t q8bytes; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - - uint32x2_t mins8 = { 0 }; - mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); - mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); - - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; - - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - sumf -= dmin * vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - int32_t sumi1 = 0; - int32_t sumi2 = 0; - - for (int j = 0; j < QK_K/64; ++j) { - const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; - - q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - sumi1 += vaddvq_s32(p1) * scales[2*j+0]; - - q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - - sumi2 += vaddvq_s32(p2) * scales[2*j+1]; - } - - sumf += d * (sumi1 + sumi2); - - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); - - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); - - __m256i sumi = _mm256_setzero_si256(); - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); - - const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - p16l = _mm256_madd_epi16(scale_l, p16l); - - const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - p16h = _mm256_madd_epi16(scale_h, p16h); - const __m256i sumj = _mm256_add_epi32(p16l, p16h); - - sumi = _mm256_add_epi32(sumi, sumj); - } - - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - - } - - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); - - __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { - - const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - - __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_0 = _mm_and_si128(q4bits, m4); - const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_1 = _mm_and_si128(q4bits, m4); - const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - - const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_0 = _mm_add_epi32(sumi_0, p16l); - const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16l = _mm_maddubs_epi16(q4l_1, q8l_1); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_1 = _mm_add_epi32(sumi_1, p16l); - - const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_0 = _mm_add_epi32(sumi_0, p16h); - const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16h = _mm_maddubs_epi16(q4h_1, q8h_1); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_1 = _mm_add_epi32(sumi_1, p16h); - - } - - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - - } - - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); - -#elif defined __riscv_v_intrinsic - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - size_t vl = 8; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - vl = 32; - - int32_t sum_1 = 0; - int32_t sum_2 = 0; - - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; - - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); - - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; - - q4 += 32; q8 += 64; - - } - - sumf += d*(sum_1 + sum_2); - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed char lowMask1 = vec_splats((int8_t)0x3f); - const vector signed char lowMask2 = vec_splats((int8_t)0x30); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v2 = vec_splats((uint8_t)2); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); - vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(utmp); - - vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); - vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); - vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); - vector signed char u3 = vec_sr(u2, v4); - - vector signed char u30 = u1; - vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); - - u1 = vec_and(u0, lowMask1); - u2 = vec_or(u30, u31); - - vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); - - vector signed short vscales = vec_unpackh(utmps); - vector signed short q4xmins = vec_unpackl(utmps); - vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); - vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins); - - vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); - vector signed int prod1 = vec_mule(q4xmins1, q8ysums1); - vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0); - vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); - vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/64; j+=2) { - __builtin_prefetch(q4, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); - vector signed char qxs1 = (vector signed char)vec_xl(16, q4); - vector signed char qxs2 = (vector signed char)vec_xl(32, q4); - vector signed char qxs3 = (vector signed char)vec_xl(48, q4); - q4 += 64; - - vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask); - vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4); - vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask); - vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4); - vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask); - vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4); - vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask); - vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y01 = vec_xl( 32, q8); - vector signed char q8y11 = vec_xl( 48, q8); - vector signed char q8y20 = vec_xl( 64, q8); - vector signed char q8y30 = vec_xl( 80, q8); - vector signed char q8y21 = vec_xl( 96, q8); - vector signed char q8y31 = vec_xl(112, q8); - q8 += 128; - - vector signed int qv00 = vec_msum(q8y00, q4x00, v0); - vector signed int qv01 = vec_msum(q8y01, q4x01, v0); - vector signed int qv10 = vec_msum(q8y10, q4x10, v0); - vector signed int qv11 = vec_msum(q8y11, q4x11, v0); - vector signed int qv20 = vec_msum(q8y20, q4x20, v0); - vector signed int qv21 = vec_msum(q8y21, q4x21, v0); - vector signed int qv30 = vec_msum(q8y30, q4x30, v0); - vector signed int qv31 = vec_msum(q8y31, q4x31, v0); - - vector signed int vscales_h = vec_unpackh(vscales); - vector signed int vs0 = vec_splat(vscales_h, 0); - vector signed int vs1 = vec_splat(vscales_h, 1); - vector signed int vs2 = vec_splat(vscales_h, 2); - vector signed int vs3 = vec_splat(vscales_h, 3); - vscales = vec_sld(vscales, vscales, 8); - - vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1); - vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2); - vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3); - - vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1); - vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2); - vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - GGML_UNUSED(kmask1); - GGML_UNUSED(kmask2); - GGML_UNUSED(kmask3); - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - - __m256 acc = (__m256)__lasx_xvldi(0); - __m128 acc_m = (__m128)__lsx_vldi(0); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); - const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); - acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); - - const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = lasx_insertf128(sc128, sc128); - - __m256i sumi = __lasx_xvldi(0); - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4l = __lasx_xvand_v(q4bits, m4); - const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4); - - const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16l = lasx_maddubs_h(q4l, q8l); - p16l = lasx_madd_h(scale_l, p16l); - - const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16h = lasx_maddubs_h(q4h, q8h); - p16h = lasx_madd_h(scale_h, p16h); - const __m256i sumj = __lasx_xvadd_w(p16l, p16h); - - sumi = __lasx_xvadd_w(sumi, sumj); - } - - __m256 vd = __lasx_xvreplfr2vr_s(d); - acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); - - } - - acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee)); - __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0); - acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1); - - - ft_union fi; - fi.i = __lsx_vpickve2gr_w(acc_m, 0); - *s = hsum_float_8(acc) + fi.f ; -#else - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - a += 32; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - a += 32; q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - -#ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mone = vdupq_n_u8(1); - const uint8x16_t mtwo = vdupq_n_u8(2); - const int32x4_t mzero = vdupq_n_s32(0); - - ggml_int8x16x4_t q5bytes; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - int32_t sumi_mins = vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); - - ggml_uint8x16x4_t q5h; - - int32_t sumi = 0; - - for (int j = 0; j < QK_K/64; ++j) { - - const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; - const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - - q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); - q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); - - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); - - sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; - sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; - } - - sumf += d * sumi - dmin * sumi_mins; - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m256i mone = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); - - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); - - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); - __m256i hmask = mone; - - __m256i sumi = _mm256_setzero_si256(); - - int bit = 0; - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; - - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); - hmask = _mm256_slli_epi16(hmask, 1); - - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); - hmask = _mm256_slli_epi16(hmask, 1); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); - - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - - } - - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); - - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); - __m128i hmask = mone; - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - int bit = 0; - - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - - const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - - __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); - __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); - __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); - __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); - - __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_0, p16_1); - - q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); - q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); - q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - q5_0 = _mm_add_epi8(q5l_0, q5h_0); - q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); - - q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); - p16_2 = _mm_madd_epi16(scale_1, p16_2); - p16_3 = _mm_madd_epi16(scale_1, p16_3); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - - } - - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __riscv_v_intrinsic - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - float sumf = 0; - float sums = 0.0; - - size_t vl; - - for (int i = 0; i < nb; ++i) { - - vl = 8; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - - vl = 32; - int32_t aux32 = 0; - int is = 0; - - uint8_t m = 1; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); - - for (int j = 0; j < QK_K/64; ++j) { - // load Q5 and Q8 - vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); - vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); - - // compute mask for addition - vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_mu(vmask_1, q5_a, q5_a, 16, vl); - m <<= 1; - - vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_mu(vmask_2, q5_l, q5_l, 16, vl); - m <<= 1; - - vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); - vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); - - vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); - vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); - - vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); - vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); - - aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); - q5 += 32; q8 += 64; - - } - - vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); - sums += __riscv_vfmv_f_s_f32m1_f32(vaux); - - } - - *s = sumf+sums; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed char lowMask1 = vec_splats((int8_t)0x3f); - const vector signed char lowMask2 = vec_splats((int8_t)0x30); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v1 = vec_splats((unsigned char)0x1); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(utmp); - - vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); - vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); - vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); - vector signed char u3 = vec_sr(u2, v4); - - vector signed char u30 = u1; - vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); - - u1 = vec_and(u0, lowMask1); - u2 = vec_or(u30, u31); - - vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); - - vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); - vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - - vector signed short vscales = vec_unpackh(utmps); - - vector signed short q5xmins = vec_unpackl(utmps); - vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins); - vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins); - - vector signed int prod0 = vec_mule(q5xmins0, q8ysums0); - vector signed int prod1 = vec_mule(q5xmins1, q8ysums1); - vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0); - vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); - vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - - vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); - vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/64; ++j) { - __builtin_prefetch(q5, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q5); - vector signed char qxs1 = (vector signed char)vec_xl(16, q5); - q5 += 32; - - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_sr(qxs0, v4); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_sr(qxs1, v4); - - vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4); - vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3); - vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4); - vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3); - qxhs0 = vec_sr(qxhs0, v2); - qxhs1 = vec_sr(qxhs1, v2); - - vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00); - vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01); - vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10); - vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl(16, q8); - vector signed char q8y01 = vec_xl(32, q8); - vector signed char q8y11 = vec_xl(48, q8); - q8 += 64; - - vector signed int qv00 = vec_msum(q8y00, q5x00, v0); - vector signed int qv01 = vec_msum(q8y01, q5x01, v0); - vector signed int qv10 = vec_msum(q8y10, q5x10, v0); - vector signed int qv11 = vec_msum(q8y11, q5x11, v0); - - vector signed int vscales_h = vec_unpackh(vscales); - vector signed int vs0 = vec_splat(vscales_h, 0); - vector signed int vs1 = vec_splat(vscales_h, 1); - vscales = vec_sld(vscales, vscales, 12); - - vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1); - vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2); - vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - GGML_UNUSED(kmask1); - GGML_UNUSED(kmask2); - GGML_UNUSED(kmask3); - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - const __m128i mzero = __lsx_vldi(0); - const __m256i mone = __lasx_xvreplgr2vr_b(1); - - __m256 acc = (__m256)__lasx_xvldi(0); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); - const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); - const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero); - summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check - - const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); - const __m256i scales = lasx_insertf128(sc128, sc128); - - const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); - __m256i hmask = mone; - - __m256i sumi = __lasx_xvldi(0); - - int bit = 0; - __m256i xvbit; - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; - - xvbit = __lasx_xvreplgr2vr_h(bit++); - const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4); - const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); - const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0); - hmask = __lasx_xvslli_h(hmask, 1); - - xvbit = __lasx_xvreplgr2vr_h(bit++); - const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4); - const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); - const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1); - hmask = __lasx_xvslli_h(hmask, 1); - - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1); - - p16_0 = lasx_madd_h(scale_0, p16_0); - p16_1 = lasx_madd_h(scale_1, p16_1); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); - - } - - __m256 vd = __lasx_xvreplfr2vr_s(d); - acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#else - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q6_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - float sum = 0; - - const uint8x16_t m4b = vdupq_n_u8(0xF); - const int32x4_t vzero = vdupq_n_s32(0); - //const int8x16_t m32s = vdupq_n_s8(32); - - const uint8x16_t mone = vdupq_n_u8(3); - - ggml_int8x16x4_t q6bytes; - ggml_uint8x16x4_t q6h; - - for (int i = 0; i < nb; ++i) { - - const float d_all = GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); - const int8x16_t scales = vld1q_s8(scale); - const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; - - const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), - vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), - vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), - vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); - int32_t isum_mins = vaddvq_s32(prod); - - int32_t isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; - ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; - ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 2); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - - scale += 4; - - q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - - shifted = vshrq_n_u8(qhbits.val[0], 4); - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[0], 6); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; - } - //sum += isum * d_all * y[i].d; - sum += d_all * y[i].d * (isum - 32 * isum_mins); - - } - *s = sum; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); - - __m256i sumi = _mm256_setzero_si256(); - - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); - - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); - - } - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m32s = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); - const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); - const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); - const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); - const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); - - const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); - const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); - const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); - const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); - const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); - - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); - - __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); - - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi8(shuffle, m2); - - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); - p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); - p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); - p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); - p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); - - } - - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - size_t vl; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - int sum_t = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - vl = 32; - - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - - q6 += 64; qh += 32; q8 += 128; is=8; - - } - - sumf += d * sum_t; - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector signed char off = vec_splats((signed char)0x20); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - vector signed int vsumi4 = v0; - vector signed int vsumi5 = v0; - vector signed int vsumi6 = v0; - vector signed int vsumi7 = v0; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict qs = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/128; ++j) { - __builtin_prefetch(q6, 0, 0); - __builtin_prefetch(qh, 0, 0); - __builtin_prefetch(q8, 0, 0); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q6); - vector signed char qxs1 = (vector signed char)vec_xl(16, q6); - vector signed char qxs2 = (vector signed char)vec_xl(32, q6); - vector signed char qxs3 = (vector signed char)vec_xl(48, q6); - q6 += 64; - - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_sr(qxs0, v4); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_sr(qxs1, v4); - vector signed char qxs20 = vec_and(qxs2, lowMask); - vector signed char qxs21 = vec_sr(qxs2, v4); - vector signed char qxs30 = vec_and(qxs3, lowMask); - vector signed char qxs31 = vec_sr(qxs3, v4); - - vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh); - vector signed char qxhs1 = (vector signed char)vec_xl(16, qh); - qh += 32; - - vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); - vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); - vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4); - vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4); - vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); - vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); - vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4); - vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4); - - vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); - vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); - vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); - vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); - vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off); - vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off); - vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off); - vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y20 = vec_xl( 32, q8); - vector signed char q8y30 = vec_xl( 48, q8); - vector signed char q8y01 = vec_xl( 64, q8); - vector signed char q8y11 = vec_xl( 80, q8); - vector signed char q8y21 = vec_xl( 96, q8); - vector signed char q8y31 = vec_xl(112, q8); - q8 += 128; - - vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); - vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); - vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20)); - vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30)); - vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); - vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); - vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21)); - vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31)); - - vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8)); - qs += 8; - - vector signed short vs0 = vec_splat(vscales, 0); - vector signed short vs1 = vec_splat(vscales, 1); - vector signed short vs2 = vec_splat(vscales, 2); - vector signed short vs3 = vec_splat(vscales, 3); - vector signed short vs4 = vec_splat(vscales, 4); - vector signed short vs5 = vec_splat(vscales, 5); - vector signed short vs6 = vec_splat(vscales, 6); - vector signed short vs7 = vec_splat(vscales, 7); - - vsumi0 = vec_msum(qv00, vs0, vsumi0); - vsumi1 = vec_msum(qv01, vs4, vsumi1); - vsumi2 = vec_msum(qv10, vs1, vsumi2); - vsumi3 = vec_msum(qv11, vs5, vsumi3); - vsumi4 = vec_msum(qv20, vs2, vsumi4); - vsumi5 = vec_msum(qv21, vs6, vsumi5); - vsumi6 = vec_msum(qv30, vs3, vsumi6); - vsumi7 = vec_msum(qv31, vs7, vsumi7); - } - - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); - const __m256i m2 = __lasx_xvreplgr2vr_b(3); - const __m256i m32s = __lasx_xvreplgr2vr_b(32); - - __m256 acc = (__m256)__lasx_xvldi(0); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0); - - __m256i sumi = __lasx_xvldi(0); - - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3)); - is += 4; - - const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; - - const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4); - const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4); - - const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0); - const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1); - const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3); - - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0); - __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1); - __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2); - __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3); - - __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0); - __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1); - __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2); - __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3); - - p16_0 = __lasx_xvsub_h(p16_0, q8s_0); - p16_1 = __lasx_xvsub_h(p16_1, q8s_1); - p16_2 = __lasx_xvsub_h(p16_2, q8s_2); - p16_3 = __lasx_xvsub_h(p16_3, q8s_3); - - p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0); - p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1); - p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2); - p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); - } - - acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); - } - - *s = hsum_float_8(acc); - -#else - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - } - a += 128; - q4 += 64; - qh += 32; - } - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) -static const int8_t keven_signs_q2xs[1024] = { - 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, - 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, - 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, - 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, - 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, - 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, - 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, - 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, - 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, - 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, - 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, - 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, - 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, - 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, - 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, - 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, - 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, - 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, - 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, - 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, - 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, - 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, - 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, - 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, - 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, - 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, - 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, - 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, - 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, - 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, - 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, - 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -}; -#endif - -void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq2_xxs * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - ggml_int8x16x4_t q2u; - ggml_int8x16x4_t q2s; - ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - float sumf1 = 0, sumf2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); - q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); - q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); - q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); - q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); - q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); - q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); - q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); - q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); - q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); - q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); - q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); - sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); - sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); - } - sumf += d*(sumf1 + sumf2); - } - *s = 0.25f * sumf; - -#elif defined(__AVX2__) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], - signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__AVX__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]); - const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]); - const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); - const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]); - const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); - const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); - const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); - const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - const vector int v0 = vec_splats((int32_t)0); - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - memcpy(aux32, q2, 4*sizeof(uint32_t)); - q2 += 8; - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])}; - - vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127))}; - vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))}; - vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127))}; - vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))}; - - vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); - vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); - vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); - vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); - - const uint16_t ls0 = aux32[1] >> 28; - const uint16_t ls1 = aux32[3] >> 28; - - vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); - vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.125f * vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - - const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], - signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); - const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; - const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); - const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = 0.125f * hsum_float_8(accumf); - -#else - - uint32_t aux32[2]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - memcpy(aux32, q2, 2*sizeof(uint32_t)); - q2 += 4; - const uint32_t ls = 2*(aux32[1] >> 28) + 1; - int32_t sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); - const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += sumi * ls; - } - sumf += d * bsum; - } - *s = 0.125f * sumf; -#endif -} - -void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq2_xs * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - ggml_int8x16x4_t q2u; - ggml_int8x16x4_t q2s; - ggml_int8x16x4_t q8b; - - int32x4x4_t scales32; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint8x8_t scales8 = vld1_u8(x[i].scales); - const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); - const uint8x8_t scales_h = vshr_n_u8(scales8, 4); - uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); - scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); - const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); - const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); - scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); - scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); - scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); - scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); - int32x4_t sumi = vdupq_n_s32(0); - for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); - q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); - q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); - q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); - q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); - q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); - q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); - q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); - q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); - q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); - q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); - q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); - const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); - const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); - const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); - const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); - const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); - sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); - q2 += 8; - } - sumf += d*vaddvq_s32(sumi); - } - *s = 0.125f * sumf; - -#elif defined(__AVX2__) - - const __m256i mone = _mm256_set1_epi8(1); - static const char block_sign_shuffle_mask_1[32] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, - }; - static const char block_sign_shuffle_mask_2[32] = { - 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, - 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, - }; - static const uint8_t bit_selector_mask_bytes[32] = { - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); - const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); - const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); - - static const uint8_t k_bit_helper[32] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); - const __m256i m511 = _mm256_set1_epi16(511); - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - uint64_t aux64; - - // somewhat hacky, but gives a significant boost in performance - __m256i aux_gindex; - const uint16_t * gindex = (const uint16_t *)&aux_gindex; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - __m128i stmp = _mm_set1_epi64x(aux64); - stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); - const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); - - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - - const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; - aux_gindex = _mm256_and_si256(q2_data, m511); - - const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); - const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); - const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); - - const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); - const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); - - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - - const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], - iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); - const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], - iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); - const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], - iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); - const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], - iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - - const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits); - const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1); - const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l); - const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h); - - __m256i signs; - signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3); - const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4); - - const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); - const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); - const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2))); - const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3))); - - sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); - sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); - sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3)); - sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4)); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__AVX__) - const __m128i mone = _mm_set1_epi8(1); - static const char block_sign_shuffle_mask_1[32] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, - }; - static const char block_sign_shuffle_mask_2[32] = { - 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, - 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, - }; - static const uint8_t bit_selector_mask_bytes[32] = { - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes); - const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1); - const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1); - const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1); - const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2); - const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1); - - static const uint8_t k_bit_helper[32] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper); - const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1); - const __m128i m511 = _mm_set1_epi16(511); - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - uint64_t aux64; - - // somewhat hacky, but gives a significant boost in performance - __m256i aux_gindex; - const uint16_t * gindex = (const uint16_t *)&aux_gindex; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - __m128i stmp = _mm_set1_epi64x(aux64); - stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); - const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - - const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2); - const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16; - aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511)); - - const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9); - const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9); - const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13); - const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13); - const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0); - const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1); - - const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0); - const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1); - const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0); - const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1); - - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - - const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); - const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]); - const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); - const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]); - const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]); - const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]); - const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]); - - // AVX2 full_signs_1 is full_sign_bits_0 here - // AVX2 full_signs_2 is full_sign_bits_1 here - __m128i signs_0, signs_1; - signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone)); - - signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone)); - - signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone)); - - signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone)); - - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0); - const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1); - const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0); - const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1); - - __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); - const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); - const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); - const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); - const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1)); - sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0)); - sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1)); - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1)); - sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0)); - sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1)); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__loongarch_asx) - - const __m256i mone = __lasx_xvreplgr2vr_b(1); - static const char block_sign_shuffle_mask_1[32] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, - }; - static const char block_sign_shuffle_mask_2[32] = { - 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, - 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, - }; - static const uint8_t bit_selector_mask_bytes[32] = { - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0); - const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0); - const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0); - - static const uint8_t k_bit_helper[32] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0); - const __m256i m511 = __lasx_xvreplgr2vr_h(511); - const __m128i m4 = __lsx_vreplgr2vr_b(0xf); - const __m128i m1 = __lsx_vreplgr2vr_b(1); - - uint64_t aux64; - - // somewhat hacky, but gives a significant boost in performance - __m256i aux_gindex; - const uint16_t * gindex = (const uint16_t *)&aux_gindex; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - __m128i stmp = __lsx_vreplgr2vr_d(aux64); - stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4)); - const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1); - - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - - const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16; - aux_gindex = __lasx_xvand_v(q2_data, m511); - - const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9); - const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13); - const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper); - - const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting); - const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits); - - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - - const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], - iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); - const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], - iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); - const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], - iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); - const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], - iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - - const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0); - const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1); - const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l); - const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h); - - __m256i signs; - signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1); - - signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2); - - signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3); - - signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4); - - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3); - const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4); - - const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0))); - const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1))); - const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2))); - const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3))); - - sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1)); - sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2)); - sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3)); - sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4)); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); -#elif defined(__POWER9_VECTOR__) - const vector int v0 = vec_splats((int32_t)0); - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint16_t * restrict q2 = x[i].qs; - const uint8_t * restrict sc = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/64; ++j) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))}; - - vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))}; - vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))}; - vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))}; - vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))}; - q2 += 8; - - vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); - vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); - vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); - vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); - const uint16_t ls1 = (uint16_t)(sc[0] >> 4); - const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); - const uint16_t ls3 = (uint16_t)(sc[1] >> 4); - sc += 2; - - vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); - vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); - vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); - vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - - vsumi0 = vec_msum(qv0, vscales0, vsumi0); - vsumi1 = vec_msum(qv1, vscales1, vsumi1); - vsumi2 = vec_msum(qv2, vscales2, vsumi2); - vsumi3 = vec_msum(qv3, vscales3, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.125f * vec_extract(vsumf0, 0); -#else - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * restrict q2 = x[i].qs; - const uint8_t * restrict sc = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; - const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; - int32_t sumi = 0; - for (int l = 0; l < 2; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); - const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += sumi * ls1; - sumi = 0; - for (int l = 2; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); - const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += sumi * ls2; - q2 += 4; - } - sumf += d * bsum; - } - *s = 0.125f * sumf; -#endif -} - -void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq2_s * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); - const uint8x16_t mask2 = vld1q_u8(k_mask2); - const uint8x16_t m1 = vdupq_n_u8(1); - const int32x4_t vzero = vdupq_n_s32(0); - - uint8x16x2_t vs; - ggml_int8x16x4_t q2s; - ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * restrict q8 = y[i].qs; - - int sumi1 = 0, sumi2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); - q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); - q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); - q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); - qs += 8; - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); - vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); - - q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); - q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); - vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); - - signs += 4; - - q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); - q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); - - const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); - const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); - const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); - const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); - - sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); - sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); - sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); - sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); - } - sumf += d*(sumi1 + sumi2); - } - - *s = 0.125f * sumf; - -#elif defined(__AVX2__) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); - const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); - - uint64_t aux64; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); - const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 - - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], - iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], - iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], - iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); - const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], - iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], - iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], - iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); - qs += 8; - - __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); - - aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 - - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0))); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1))); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__AVX__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); - const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); - const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); - const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); - - uint64_t aux64; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * restrict q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); - const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8); - const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8)); - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], - iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); - const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], - iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]); - const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], - iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); - const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], - iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]); - qs += 8; - - __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); - __m128i aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); - const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); - - aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); - aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); - const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); - - signs += 4; - - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0))); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1))); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0))); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1))); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - const vector int v0 = vec_splats((int32_t)0); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const vector unsigned char mask0 = vec_xl( 0, k_mask1); - const vector unsigned char mask1 = vec_xl(16, k_mask1); - const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * restrict q2 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const uint8_t * restrict sc = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))}; - q2 += 8; - qh += 2; - - vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); - vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); - signs += 4; - - vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); - vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); - vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0); - vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1); - - vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); - vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); - vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); - vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); - - vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0); - vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1); - vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2); - vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); - const uint16_t ls1 = (uint16_t)(sc[0] >> 4); - const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); - const uint16_t ls3 = (uint16_t)(sc[1] >> 4); - sc += 2; - - vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); - vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); - vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); - vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - - vsumi0 = vec_msum(qv0, vscales0, vsumi0); - vsumi1 = vec_msum(qv1, vscales1, vsumi1); - vsumi2 = vec_msum(qv2, vscales2, vsumi2); - vsumi3 = vec_msum(qv3, vscales3, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.125f * vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - - const __m128i m4 = __lsx_vreplgr2vr_b(0xf); - const __m128i m1 = __lsx_vreplgr2vr_b(1); - - const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); - const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); - uint64_t aux64; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * restrict q8 = y[i].qs; - - __m128i tmp1; - memcpy(&aux64, x[i].scales, 8); - tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0); - tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1); - const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1); - const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 - - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], - iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], - iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], - iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); - const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], - iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], - iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], - iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); - qs += 8; - - __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); - - aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 - - const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0))); - const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1))); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = 0.125f * hsum_float_8(accumf); - -#else - - float sumf = 0; - for (int i = 0; i < nb; i++) { - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint8_t * signs = qs + QK_K/8; - - int bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); - int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); - int sumi1 = 0, sumi2 = 0; - for (int l = 0; l < 2; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); - for (int j = 0; j < 8; ++j) { - sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - for (int l = 2; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); - for (int j = 0; j < 8; ++j) { - sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += ls1 * sumi1 + ls2 * sumi2; - qs += 4; - signs += 4; - } - - sumf += d * bsum; - } - - *s = 0.125f * sumf; - -#endif - -} - -void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq3_xxs * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - ggml_int8x16x4_t q3s; - ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - float sumf1 = 0, sumf2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); - const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); - const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); - const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); - const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); - q3 += 16; - q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); - q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); - q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); - q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); - q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); - q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); - q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); - q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); - sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); - sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); - } - sumf += d*(sumf1 + sumf2); - } - *s = 0.5f * sumf; - -#elif defined(__AVX2__) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - memcpy(aux32, gas, 8); gas += 8; - const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], - signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); - const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = aux32[0] >> 28; - const uint16_t ls2 = aux32[1] >> 28; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.25f * hsum_float_8(accumf); - -#elif defined(__AVX__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); - q3 += 8; - const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); - q3 += 8; - memcpy(aux32, gas, 8); gas += 8; - const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); - const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]); - const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); - const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); - const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); - const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); - const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const uint16_t ls1 = aux32[0] >> 28; - const uint16_t ls2 = aux32[1] >> 28; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.25f * hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - const vector int v0 = vec_splats((int32_t)0); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * restrict q3 = x[i].qs; - const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4); - const int8_t * restrict q8 = y[i].qs; - -#pragma GCC unroll 1 - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q3, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; - vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; - vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; - vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; - q3 += 16; - - vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >> 0) & 127]), (uint64_t)(signs64[(signs[0] >> 7) & 127])}; - vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])}; - vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >> 0) & 127]), (uint64_t)(signs64[(signs[1] >> 7) & 127])}; - vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])}; - - vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0); - vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1); - vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2); - vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(signs[0] >> 28); - const uint16_t ls1 = (uint16_t)(signs[1] >> 28); - signs += 2; - - vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); - vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.25f * vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - memcpy(aux32, gas, 8); gas += 8; - - const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], - signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); - const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); - const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const uint16_t ls1 = aux32[0] >> 28; - const uint16_t ls2 = aux32[1] >> 28; - - const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); - const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = 0.25f * hsum_float_8(accumf); - -#else - - uint32_t aux32; - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict gas = x[i].qs + QK_K/4; - const int8_t * restrict q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); - const uint32_t ls = 2*(aux32 >> 28) + 1; - int32_t sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); - const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); - const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); - } - q8 += 8; - } - q3 += 8; - bsum += sumi * ls; - } - sumf += d * bsum; - } - *s = 0.25f * sumf; -#endif -} - -void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq3_s * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - typedef union { - uint16x8_t vec_index; - uint16_t index[8]; - } vec_index_t; - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; - - const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); - const uint8x16_t mask2 = vld1q_u8(k_mask2); - - const int16x8_t hshift = vld1q_s16(k_shift); - const uint16x8_t m256 = vdupq_n_u16(256); - const uint8x16_t m1 = vdupq_n_u8(1); - - uint8x16x2_t vs; - ggml_int8x16x4_t q3s; - ggml_int8x16x4_t q8b; - vec_index_t idx; - - uint32_t scales32[2]; - const uint8_t * scales8 = (const uint8_t *)scales32; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)x[i].signs; - const int8_t * restrict q8 = y[i].qs; - - memcpy(scales32, x[i].scales, 4); - scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; - scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; - - int sumi1 = 0, sumi2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - - const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; - idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); - const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], - iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); - const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], - iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); - idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); - const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], - iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); - const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], - iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); - - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); - vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); - vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); - - q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); - q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); - vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); - vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); - - signs += 4; - - q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); - q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); - - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); - - sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; - sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; - } - sumf += d*(sumi1 + sumi2); - } - *s = sumf; - -#elif defined(__AVX2__) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); - const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); - - const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); - const __m256i idx_mask = _mm256_set1_epi32(256); - - typedef union { - __m256i vec[2]; - uint32_t index[16]; - } index_t; - - index_t idx; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)x[i].signs; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16; - idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]); - idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]); - idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); - idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); - idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l))); - idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1))); - - // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. - //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); - //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); - const __m256i q2_1 = _mm256_set_epi32( - iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], - iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] - ); - const __m256i q2_2 = _mm256_set_epi32( - iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], - iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] - ); - - __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); - - aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; - const uint16_t ls2 = x[i].scales[ib32/2] >> 4; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = hsum_float_8(accumf); - -#elif defined(__AVX__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); - const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); - const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); - const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); - - const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256); - const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16); - const __m128i idx_mask = _mm_set1_epi32(256); - - typedef union { - __m128i vec[4]; - uint32_t index[16]; - } index_t; - - index_t idx; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)x[i].signs; - const int8_t * restrict q8 = y[i].qs; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); - const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp); - const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16; - idx.vec[0] = _mm_set1_epi32(qh[ib32+0]); - idx.vec[1] = idx.vec[0]; - idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); - idx.vec[3] = idx.vec[2]; - - idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask); - idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask); - idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask); - idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask); - - idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); - idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); - idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1)); - idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); - - const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]); - const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); - const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]); - const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]); - - __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16)); - __m128i aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); - const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); - - aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16)); - aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); - const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); - - signs += 4; - - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; - const uint16_t ls2 = x[i].scales[ib32/2] >> 4; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - const vector int v0 = vec_splats((int32_t)0); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const vector unsigned char mask0 = vec_xl( 0, k_mask1); - const vector unsigned char mask1 = vec_xl(16, k_mask1); - const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)(x[i].signs); - const uint8_t * restrict sc = x[i].scales; - const int8_t * restrict q8 = y[i].qs; - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q3, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)], - iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]}; - vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)], - iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]}; - vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)], - iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]}; - vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)], - iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]}; - q3 += 16; - qh += 2; - - vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); - vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); - signs += 4; - - vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); - vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); - vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0); - vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1); - - vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); - vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); - vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); - vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); - - vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0); - vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1); - vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2); - vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); - const uint16_t ls1 = (uint16_t)(sc[0] >> 4); - sc ++; - - vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); - vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); - const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); - - __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8); - const __m256i idx_mask = __lasx_xvreplgr2vr_w(256); - - typedef union { - __m256i vec[2]; - uint32_t index[16]; - } index_t; - - index_t idx; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint16_t * restrict signs = (const uint16_t *)x[i].signs; - const int8_t * restrict q8 = y[i].qs; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16; - idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]); - idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]); - idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask); - idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask); - idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0))); - idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1))); - - // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. - //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); - //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); - const __m256i q2_1 = lasx_set_w( - iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], - iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] - ); - const __m256i q2_2 = lasx_set_w( - iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], - iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] - ); - - __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); - - aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; - const uint16_t ls2 = x[i].scales[ib32/2] >> 4; - const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); - const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = hsum_float_8(accumf); - -#else - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * restrict qs = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const uint8_t * restrict signs = x[i].signs; - const int8_t * restrict q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; - const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; - int32_t sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); - } - q8 += 8; - } - qs += 8; - signs += 4; - bsum += sumi * ls1; - sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); - } - q8 += 8; - } - qs += 8; - signs += 4; - bsum += sumi * ls2; - } - sumf += d * bsum; - } - *s = sumf; -#endif -} - - -#if defined(__AVX__) -static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { - const __m128i ax = _mm_sign_epi8(x, x); - const __m128i sy = _mm_sign_epi8(y, x); - return _mm_maddubs_epi16(ax, sy); -} -#endif - -#if defined(__AVX2__) -static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { - const __m256i ax = _mm256_sign_epi8(x, x); - const __m256i sy = _mm256_sign_epi8(y, x); - return _mm256_maddubs_epi16(ax, sy); -} -#elif defined(__loongarch_asx) -static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { - const __m256i ax = __lasx_xvsigncov_b(x, x); - const __m256i sy = __lasx_xvsigncov_b(x, y); - __m256i tmp1, tmp2, tmp3; - tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy); - tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy); - tmp3 = __lasx_xvadd_h(tmp1, tmp2); - return __lasx_xvsat_h(tmp3, 15); -} -#endif - -void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq1_s * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined __ARM_NEON - - ggml_int8x16x4_t q1b; - ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - int sumi1 = 0, sumi2 = 0, sumi3 = 0; - - for (int ib = 0; ib < QK_K/32; ib += 2) { - - q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); - q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); - q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); - q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); - qs += 8; - - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); - - const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - sumi1 += vaddvq_s32(p1) * ls1; - sumi2 += vaddvq_s32(p2) * ls2; - sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); - - } - - sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); - } - - *s = sumf; - -#elif defined __AVX2__ - - __m256 accum = _mm256_setzero_ps(); - float accum1 = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - __m256i sumi = _mm256_setzero_si256(); - int sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], - iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); - const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], - iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); - qs += 8; - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); - const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); - const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2)); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2)); - sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum); - accum1 += d * sumi1; - - } - - *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; - -#elif defined __AVX__ - __m256 accum = _mm256_setzero_ps(); - float accum1 = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - int sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); - const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]); - const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); - const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]); - qs += 8; - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - - const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); - const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); - const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); - const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); - const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2)); - - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); - sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum); - accum1 += d * sumi1; - - } - - *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; - -#elif defined(__POWER9_VECTOR__) - const vector unsigned char v0 = vec_splats((unsigned char)0x0); - const vector unsigned short vsign = vec_splats((unsigned short)0x8000); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi8 = vec_splats((int32_t)0); - - const uint8_t * restrict q1 = x[i].qs; - const uint16_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - const int16_t * restrict qs = y[i].bsums; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q1, 0, 1); - __builtin_prefetch(qh, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))}; - q1 += 8; - - vector signed char q1x0 = (vector signed char)aux64x2_0; - vector signed char q1x1 = (vector signed char)aux64x2_1; - vector signed char q1x2 = (vector signed char)aux64x2_2; - vector signed char q1x3 = (vector signed char)aux64x2_3; - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3)); - - const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7); - const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7); - - vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); - vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - vector signed short vscales = vec_sld(vscales23, vscales01, 8); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - - vector signed short q8ysums = vec_xl_len(qs, 8); - qs += 4; - q8ysums = vec_mergeh(q8ysums, (vector signed short)v0); - - vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8); - qh += 2; - vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0); - - vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel); - - vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - - vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - __m256 accum = (__m256)__lasx_xvldi(0); - float accum1 = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - __m256i sumi = __lasx_xvldi(0); - int sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ib += 2) { - __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0); - q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1); - q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2); - q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3); - - __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0); - q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1); - q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2); - q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3); - - qs += 8; - const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); - const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); - const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - - __m256i tmp1, tmp5, tmp6; - tmp1 = __lasx_xvreplgr2vr_h(ls1); - tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1); - tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1); - const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6); - - tmp1 = __lasx_xvreplgr2vr_h(ls2); - tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1); - tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1); - const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2)); - sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum); - accum1 += d * sumi1; - } - - *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; - -#else - - float sumf = 0; - for (int i = 0; i < nb; i++) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - int sumi = 0, sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ++ib) { - const int ls = 2*((qh[ib] >> 12) & 7) + 1; - const int delta = qh[ib] & 0x8000 ? -1 : 1; - int lsum = 0; - for (int l = 0; l < 4; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); - for (int j = 0; j < 8; ++j) { - lsum += q8[j] * grid[j]; - } - q8 += 8; - } - sumi += ls * lsum; - sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); - qs += 4; - } - - sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); - } - - *s = sumf; - -#endif -} - -void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq1_m * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - - iq1m_scale_t scale; - -#if defined __ARM_NEON - const int32x4_t mask = vdupq_n_s32(0x7); - const int32x4_t mone = vdupq_n_s32(1); - const int32x4_t mzero = vdupq_n_s32(0); - - ggml_int8x16x4_t deltas; - deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1)); - deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1)); - deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1)); - deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1)); - - ggml_int8x16x4_t q1b; - ggml_int8x16x4_t q8b; - - uint32_t aux32; - const uint8_t * aux8 = (const uint8_t *)&aux32; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - int32x4_t sumi1 = mzero; - int32x4_t sumi2 = mzero; - - for (int ib = 0; ib < QK_K/32; ib += 2) { - - q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700))))); - q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700))))); - q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700))))); - q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700))))); - - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - - const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); - const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); - const int32x4_t p12 = vpaddq_s32(p1, p2); - - const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that - aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202); - - const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1])); - const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); - const int32x4_t p34 = vpaddq_s32(p3, p4); - - int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); - - scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); - - sumi1 = vmlaq_s32(sumi1, scales_4, p12); - sumi2 = vmlaq_s32(sumi2, scales_4, p34); - - qs += 8; qh += 4; - - } - - sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i mask = _mm256_set1_epi16(0x7); - const __m256i mone = _mm256_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m256i q1b_1 = _mm256_set_epi64x( - iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], - iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)] - ); - const __m256i q1b_2 = _mm256_set_epi64x( - iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], - iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)] - ); - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); - const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); - - const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, - qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, - qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - - const __m256i dot3 = mul_add_epi8(delta1, q8b_1); - const __m256i dot4 = mul_add_epi8(delta2, q8b_2); - - __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); - __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); - - scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); - scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); - const __m256i p1 = _mm256_madd_epi16(dot1, scale1); - const __m256i p2 = _mm256_madd_epi16(dot2, scale2); - const __m256i p3 = _mm256_madd_epi16(dot3, scale1); - const __m256i p4 = _mm256_madd_epi16(dot4, scale2); - - sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2)); - sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4)); - - qs += 8; qh += 4; - } - - const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); - - accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); - accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); - } - - *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); - -#elif defined __AVX__ - const __m128i mask = _mm_set1_epi16(0x7); - const __m128i mone = _mm_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q1b_1_0 = _mm_set_epi64x( - iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]); - const __m128i q1b_1_1 = _mm_set_epi64x( - iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]); - const __m128i q1b_2_0 = _mm_set_epi64x( - iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]); - const __m128i q1b_2_1 = _mm_set_epi64x( - iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]); - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - - const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); - const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); - const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); - const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); - - const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - - const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0); - const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1); - const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0); - const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1); - - __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0); - __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3); - __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6); - __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9); - - scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone); - scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone); - scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone); - scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone); - const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1); - const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0); - const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1); - const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0); - const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1); - - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); - sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0)); - sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1)); - - qs += 8; qh += 4; - } - - const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); - - accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1); - accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2); - } - - *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); - -#else - - int sum1[2], sum2[2], delta[4]; - - float sumf = 0; - for (int i = 0; i < nb; i++) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - int sumi1 = 0, sumi2 = 0; - for (int ib = 0; ib < QK_K/32; ++ib) { - delta[0] = qh[0] & 0x08 ? -1 : 1; - delta[1] = qh[0] & 0x80 ? -1 : 1; - delta[2] = qh[1] & 0x08 ? -1 : 1; - delta[3] = qh[1] & 0x80 ? -1 : 1; - sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; - for (int l = 0; l < 4; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); - int lsum1 = 0, lsum2 = 0; - for (int j = 0; j < 8; ++j) { - lsum1 += q8[j] * grid[j]; - lsum2 += q8[j]; - } - q8 += 8; - sum1[l/2] += lsum1; - sum2[l/2] += lsum2*delta[l]; - } - - const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; - const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; - - sumi1 += sum1[0] * ls1 + sum1[1] * ls2; - sumi2 += sum2[0] * ls1 + sum2[1] * ls2; - qs += 4; - qh += 2; - } - - sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); - } - - *s = sumf; - -#endif -} - -void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - assert(n % QK4_NL == 0); - static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); - - const block_iq4_nl * restrict x = vx; - const block_q8_0 * restrict y = vy; - - const int nb = n / QK4_NL; - - int ib = 0; - float sumf = 0; - -#if defined __ARM_NEON - const int8x16_t values = vld1q_s8(kvalues_iq4nl); - const uint8x16_t m4b = vdupq_n_u8(0x0f); - uint8x16x2_t q4bits; - int8x16x4_t q4b; - int8x16x4_t q8b; - int32x4_t prod_1, prod_2; - - for (; ib + 1 < nb; ib += 2) { - - q4bits.val[0] = vld1q_u8(x[ib + 0].qs); - q4bits.val[1] = vld1q_u8(x[ib + 1].qs); - q8b.val[0] = vld1q_s8(y[ib + 0].qs); - q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); - q8b.val[2] = vld1q_s8(y[ib + 1].qs); - q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); - - q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); - q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); - q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); - q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - - prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); - prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); - - sumf += - GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + - GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); - } - -#elif defined __AVX2__ - - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - const __m256i mone = _mm256_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs); - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs); - const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); - const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); - const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); - accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), - _mm256_cvtepi32_ps(p_1), accum1); - accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), - _mm256_cvtepi32_ps(p_2), accum2); - } - - sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); - -#elif defined __AVX__ - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - const __m128i mone = _mm_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); - - const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); - const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); - const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); - const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); - const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); - const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); - const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); - const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); - const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); - const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); - const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); - const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); - accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), - _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1); - accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), - _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2); - } - - sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - - const vector signed char values = vec_xl( 0, kvalues_iq4nl); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q4x0 = vec_and(qxs, lowMask); - vector signed char q4x1 = vec_sr(qxs, v4); - - q4x0 = vec_perm(values, values, (vector unsigned char)q4x0); - q4x1 = vec_perm(values, values, (vector unsigned char)q4x1); - - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - - vsumi0 = vec_sum4s(qv0, vsumi0); - vsumi1 = vec_sum4s(qv1, vsumi1); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - } - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined (__loongarch_asx) - - const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); - const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); - const __m256i mone = __lasx_xvreplgr2vr_h(1); - - __m256 accum1 = (__m256)__lasx_xvldi(0); - __m256 accum2 = (__m256)__lasx_xvldi(0); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0); - const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0); - const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0); - const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0); - const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), - lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); - const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), - lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const __m256i p_1 = lasx_madd_h(p16_1, mone); - const __m256i p_2 = lasx_madd_h(p16_2, mone); - accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), - __lasx_xvffint_s_w(p_1), accum1); - accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), - __lasx_xvffint_s_w(p_2), accum2); - } - - sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); - -#endif - for (; ib < nb; ++ib) { - const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; - sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; - } - sumf += d * (sumi1 + sumi2); - } - *s = sumf; -} - -void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - assert(n % QK_K == 0); - - const block_iq4_xs * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#if defined __ARM_NEON - const int8x16_t values = vld1q_s8(kvalues_iq4nl); - const uint8x16_t m4b = vdupq_n_u8(0x0f); - ggml_uint8x16x2_t q4bits; - ggml_int8x16x4_t q4b; - ggml_int8x16x4_t q8b; - int32x4_t prod_1, prod_2; - - float sumf = 0; - - for (int ibl = 0; ibl < nb; ++ibl) { - - const int8_t * q8 = y[ibl].qs; - const uint8_t * q4 = x[ibl].qs; - uint16_t h = x[ibl].scales_h; - - int sumi1 = 0, sumi2 = 0; - for (int ib = 0; ib < QK_K/64; ++ib) { - - q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - - q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); - q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); - q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); - q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - - prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); - prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); - - int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; - int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; - h >>= 4; - sumi1 += vaddvq_s32(prod_1) * ls1; - sumi2 += vaddvq_s32(prod_2) * ls2; - - } - - sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - - __m256 accum = _mm256_setzero_ps(); - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - uint16_t sh = x[ibl].scales_h; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16; - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16; - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); - const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; - sh >>= 4; - const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1)); - const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2)); - sumi1 = _mm256_add_epi32(p_1, sumi1); - sumi2 = _mm256_add_epi32(p_2, sumi2); - } - accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), - _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum); - } - - *s = hsum_float_8(accum); - -#elif defined __AVX__ - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - - __m256 accum = _mm256_setzero_ps(); - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - uint16_t sh = x[ibl].scales_h; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16; - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16; - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); - const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); - const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); - const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); - const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); - const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); - const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); - const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); - const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; - sh >>= 4; - const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1)); - const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1)); - const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2)); - const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2)); - sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0); - sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1); - sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0); - sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1); - } - __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0); - __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1); - accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), - _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum); - } - - *s = hsum_float_8(accum); - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const vector signed char values = vec_xl( 0, kvalues_iq4nl); - - for (int ibl = 0; ibl < nb; ++ibl) { - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ibl].d)); - vector float vyd = vec_splats(y[ibl].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - uint16_t h = x[ibl].scales_h; - - const uint8_t * restrict q4 = x[ibl].qs; - const uint8_t * restrict sc = x[ibl].scales_l; - const int8_t * restrict q8 = y[ibl].qs; - - for (int ib = 0; ib < QK_K/64; ib ++ ) { - __builtin_prefetch(q4, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); - vector signed char qxs1 = (vector signed char)vec_xl(16, q4); - q4 += 32; - - vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask); - vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4); - vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask); - vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4); - - q4x00 = vec_perm(values, values, (vector unsigned char)q4x00); - q4x01 = vec_perm(values, values, (vector unsigned char)q4x01); - q4x10 = vec_perm(values, values, (vector unsigned char)q4x10); - q4x11 = vec_perm(values, values, (vector unsigned char)q4x11); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3)); - - const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32); - const uint16_t ls1 = (uint16_t)(((sc[0] >> 4) | ((h << 2) & 0x30)) - 32); - h >>= 4; - sc ++; - - vector signed short vscales01 = vec_splats((int16_t)ls0); - vector signed short vscales23 = vec_splats((int16_t)ls1); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); - const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); - - __m256 accum = (__m256)__lasx_xvldi(0); - __m256i tmp1; - __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask; - - mask_8f = __lsx_vreplgr2vr_b(0x8f); - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - uint16_t sh = x[ibl].scales_h; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - __m128i zero = __lsx_vldi(0); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; - const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; - const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp3 = __lsx_vand_v(tmp0, mask); - tmp3 = __lsx_vshuf_b(values128, zero, tmp3); - - tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp4 = __lsx_vand_v(tmp0, mask); - tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - - const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4); - - tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp3 = __lsx_vand_v(tmp0, mask); - tmp3 = __lsx_vshuf_b(values128, zero, tmp3); - - tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f); - tmp0 = __lsx_vori_b(tmp2, 0x10); - mask = __lsx_vsle_b(zero, tmp2); - tmp4 = __lsx_vand_v(tmp0, mask); - tmp4 = __lsx_vshuf_b(values128, zero, tmp4); - - const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4); - - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; - sh >>= 4; - __m256i tmp5, tmp6; - tmp1 = __lasx_xvreplgr2vr_h(ls1); - tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1); - tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1); - const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6); - tmp1 = __lasx_xvreplgr2vr_h(ls2); - tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1); - tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1); - const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6); - sumi1 = __lasx_xvadd_w(p_1, sumi1); - sumi2 = __lasx_xvadd_w(p_2, sumi2); - } - accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), - __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum); - } - - *s = hsum_float_8(accum); - -#else - float sumf = 0; - for (int ibl = 0; ibl < nb; ++ibl) { - const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; - uint16_t h = x[ibl].scales_h; - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - for (int ib = 0; ib < QK_K/32; ib += 2) { - const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); - const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); - h >>= 4; - const float d1 = d4d8*(ls1 - 32); - const float d2 = d4d8*(ls2 - 32); - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < 16; ++j) { - sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; - sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; - } - sumf += d1 * (sumi1 + sumi2); - qs += 16; - q8 += 32; - sumi1 = sumi2 = 0; - for (int j = 0; j < 16; ++j) { - sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; - sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; - } - sumf += d2 * (sumi1 + sumi2); - qs += 16; - q8 += 32; - } - } - *s = sumf; -#endif -} - // ================================ IQ2 quantization ============================================= typedef struct { @@ -14057,12 +3770,6 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t return nrow * nblock * sizeof(block_iq3_xxs); } -void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_iq3_xxs * restrict y = vy; - quantize_row_iq3_xxs_ref(x, y, k); -} - void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_row_iq3_xxs_impl(256, x, y, k, NULL); @@ -14273,12 +3980,6 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n return nrow * nblock * sizeof(block_iq3_s); } -void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_iq3_s * restrict y = vy; - quantize_row_iq3_s_ref(x, y, k); -} - void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq3_s(x, y, 1, k, NULL); @@ -15002,7 +4703,8 @@ size_t quantize_iq4_nl(const float * restrict src, void * restrict dst, int64_t return nrow * nblock * sizeof(block_iq4_nl); } -void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k) { +//void quantize_row_iq4_nl_ref(const float * restrict x, void * restrict vy, int64_t k) { +void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { GGML_ASSERT(k%QK4_NL == 0); int64_t nblock = k/QK4_NL; uint8_t L[QK4_NL]; @@ -15010,18 +4712,13 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k uint16_t unused_h; uint8_t * unused_l = NULL; float scale; - block_iq4_nl * iq4 = (block_iq4_nl *)vy; + block_iq4_nl * iq4 = y; for (int ibl = 0; ibl < nblock; ++ibl) { quantize_row_iq4_nl_impl(QK4_NL, 32, x + QK4_NL*ibl, &iq4[ibl].d, iq4[ibl].qs, &unused_h, unused_l, &scale, weight, L, kvalues_iq4nl, NULL, -1); } } -void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { - assert(k % QK4_NL == 0); - quantize_row_iq4_nl(x, y, k); -} - size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { GGML_ASSERT(n_per_row%QK_K == 0); int64_t nblock = n_per_row/QK_K; @@ -15042,12 +4739,6 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t return nrow * nblock * sizeof(block_iq4_xs); } -void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_iq4_xs * restrict y = vy; - quantize_row_iq4_xs_ref(x, y, k); -} - void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq4_xs(x, y, 1, k, NULL); @@ -15240,11 +4931,7 @@ void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, quantize_iq2_s(x, y, 1, k, NULL); } -void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK_K == 0); - block_iq2_s * restrict y = vy; - quantize_row_iq2_s_ref(x, y, k); -} +// =============================== data validation static bool validate_float(float f, size_t i) { if (isinf(f)) { @@ -15533,15 +5220,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; - case GGML_TYPE_Q4_0_4_4: - case GGML_TYPE_Q4_0_4_8: - { - VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4); - } break; - case GGML_TYPE_Q4_0_8_8: - { - VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8); - } break; case GGML_TYPE_I8: case GGML_TYPE_I16: diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index e96ce2b5e..d09173e11 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -11,140 +11,89 @@ extern "C" { #endif +// NOTE: these functions are defined as GGML_API because they used by the CPU backend + // Quantization -void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); -void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); - -void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - -void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - -void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - -void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); // Dequantization -void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -//void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); - -// Dot product -void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); - -void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); - -void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); - -void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +GGML_API void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq3_xxs(const block_iq3_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq1_s (const block_iq1_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq1_m (const block_iq1_m * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") -size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq2_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq3_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq1_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq1_m (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -void iq2xs_init_impl(enum ggml_type type); -void iq2xs_free_impl(enum ggml_type type); -void iq3xs_init_impl(int grid_size); -void iq3xs_free_impl(int grid_size); - -#if defined(__ARM_FEATURE_SVE) -extern int ggml_sve_cnt_b; -#endif +GGML_API void iq2xs_init_impl(enum ggml_type type); +GGML_API void iq2xs_free_impl(enum ggml_type type); +GGML_API void iq3xs_init_impl(int grid_size); +GGML_API void iq3xs_free_impl(int grid_size); #ifdef __cplusplus } diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt new file mode 100644 index 000000000..f5acb8ec2 --- /dev/null +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -0,0 +1,9 @@ +message(STATUS "Using RPC backend") + +ggml_add_backend_library(ggml-rpc + ggml-rpc.cpp + ) + +if (WIN32) + target_link_libraries(ggml-rpc PRIVATE ws2_32) +endif() diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp similarity index 54% rename from ggml/src/ggml-rpc.cpp rename to ggml/src/ggml-rpc/ggml-rpc.cpp index 8f9d0a460..3d0c46578 100644 --- a/ggml/src/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1,5 +1,5 @@ #include "ggml-rpc.h" -#include "ggml.h" +#include "ggml-impl.h" #include "ggml-backend-impl.h" #include @@ -25,16 +25,7 @@ # include # include #endif -#include - -#define UNUSED GGML_UNUSED - -#define GGML_DEBUG 0 -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif +#include #ifdef _WIN32 typedef SOCKET sockfd_t; @@ -57,8 +48,9 @@ struct socket_t { } }; -// ggml_tensor is serialized into rpc_tensor +// all RPC structures must be packed #pragma pack(push, 1) +// ggml_tensor is serialized into rpc_tensor struct rpc_tensor { uint64_t id; uint32_t type; @@ -76,7 +68,6 @@ struct rpc_tensor { char padding[4]; }; -#pragma pack(pop) static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8"); @@ -93,9 +84,82 @@ enum rpc_cmd { RPC_CMD_COPY_TENSOR, RPC_CMD_GRAPH_COMPUTE, RPC_CMD_GET_DEVICE_MEMORY, + RPC_CMD_INIT_TENSOR, + RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_COUNT, }; +struct rpc_msg_get_alloc_size_req { + rpc_tensor tensor; +}; + +struct rpc_msg_get_alloc_size_rsp { + uint64_t alloc_size; +}; + +struct rpc_msg_init_tensor_req { + rpc_tensor tensor; +}; + +struct rpc_msg_alloc_buffer_req { + uint64_t size; +}; + +struct rpc_msg_alloc_buffer_rsp { + uint64_t remote_ptr; + uint64_t remote_size; +}; + +struct rpc_msg_get_alignment_rsp { + uint64_t alignment; +}; + +struct rpc_msg_get_max_size_rsp { + uint64_t max_size; +}; + +struct rpc_msg_buffer_get_base_req { + uint64_t remote_ptr; +}; + +struct rpc_msg_buffer_get_base_rsp { + uint64_t base_ptr; +}; + +struct rpc_msg_free_buffer_req { + uint64_t remote_ptr; +}; + +struct rpc_msg_buffer_clear_req { + uint64_t remote_ptr; + uint8_t value; +}; + +struct rpc_msg_get_tensor_req { + rpc_tensor tensor; + uint64_t offset; + uint64_t size; +}; + +struct rpc_msg_copy_tensor_req { + rpc_tensor src; + rpc_tensor dst; +}; + +struct rpc_msg_copy_tensor_rsp { + uint8_t result; +}; + +struct rpc_msg_graph_compute_rsp { + uint8_t result; +}; + +struct rpc_msg_get_device_memory_rsp { + uint64_t free_mem; + uint64_t total_mem; +}; +#pragma pack(pop) + // RPC data structures static ggml_guid_t ggml_backend_rpc_guid() { @@ -117,9 +181,8 @@ struct ggml_backend_rpc_context { struct ggml_backend_rpc_buffer_context { std::shared_ptr sock; - std::unordered_map base_cache; + void * base_ptr; uint64_t remote_ptr; - std::string name; }; // RPC helper functions @@ -240,6 +303,38 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) { return true; } +static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) { + if (!send_data(sockfd, &msg_size, sizeof(msg_size))) { + return false; + } + return send_data(sockfd, msg, msg_size); +} + +static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) { + uint64_t size; + if (!recv_data(sockfd, &size, sizeof(size))) { + return false; + } + if (size != msg_size) { + return false; + } + return recv_data(sockfd, msg, msg_size); +} + +static bool recv_msg(sockfd_t sockfd, std::vector & input) { + uint64_t size; + if (!recv_data(sockfd, &size, sizeof(size))) { + return false; + } + try { + input.resize(size); + } catch (const std::bad_alloc & e) { + fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size); + return false; + } + return recv_data(sockfd, input.data(), size); +} + static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { size_t pos = endpoint.find(':'); if (pos == std::string::npos) { @@ -252,28 +347,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const std::vector & input, std::vector & output) { +static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { uint8_t cmd_byte = cmd; if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { return false; } - uint64_t input_size = input.size(); if (!send_data(sock->fd, &input_size, sizeof(input_size))) { return false; } - if (!send_data(sock->fd, input.data(), input.size())) { + if (!send_data(sock->fd, input, input_size)) { return false; } - uint64_t output_size; - if (!recv_data(sock->fd, &output_size, sizeof(output_size))) { + // TODO: currently the output_size is always known, do we need support for commands with variable output size? + // even if we do, we can skip sending output_size from the server for commands with known output size + uint64_t out_size; + if (!recv_data(sock->fd, &out_size, sizeof(out_size))) { return false; } - if (output_size == 0) { - output.clear(); - return true; + if (out_size != output_size) { + return false; } - output.resize(output_size); - if (!recv_data(sock->fd, output.data(), output_size)) { + if (!recv_data(sock->fd, output, output_size)) { return false; } return true; @@ -308,7 +402,7 @@ static std::shared_ptr get_socket(const std::string & endpoint) { initialized = true; } #else - UNUSED(initialized); + GGML_UNUSED(initialized); #endif auto sock = socket_connect(host.c_str(), port); if (sock == nullptr) { @@ -319,43 +413,25 @@ static std::shared_ptr get_socket(const std::string & endpoint) { return sock; } -GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { +static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - return ctx->name.c_str(); -} - -GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - // input serialization format: | remote_ptr (8 bytes) | - std::vector input(sizeof(uint64_t), 0); - uint64_t remote_ptr = ctx->remote_ptr; - memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); - std::vector output; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output); + rpc_msg_free_buffer_req request = {ctx->remote_ptr}; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); GGML_ASSERT(status); - GGML_ASSERT(output.empty()); delete ctx; } -GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { +static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) { - return ctx->base_cache[buffer]; + if (ctx->base_ptr != nullptr) { + return ctx->base_ptr; } - // input serialization format: | remote_ptr (8 bytes) | - std::vector input(sizeof(uint64_t), 0); - uint64_t remote_ptr = ctx->remote_ptr; - memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); - std::vector output; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output); + rpc_msg_buffer_get_base_req request = {ctx->remote_ptr}; + rpc_msg_buffer_get_base_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response)); GGML_ASSERT(status); - GGML_ASSERT(output.size() == sizeof(uint64_t)); - // output serialization format: | base_ptr (8 bytes) | - uint64_t base_ptr; - memcpy(&base_ptr, output.data(), sizeof(base_ptr)); - void * base = reinterpret_cast(base_ptr); - ctx->base_cache[buffer] = base; - return base; + ctx->base_ptr = reinterpret_cast(response.base_ptr); + return ctx->base_ptr; } static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { @@ -388,15 +464,23 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { return result; } -GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - UNUSED(buffer); - if (ggml_is_quantized(tensor->type)) { - // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized - GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor"); +static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + + // CUDA backend on the server pads everything to 512 due to CUDA limitations. + // Due to bandwidth constraints, we only call the server init tensor functions if necessary. + // In particular, only quantized tensors need padding + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + rpc_msg_init_tensor_req request; + + request.tensor = serialize_tensor(tensor); + + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); + GGML_ASSERT(status); } } -GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) | size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size; @@ -405,29 +489,21 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); - std::vector output; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0); GGML_ASSERT(status); } -GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { +static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) | - int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t); - std::vector input(input_size, 0); - rpc_tensor rpc_tensor = serialize_tensor(tensor); - memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); - memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); - memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size)); - std::vector output; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output); + rpc_msg_get_tensor_req request; + request.tensor = serialize_tensor(tensor); + request.offset = offset; + request.size = size; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size); GGML_ASSERT(status); - GGML_ASSERT(output.size() == size); - // output serialization format: | data (size bytes) | - memcpy(data, output.data(), size); } -GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { +static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { // check if src and dst are on the same server ggml_backend_buffer_t src_buffer = src->buffer; ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; @@ -437,38 +513,27 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b return false; } ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - // input serialization format: | rpc_tensor src | rpc_tensor dst | - int input_size = 2*sizeof(rpc_tensor); - std::vector input(input_size, 0); - rpc_tensor rpc_src = serialize_tensor(src); - rpc_tensor rpc_dst = serialize_tensor(dst); - memcpy(input.data(), &rpc_src, sizeof(rpc_src)); - memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst)); - std::vector output; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output); + rpc_msg_copy_tensor_req request; + request.src = serialize_tensor(src); + request.dst = serialize_tensor(dst); + rpc_msg_copy_tensor_rsp response; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); GGML_ASSERT(status); - // output serialization format: | result (1 byte) | - GGML_ASSERT(output.size() == 1); - return output[0]; + return response.result; } -GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { +static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; - // serialization format: | bufptr (8 bytes) | value (1 byte) | - int input_size = sizeof(uint64_t) + sizeof(uint8_t); - std::vector input(input_size, 0); - memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr)); - memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value)); - std::vector output; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output); + rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value}; + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0); GGML_ASSERT(status); } static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { - /* .get_name = */ ggml_backend_rpc_buffer_get_name, /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer, /* .get_base = */ ggml_backend_rpc_buffer_get_base, /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, @@ -476,32 +541,23 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { /* .reset = */ NULL, }; -GGML_CALL static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) { +static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) { ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; return buft_ctx->name.c_str(); } -GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; - // input serialization format: | size (8 bytes) | - int input_size = sizeof(uint64_t); - std::vector input(input_size, 0); - memcpy(input.data(), &size, sizeof(size)); - std::vector output; + rpc_msg_alloc_buffer_req request = {size}; + rpc_msg_alloc_buffer_rsp response; auto sock = get_socket(buft_ctx->endpoint); - bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); GGML_ASSERT(status); - GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); - // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | - uint64_t remote_ptr; - memcpy(&remote_ptr, output.data(), sizeof(remote_ptr)); - size_t remote_size; - memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size)); - if (remote_ptr != 0) { + if (response.remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"}, - remote_size); + new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr}, + response.remote_size); return buffer; } else { return nullptr; @@ -509,44 +565,47 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer } static size_t get_alignment(const std::shared_ptr & sock) { - // input serialization format: | 0 bytes | - std::vector input; - std::vector output; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output); + rpc_msg_get_alignment_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response)); GGML_ASSERT(status); - GGML_ASSERT(output.size() == sizeof(uint64_t)); - // output serialization format: | alignment (8 bytes) | - uint64_t alignment; - memcpy(&alignment, output.data(), sizeof(alignment)); - return alignment; + return response.alignment; } -GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; return buft_ctx->alignment; } static size_t get_max_size(const std::shared_ptr & sock) { - // input serialization format: | 0 bytes | - std::vector input; - std::vector output; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output); + rpc_msg_get_max_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response)); GGML_ASSERT(status); - GGML_ASSERT(output.size() == sizeof(uint64_t)); - // output serialization format: | max_size (8 bytes) | - uint64_t max_size; - memcpy(&max_size, output.data(), sizeof(max_size)); - return max_size; + return response.max_size; } -GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; return buft_ctx->max_size; } -GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - UNUSED(buft); - return ggml_nbytes(tensor); +static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + // See comments in init_tensor. + if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) { + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + auto sock = get_socket(buft_ctx->endpoint); + + rpc_msg_get_alloc_size_req request; + + request.tensor = serialize_tensor(tensor); + + rpc_msg_get_alloc_size_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response)); + GGML_ASSERT(status); + + return response.alloc_size; + } else { + return ggml_nbytes(tensor); + } } static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { @@ -558,25 +617,20 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { /* .is_host = */ NULL, }; -GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { +static const char * ggml_backend_rpc_name(ggml_backend_t backend) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; return rpc_ctx->name.c_str(); } -GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { +static void ggml_backend_rpc_free(ggml_backend_t backend) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; delete rpc_ctx; delete backend; } -GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { - ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; - return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); -} - -GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { - UNUSED(backend); +static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { + GGML_UNUSED(backend); // this is no-op because we don't have any async operations } @@ -617,38 +671,20 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector & o memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); } -GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { +static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; std::vector input; serialize_graph(cgraph, input); - std::vector output; + rpc_msg_graph_compute_rsp response; auto sock = get_socket(rpc_ctx->endpoint); - bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); GGML_ASSERT(status); - GGML_ASSERT(output.size() == 1); - return (enum ggml_status)output[0]; -} - -GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) { - UNUSED(backend); - UNUSED(op); - //TODO: call the remote backend and cache the results - return true; -} - -GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { - return false; - } - ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; - ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - return buft_ctx->endpoint == rpc_ctx->endpoint; + return (enum ggml_status)response.result; } static ggml_backend_i ggml_backend_rpc_interface = { /* .get_name = */ ggml_backend_rpc_name, /* .free = */ ggml_backend_rpc_free, - /* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, @@ -658,17 +694,11 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_rpc_graph_compute, - /* .supports_op = */ ggml_backend_rpc_supports_op, - /* .supports_buft = */ ggml_backend_rpc_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, /* .event_record = */ NULL, /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, }; -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { +ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); // NOTE: buffer types are allocated and never freed; this is by design @@ -693,13 +723,14 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { /* .iface = */ ggml_backend_rpc_buffer_type_interface, + /* .device = */ ggml_backend_rpc_add_device(endpoint), /* .context = */ buft_ctx }; buft_map[endpoint] = buft; return buft; } -GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { +ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { /* .endpoint = */ endpoint, /* .name = */ "RPC[" + std::string(endpoint) + "]", @@ -708,32 +739,25 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_rpc_guid(), /* .interface = */ ggml_backend_rpc_interface, + /* .device = */ ggml_backend_rpc_add_device(endpoint), /* .context = */ ctx }; return backend; } -GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { +bool ggml_backend_is_rpc(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); } static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { - // input serialization format: | 0 bytes | - std::vector input; - std::vector output; - bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output); + rpc_msg_get_device_memory_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response)); GGML_ASSERT(status); - GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); - // output serialization format: | free (8 bytes) | total (8 bytes) | - uint64_t free_mem; - memcpy(&free_mem, output.data(), sizeof(free_mem)); - uint64_t total_mem; - memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem)); - *free = free_mem; - *total = total_mem; + *free = response.free_mem; + *total = response.total_mem; } -GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { +void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { auto sock = get_socket(endpoint); if (sock == nullptr) { *free = 0; @@ -750,16 +774,18 @@ public: rpc_server(ggml_backend_t backend) : backend(backend) {} ~rpc_server(); - bool alloc_buffer(const std::vector & input, std::vector & output); - void get_alignment(std::vector & output); - void get_max_size(std::vector & output); - bool buffer_get_base(const std::vector & input, std::vector & output); - bool free_buffer(const std::vector & input); - bool buffer_clear(const std::vector & input); + void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); + void get_alignment(rpc_msg_get_alignment_rsp & response); + void get_max_size(rpc_msg_get_max_size_rsp & response); + bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response); + bool free_buffer(const rpc_msg_free_buffer_req & request); + bool buffer_clear(const rpc_msg_buffer_clear_req & request); bool set_tensor(const std::vector & input); - bool get_tensor(const std::vector & input, std::vector & output); - bool copy_tensor(const std::vector & input, std::vector & output); - bool graph_compute(const std::vector & input, std::vector & output); + bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); + bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); + bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); + bool init_tensor(const rpc_msg_init_tensor_req & request); + bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); private: ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); @@ -773,82 +799,82 @@ private: std::unordered_set buffers; }; -bool rpc_server::alloc_buffer(const std::vector & input, std::vector & output) { - // input serialization format: | size (8 bytes) | - if (input.size() != sizeof(uint64_t)) { +bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { + ggml_backend_buffer_type_t buft; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + + if (tensor == nullptr) { + GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); + ggml_free(ctx); return false; } - uint64_t size; - memcpy(&size, input.data(), sizeof(size)); - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size); - uint64_t remote_ptr = 0; - uint64_t remote_size = 0; - if (buffer != nullptr) { - remote_ptr = reinterpret_cast(buffer); - remote_size = buffer->size; - GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size); - buffers.insert(buffer); + + if (tensor->buffer == nullptr) { + //No buffer allocated. + buft = ggml_backend_get_default_buffer_type(backend); } else { - GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size); + buft = tensor->buffer->buft; } - // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | - output.resize(2*sizeof(uint64_t), 0); - memcpy(output.data(), &remote_ptr, sizeof(remote_ptr)); - memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size)); + + response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor); + + ggml_free(ctx); return true; } -void rpc_server::get_alignment(std::vector & output) { +void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) { + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size); + response.remote_ptr = 0; + response.remote_size = 0; + if (buffer != nullptr) { + response.remote_ptr = reinterpret_cast(buffer); + response.remote_size = buffer->size; + GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size); + buffers.insert(buffer); + } else { + GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size); + } +} + +void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) { ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); size_t alignment = ggml_backend_buft_get_alignment(buft); GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment); - // output serialization format: | alignment (8 bytes) | - output.resize(sizeof(uint64_t), 0); - memcpy(output.data(), &alignment, sizeof(alignment)); + response.alignment = alignment; } -void rpc_server::get_max_size(std::vector & output) { +void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) { ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); size_t max_size = ggml_backend_buft_get_max_size(buft); GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size); - // output serialization format: | max_size (8 bytes) | - output.resize(sizeof(uint64_t), 0); - memcpy(output.data(), &max_size, sizeof(max_size)); + response.max_size = max_size; } -bool rpc_server::buffer_get_base(const std::vector & input, std::vector & output) { - // input serialization format: | remote_ptr (8 bytes) | - if (input.size() != sizeof(uint64_t)) { - return false; - } - uint64_t remote_ptr; - memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); - GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr); - ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr); +bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) { + GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); if (buffers.find(buffer) == buffers.end()) { - GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); return false; } void * base = ggml_backend_buffer_get_base(buffer); - // output serialization format: | base_ptr (8 bytes) | - uint64_t base_ptr = reinterpret_cast(base); - output.resize(sizeof(uint64_t), 0); - memcpy(output.data(), &base_ptr, sizeof(base_ptr)); + response.base_ptr = reinterpret_cast(base); return true; } -bool rpc_server::free_buffer(const std::vector & input) { - // input serialization format: | remote_ptr (8 bytes) | - if (input.size() != sizeof(uint64_t)) { - return false; - } - uint64_t remote_ptr; - memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); - GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr); - ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr); +bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) { + GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); if (buffers.find(buffer) == buffers.end()) { - GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); return false; } ggml_backend_buffer_free(buffer); @@ -856,22 +882,14 @@ bool rpc_server::free_buffer(const std::vector & input) { return true; } -bool rpc_server::buffer_clear(const std::vector & input) { - // input serialization format: | remote_ptr (8 bytes) | value (1 byte) | - if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) { - return false; - } - uint64_t remote_ptr; - memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); - uint8_t value; - memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value)); - GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value); - ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr); +bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) { + GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value); + ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr); if (buffers.find(buffer) == buffers.end()) { - GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__); + GGML_LOG_ERROR("[%s] buffer not found\n", __func__); return false; } - ggml_backend_buffer_clear(buffer, value); + ggml_backend_buffer_clear(buffer, request.value); return true; } @@ -883,15 +901,17 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp } result->buffer = reinterpret_cast(tensor->buffer); if (result->buffer && buffers.find(result->buffer) == buffers.end()) { - return nullptr; + result->buffer = nullptr; } - // require that the tensor data does not go beyond the buffer end - uint64_t tensor_size = (uint64_t) ggml_nbytes(result); - uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer); - uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer); - GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow - GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size); + if (result->buffer) { + // require that the tensor data does not go beyond the buffer end + uint64_t tensor_size = (uint64_t) ggml_nbytes(result); + uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer); + uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer); + GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow + GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size); + } result->op = (ggml_op) tensor->op; for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { @@ -922,7 +942,7 @@ bool rpc_server::set_tensor(const std::vector & input) { struct ggml_context * ctx = ggml_init(params); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { - GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__); + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); ggml_free(ctx); return false; } @@ -944,74 +964,89 @@ bool rpc_server::set_tensor(const std::vector & input) { return true; } -bool rpc_server::get_tensor(const std::vector & input, std::vector & output) { - // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) | - if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) { - return false; - } - const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); - uint64_t offset; - memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); - uint64_t size; - memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size)); - +bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; struct ggml_context * ctx = ggml_init(params); - ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { - GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__); + GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); ggml_free(ctx); return false; } - GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); + + // Call the backend's buffer_init_tensor function + ggml_backend_buffer_t buffer = tensor->buffer; + if (buffer && buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } else { + GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n"); + } + + if (tensor->extra != nullptr) { + // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. + // Currently unimplemented. + GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + return true; +} + +bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response) { + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx = ggml_init(params); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); + if (tensor == nullptr) { + GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); + ggml_free(ctx); + return false; + } + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size); // sanitize tensor->data { const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); - if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { - GGML_ABORT("[%s] tensor->data out of bounds\n", __func__); + if (request.tensor.data + request.offset < p0 || + request.tensor.data + request.offset >= p1 || + request.size > (p1 - request.tensor.data - request.offset)) { + GGML_ABORT("[%s] tensor->data out of bounds\n", __func__); } } - // output serialization format: | data (size bytes) | - output.resize(size, 0); - ggml_backend_tensor_get(tensor, output.data(), offset, size); + response.resize(request.size, 0); + ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size); ggml_free(ctx); return true; } -bool rpc_server::copy_tensor(const std::vector & input, std::vector & output) { - // serialization format: | rpc_tensor src | rpc_tensor dst | - if (input.size() != 2*sizeof(rpc_tensor)) { - return false; - } - const rpc_tensor * rpc_src = (const rpc_tensor *)input.data(); - const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src)); - +bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) { struct ggml_init_params params { /*.mem_size =*/ 2*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; struct ggml_context * ctx = ggml_init(params); - ggml_tensor * src = deserialize_tensor(ctx, rpc_src); - ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst); + ggml_tensor * src = deserialize_tensor(ctx, &request.src); + ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); if (src == nullptr || dst == nullptr) { - GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__); + GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer); - bool result = ggml_backend_buffer_copy_tensor(src, dst); - // output serialization format: | result (1 byte) | - output.resize(1, 0); - output[0] = result; + response.result = ggml_backend_buffer_copy_tensor(src, dst); ggml_free(ctx); return true; } @@ -1040,7 +1075,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id, return result; } -bool rpc_server::graph_compute(const std::vector & input, std::vector & output) { +bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { // serialization format: // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | if (input.size() < sizeof(uint32_t)) { @@ -1060,7 +1095,7 @@ bool rpc_server::graph_compute(const std::vector & input, std::vector & input, std::vectornodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map); } ggml_status status = ggml_backend_graph_compute(backend, graph); - // output serialization format: | status (1 byte) | - output.resize(1, 0); - output[0] = status; + response.result = status; ggml_free(ctx); return true; } @@ -1105,89 +1138,182 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre fprintf(stderr, "Unknown command: %d\n", cmd); break; } - std::vector input; - std::vector output; - uint64_t input_size; - if (!recv_data(sockfd, &input_size, sizeof(input_size))) { - break; - } - try { - input.resize(input_size); - } catch (const std::bad_alloc & e) { - fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size); - break; - } - if (!recv_data(sockfd, input.data(), input_size)) { - break; - } - bool ok = true; switch (cmd) { case RPC_CMD_ALLOC_BUFFER: { - ok = server.alloc_buffer(input, output); + rpc_msg_alloc_buffer_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_alloc_buffer_rsp response; + server.alloc_buffer(request, response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } + break; + } + case RPC_CMD_GET_ALLOC_SIZE: { + rpc_msg_get_alloc_size_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_get_alloc_size_rsp response; + server.get_alloc_size(request, response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } break; } case RPC_CMD_GET_ALIGNMENT: { - server.get_alignment(output); + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_get_alignment_rsp response; + server.get_alignment(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } break; } case RPC_CMD_GET_MAX_SIZE: { - server.get_max_size(output); + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_get_max_size_rsp response; + server.get_max_size(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } break; } case RPC_CMD_BUFFER_GET_BASE: { - ok = server.buffer_get_base(input, output); + rpc_msg_buffer_get_base_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_buffer_get_base_rsp response; + if (!server.buffer_get_base(request, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } break; } case RPC_CMD_FREE_BUFFER: { - ok = server.free_buffer(input); + rpc_msg_free_buffer_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + if (!server.free_buffer(request)) { + return; + } + if (!send_msg(sockfd, nullptr, 0)) { + return; + } break; } case RPC_CMD_BUFFER_CLEAR: { - ok = server.buffer_clear(input); + rpc_msg_buffer_clear_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + if (!server.buffer_clear(request)) { + return; + } + if (!send_msg(sockfd, nullptr, 0)) { + return; + } break; } case RPC_CMD_SET_TENSOR: { - ok = server.set_tensor(input); + std::vector input; + if (!recv_msg(sockfd, input)) { + return; + } + if (!server.set_tensor(input)) { + return; + } + if (!send_msg(sockfd, nullptr, 0)) { + return; + } + break; + } + case RPC_CMD_INIT_TENSOR: { + rpc_msg_init_tensor_req request; + if (!recv_msg(sockfd, &request,sizeof(request))) { + return; + } + if (!server.init_tensor(request)) { + return; + } + if (!send_msg(sockfd, nullptr, 0)) { + return; + } break; } case RPC_CMD_GET_TENSOR: { - ok = server.get_tensor(input, output); + rpc_msg_get_tensor_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + std::vector response; + if (!server.get_tensor(request, response)) { + return; + } + if (!send_msg(sockfd, response.data(), response.size())) { + return; + } break; } case RPC_CMD_COPY_TENSOR: { - ok = server.copy_tensor(input, output); + rpc_msg_copy_tensor_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + rpc_msg_copy_tensor_rsp response; + if (!server.copy_tensor(request, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } break; } case RPC_CMD_GRAPH_COMPUTE: { - ok = server.graph_compute(input, output); + std::vector input; + if (!recv_msg(sockfd, input)) { + return; + } + rpc_msg_graph_compute_rsp response; + if (!server.graph_compute(input, response)) { + return; + } + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } break; } case RPC_CMD_GET_DEVICE_MEMORY: { - // output serialization format: | free (8 bytes) | total (8 bytes) | - output.resize(2*sizeof(uint64_t), 0); - memcpy(output.data(), &free_mem, sizeof(free_mem)); - memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_get_device_memory_rsp response; + response.free_mem = free_mem; + response.total_mem = total_mem; + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } break; } default: { fprintf(stderr, "Unknown command: %d\n", cmd); - ok = false; + return; } } - if (!ok) { - break; - } - uint64_t output_size = output.size(); - if (!send_data(sockfd, &output_size, sizeof(output_size))) { - break; - } - if (!send_data(sockfd, output.data(), output_size)) { - break; - } } } -void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) { +void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) { std::string host; int port; if (!parse_endpoint(endpoint, host, port)) { @@ -1224,3 +1350,175 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free WSACleanup(); #endif } + +// device interface + +struct ggml_backend_rpc_device_context { + std::string endpoint; + std::string name; +}; + +static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ctx->name.c_str(); +} + +static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ctx->name.c_str(); +} + +static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total); + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) { + // TODO: obtain value from the server + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_rpc_device_get_name(dev); + props->description = ggml_backend_rpc_device_get_description(dev); + props->type = ggml_backend_rpc_device_get_type(dev); + ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ggml_backend_rpc_init(ctx->endpoint.c_str()); + + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; + + return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str()); + + GGML_UNUSED(dev); +} + +static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + GGML_UNUSED(dev); + GGML_UNUSED(op); + //TODO: call the remote backend and cache the results + return true; +} + +static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) { + return false; + } + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context; + return buft_ctx->endpoint == dev_ctx->endpoint; +} + +static const struct ggml_backend_device_i ggml_backend_rpc_device_i = { + /* .get_name = */ ggml_backend_rpc_device_get_name, + /* .get_description = */ ggml_backend_rpc_device_get_description, + /* .get_memory = */ ggml_backend_rpc_device_get_memory, + /* .get_type = */ ggml_backend_rpc_device_get_type, + /* .get_props = */ ggml_backend_rpc_device_get_props, + /* .init_backend = */ ggml_backend_rpc_device_init, + /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_rpc_device_supports_op, + /* .supports_buft = */ ggml_backend_rpc_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend reg interface + +static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) { + return "RPC"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) { + return 0; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"); + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) { + return (void *)ggml_backend_rpc_add_device; + } + return NULL; + + GGML_UNUSED(reg); +} + +static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = { + /* .get_name = */ ggml_backend_rpc_reg_get_name, + /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count, + /* .get_device = */ ggml_backend_rpc_reg_get_device, + /* .get_proc_address = */ ggml_backend_rpc_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_rpc_reg(void) { + static struct ggml_backend_reg ggml_backend_rpc_reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_rpc_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_rpc_reg; +} + +ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) { + static std::unordered_map dev_map; + + static std::mutex mutex; + std::lock_guard lock(mutex); + + if (dev_map.find(endpoint) != dev_map.end()) { + return dev_map[endpoint]; + } + + ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context { + /* .endpoint = */ endpoint, + /* .name = */ "RPC[" + std::string(endpoint) + "]", + }; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_rpc_device_i, + /* .reg = */ ggml_backend_rpc_reg(), + /* .context = */ ctx, + }; + + dev_map[endpoint] = dev; + + return dev; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt new file mode 100644 index 000000000..3579a311a --- /dev/null +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -0,0 +1,84 @@ +if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$") + message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD") +endif() + +check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL) + +if (DEFINED ENV{ONEAPI_ROOT}) + message(STATUS "Using oneAPI Release SYCL compiler (icpx).") +elseif(SUPPORTS_SYCL) + message(WARNING "Using open-source SYCL compiler (clang++). Didn't detect ENV {ONEAPI_ROOT}. + If you expected the oneAPI Release compiler, please install oneAPI & source it, like: + source /opt/intel/oneapi/setvars.sh") +else() + message(FATAL_ERROR, "C++ compiler lacks SYCL support.") +endif() +message(STATUS "SYCL found") +#todo: AOT + +ggml_add_backend_library(ggml-sycl + ggml-sycl.cpp + ../../include/ggml-sycl.h + ) + +if (GGML_SYCL_F16) + if (GGML_SYCL_TARGET STREQUAL "AMD") + message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.") + endif() + add_compile_definitions(GGML_SYCL_F16) +endif() + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl") + +if (GGML_SYCL_TARGET STREQUAL "NVIDIA") + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) +elseif (GGML_SYCL_TARGET STREQUAL "AMD") + # INFO: Allowed Sub_group_sizes are not consistent through all + # hip targets. For example, 64 is used for certain models, but the backend + # does not support it. + # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) +else() + add_compile_definitions(GGML_SYCL_WARP_SIZE=16) +endif() + +file(GLOB GGML_HEADERS_SYCL "*.hpp") +file(GLOB GGML_SOURCES_SYCL "*.cpp") +target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) + +find_package(DNNL) +message("-- DNNL found:" ${DNNL_FOUND}) + +if (GGML_SYCL_TARGET STREQUAL "INTEL") + add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND}) +else() + add_compile_definitions(GGML_SYCL_DNNL=0) +endif() + +if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") + target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) +endif() + +if (WIN32) + find_package(IntelSYCL REQUIRED) + find_package(MKL REQUIRED) + target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) +else() + if (GGML_SYCL_TARGET STREQUAL "INTEL") + target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread) + elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda") + add_compile_definitions(GGML_SYCL_NVIDIA) + target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas) + elseif (GGML_SYCL_TARGET STREQUAL "AMD") + if (NOT GGML_SYCL_DEVICE_ARCH) + message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa") + target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl) + endif() + + if (GGML_SYCL_DEVICE_ARCH) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}") + endif() +endif() diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index d21b5f8dd..b1df4e5db 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -26,5 +26,9 @@ #include "softmax.hpp" #include "tsembd.hpp" #include "im2col.hpp" +#include "wkv6.hpp" +#include "outprod.hpp" +#include "element_wise.hpp" +#include "gla.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index cf5291b31..022e7b763 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -12,6 +12,9 @@ #include "common.hpp" +#include "ggml-backend-impl.h" +#include "ggml-impl.h" + int get_current_device_id() { return dpct::dev_mgr::instance().current_device_id(); } @@ -28,11 +31,7 @@ void* ggml_sycl_host_malloc(size_t size) try { if (err != 0) { // clear the error - fprintf( - stderr, - "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", - size / 1024.0 / 1024.0, - "syclGetErrorString is not supported"); + GGML_LOG_ERROR("WARNING: failed to allocate %.2f MB of pinned memory: %s\n", size / 1024.0 / 1024.0, "syclGetErrorString is not supported"); return nullptr; } @@ -52,6 +51,10 @@ void ggml_sycl_host_free(void* ptr) try { std::exit(1); } +bool gpu_has_xmx(sycl::device &dev) { + return dev.has(sycl::aspect::ext_intel_matrix); +} + int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { const int64_t max_range = std::numeric_limits::max(); int64_t sycl_down_blk_size = block_size; @@ -62,3 +65,37 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block } return sycl_down_blk_size; } + +void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const ggml_sycl_op_flatten_t op) try { + + const bool use_src1 = src1 != nullptr; + if(use_src1) + GGML_ASSERT(strcmp(src1->buffer->buft->iface.get_name(src1->buffer->buft), GGML_SYCL_NAME "_Split") != 0); + GGML_ASSERT(strcmp(dst->buffer->buft->iface.get_name(dst->buffer->buft), GGML_SYCL_NAME "_Split") != 0); + + // dd = data device + float * src0_ddf = (float *) src0->data; + float * src1_ddf = use_src1 ? (float *) src1->data : nullptr; + float * dst_ddf = (float *) dst->data; + + ggml_sycl_pool_alloc src0_f(ctx.pool()); + ggml_sycl_pool_alloc src1_f(ctx.pool()); + ggml_sycl_pool_alloc dst_f(ctx.pool()); + + ggml_sycl_set_device(ctx.device); + queue_ptr main_stream = ctx.stream(); + // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n", + // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device); + + // do the computation + op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); + // print_ggml_tensor("tensor", dst); +} +catch (sycl::exception const &exc) { + + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 05947ccb7..abad847ca 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -26,7 +26,11 @@ #define GGML_COMMON_DECL_SYCL #define GGML_COMMON_IMPL_SYCL +/* suppress warning spam */ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnested-anon-types" #include "ggml-common.h" +#pragma clang diagnostic pop void* ggml_sycl_host_malloc(size_t size); void ggml_sycl_host_free(void* ptr); @@ -134,7 +138,6 @@ typedef sycl::float2 dfloat2; #endif // GGML_SYCL_F16 #define MMVQ_MAX_BATCH_SIZE 8 -#define MMVQ_MIN_BATCH_SIZE 4 static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; @@ -330,8 +333,12 @@ struct ggml_backend_sycl_context { // pool std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; + std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; + static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); + static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); + ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { pools[device] = new_pool_for_device(stream(device,0), device); @@ -342,6 +349,15 @@ struct ggml_backend_sycl_context { ggml_sycl_pool & pool() { return pool(device); } + + ggml_sycl_pool & host_pool(int device) { + if (host_pools[device] == nullptr) { + host_pools[device] = new_pool_for_host(stream(device, 0), device); + } + return *host_pools[device]; + } + + ggml_sycl_pool & host_pool() { return host_pool(device); } }; // common device functions @@ -405,4 +421,264 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); +typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream); + +template +static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) / + ne3; + const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) % + ne3; + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; + i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + } +} + +template +static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); +} + + +template +struct bin_bcast_sycl { + template + void operator()(ggml_backend_sycl_context & ctx, + const struct ggml_tensor *src0, + const struct ggml_tensor *src1, struct ggml_tensor *dst, + const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd, + queue_ptr stream) { + + GGML_TENSOR_BINARY_OP_LOCALS + + int nr0 = ne10/ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne0[] = {ne0, ne1, ne2, ne3}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb0[] = {nb0, nb1, nb2, nb3}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne0); + collapse(cne1); + } + } + { + int64_t ne0 = cne0[0]; + int64_t ne1 = cne0[1]; + int64_t ne2 = cne0[2]; + int64_t ne3 = cne0[3]; + + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; + + size_t nb0 = cnb0[0]; + size_t nb1 = cnb0[1]; + size_t nb2 = cnb0[2]; + size_t nb3 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + sycl::range<3> block_dims(1, 1, 1); + block_dims[2] = std::min(hne0, block_size); + block_dims[1] = std::min( + ne1, block_size / (unsigned int)block_dims[2]); + block_dims[0] = std::min( + std::min( + ne2 * ne3, block_size / (unsigned int)block_dims[2] / + (unsigned int)block_dims[1]), + 64U); + + sycl::range<3> block_nums( + (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], + (ne1 + block_dims[1] - 1) / block_dims[1], + (hne0 + block_dims[2] - 1) / block_dims[2]); + + if (block_nums[0] > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * + sycl::range<3>(1, 1, block_size), + sycl::range<3>(1, 1, block_size)), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast_unravel( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12, + s13, item_ct1); + }); + } + } else { + /* + DPCT1049:16: The work-group size passed to the SYCL kernel may + exceed the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if + needed. + */ + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, + ne2, ne3, ne10, ne11, ne12, ne13, + s1, s2, s3, s11, s12, s13, + item_ct1); + }); + } + } + GGML_UNUSED(ctx); + } +}; + +template +inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, + (sycl::half *)dst_dd, main_stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd, + main_stream); + } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd, + main_stream); + } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { + op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd, + main_stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + +bool gpu_has_xmx(sycl::device &dev); + +void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const ggml_sycl_op_flatten_t op); + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp index 632eedb9d..d41cfd3a6 100644 --- a/ggml/src/ggml-sycl/concat.cpp +++ b/ggml/src/ggml-sycl/concat.cpp @@ -47,7 +47,7 @@ static void concat_f32_dim1(const float *x, const float *y, float *dst, // operation int offset_dst = nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); - if (item_ct1.get_group(1) < ne01) { // src0 + if (item_ct1.get_group(1) < (size_t) ne01) { // src0 int offset_src = nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01; dst[offset_dst] = x[offset_src]; @@ -70,7 +70,7 @@ static void concat_f32_dim2(const float *x, const float *y, float *dst, // operation int offset_dst = nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); - if (item_ct1.get_group(0) < ne02) { // src0 + if (item_ct1.get_group(0) < (size_t) ne02) { // src0 int offset_src = nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); dst[offset_dst] = x[offset_src]; @@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst, concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); }); break; + // dim >=2 will be dispatched to the default path default: stream->parallel_for( sycl::nd_range<3>(gridDim * @@ -157,8 +158,9 @@ static void concat_f32_sycl_non_cont( }); } -void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst) { +void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; queue_ptr stream = ctx.stream(); const int32_t dim = ((int32_t *)dst->op_params)[0]; diff --git a/ggml/src/ggml-sycl/concat.hpp b/ggml/src/ggml-sycl/concat.hpp index 5a04feaab..e5cb7314c 100644 --- a/ggml/src/ggml-sycl/concat.hpp +++ b/ggml/src/ggml-sycl/concat.hpp @@ -15,7 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst); +void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst); #endif // GGML_SYCL_CONCAT_HPP diff --git a/ggml/src/ggml-sycl/conv.cpp b/ggml/src/ggml-sycl/conv.cpp index bc4ab1ddb..ddba601e1 100644 --- a/ggml/src/ggml-sycl/conv.cpp +++ b/ggml/src/ggml-sycl/conv.cpp @@ -71,8 +71,9 @@ static void conv_transpose_1d_f32_f32_sycl( }); } -void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst) { +void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; diff --git a/ggml/src/ggml-sycl/conv.hpp b/ggml/src/ggml-sycl/conv.hpp index eb20730f9..f9e60dc75 100644 --- a/ggml/src/ggml-sycl/conv.hpp +++ b/ggml/src/ggml-sycl/conv.hpp @@ -15,7 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst); +void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst); #endif // GGML_SYCL_CONV_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 5fd15e6cd..05b01db2d 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -424,7 +424,7 @@ static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); // make each work-item deal with more elements since sycl global range can not exceed max int - const src_t * x = (src_t *) vx; + const src_t * x = (const src_t *) vx; for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) { y[i] = x[i]; } diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 8f4041fff..b8304c3a2 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -55,12 +55,12 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, #ifdef GGML_SYCL_F16 // v = v * {d, d}; // v = v + {m, m}; - v.s0() = (v.s0() * d) + m; - v.s1() = (v.s1() * d) + m; + v.s0() = sycl::fma(v.s0(), d, m); + v.s1() = sycl::fma(v.s1(), d, m); #else - v.x() = (v.x() * d) + m; - v.y() = (v.y() * d) + m; + v.x() = sycl::fma(v.x(), d, m); + v.y() = sycl::fma(v.y(), d, m); #endif // GGML_SYCL_F16 } @@ -110,11 +110,11 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib, #ifdef GGML_SYCL_F16 // v = v * {d, d}; // v = v + {m, m}; - v.s0() = (v.s0() * d) + m; - v.s1() = (v.s1() * d) + m; + v.s0() = sycl::fma(v.s0(), d, m); + v.s1() = sycl::fma(v.s1(), d, m); #else - v.x() = (v.x() * d) + m; - v.y() = (v.y() * d) + m; + v.x() = sycl::fma(v.x(), d, m); + v.y() = sycl::fma(v.y(), d, m); #endif // GGML_SYCL_F16 } diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 0c3dfaa37..0d097357c 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -1015,9 +1015,9 @@ void ggml_sycl_op_dequantize_mul_mat_vec( break; } - (void) src1; - (void) dst; - (void) src1_ddq_i; - (void) src1_ncols; - (void) src1_padded_row_size; + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_ncols); + GGML_UNUSED(src1_padded_row_size); } diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index fe4a8f744..c96395be6 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -15,6 +15,7 @@ #include #include +#include #include #include @@ -81,6 +82,14 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { return device_type.str(); } +template struct matrix_info_t { + oneapi::mkl::transpose transpose_info[2]; + Ts value_info[2]; + std::int64_t size_info[3]; + std::int64_t ld_info[3]; + std::int64_t groupsize_info; +}; + namespace dpct { typedef sycl::queue *queue_ptr; @@ -1236,7 +1245,7 @@ namespace dpct std::map::iterator get_map_iterator(const void *ptr) { - auto it = m_map.upper_bound((byte_t *)ptr); + auto it = m_map.upper_bound(const_cast(reinterpret_cast(ptr))); if (it == m_map.end()) { // Not a virtual pointer. @@ -1688,9 +1697,14 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); - oneapi::mkl::blas::column_major::gemm( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - data_b, ldb, beta_value, data_c, ldc); +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector{ q }, + a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, + beta_value, data_c, ldc); +#else + oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, + beta_value, data_c, ldc); +#endif } template @@ -1721,26 +1735,13 @@ namespace dpct }; template - inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void **a, int lda, - const void **b, int ldb, const void *beta, void **c, - int ldc, int batch_size) - { - struct matrix_info_t - { - oneapi::mkl::transpose transpose_info[2]; - Ts value_info[2]; - std::int64_t size_info[3]; - std::int64_t ld_info[3]; - std::int64_t groupsize_info; - }; - + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, + int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, + int ldb, const void * beta, void ** c, int ldc, int batch_size, + matrix_info_t * matrix_info) { Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); Ts beta_value = dpct::get_value(reinterpret_cast(beta), q); - matrix_info_t *matrix_info = - (matrix_info_t *)std::malloc(sizeof(matrix_info_t)); matrix_info->transpose_info[0] = a_trans; matrix_info->transpose_info[1] = b_trans; matrix_info->value_info[0] = alpha_value; @@ -1753,19 +1754,22 @@ namespace dpct matrix_info->ld_info[2] = ldc; matrix_info->groupsize_info = batch_size; +#ifdef GGML_SYCL_NVIDIA sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( - q, matrix_info->transpose_info, matrix_info->transpose_info + 1, - matrix_info->size_info, matrix_info->size_info + 1, - matrix_info->size_info + 2, matrix_info->value_info, - reinterpret_cast(a), matrix_info->ld_info, - reinterpret_cast(b), matrix_info->ld_info + 1, - matrix_info->value_info + 1, reinterpret_cast(c), - matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); - - q.submit([&](sycl::handler &cgh) - { - cgh.depends_on(e); - cgh.host_task([=] { std::free(matrix_info); }); }); + oneapi::mkl::backend_selector{ q }, matrix_info->transpose_info, + matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, + matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), + reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); +#else + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, + matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), + reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), + matrix_info->ld_info + 1, reinterpret_cast(matrix_info->value_info + 1), + reinterpret_cast(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info)); +#endif } template @@ -1782,10 +1786,16 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); +#ifdef GGML_SYCL_NVIDIA oneapi::mkl::blas::column_major::gemm_batch( - q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, - stride_a, data_b, ldb, stride_b, beta_value, - data_c, ldc, stride_c, batch_size); + oneapi::mkl::backend_selector{ q }, a_trans, b_trans, m, n, k, + alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c, + batch_size); +#else + oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, + stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, + stride_c, batch_size); +#endif } } // namespace detail @@ -1830,31 +1840,10 @@ namespace dpct : id); } - template - sycl::vec extract_and_sign_or_zero_extend4(T val) - { - return sycl::vec(val) - .template as, int8_t, uint8_t>, 4>>() - .template convert(); - } - - template - using dot_product_acc_t = - std::conditional_t && std::is_unsigned_v, - uint32_t, int32_t>; - template inline auto dp4a(T1 a, T2 b, T3 c) { - dot_product_acc_t res = c; - auto va = extract_and_sign_or_zero_extend4(a); - auto vb = extract_and_sign_or_zero_extend4(b); - res += va[0] * vb[0]; - res += va[1] * vb[1]; - res += va[2] * vb[2]; - res += va[3] * vb[3]; - return res; + return syclcompat::dp4a(a, b, c); } struct sub_sat @@ -2423,25 +2412,11 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans, - oneapi::mkl::transpose b_trans, int m, int n, int k, - const void *alpha, const void *a[], - library_data_t a_type, int lda, const void *b[], - library_data_t b_type, int ldb, const void *beta, - void *c[], library_data_t c_type, int ldc, - int batch_size, library_data_t scaling_type) - { - if (scaling_type == library_data_t::real_float && - c_type == library_data_t::complex_float) - { - scaling_type = library_data_t::complex_float; - } - else if (scaling_type == library_data_t::real_double && - c_type == library_data_t::complex_double) - { - scaling_type = library_data_t::complex_double; - } - + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, + int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, + const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], + library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, + matrix_info_t * matrix_info) { std::uint64_t key = detail::get_type_combination_id(a_type, b_type, c_type, scaling_type); switch (key) @@ -2450,48 +2425,24 @@ namespace dpct library_data_t::real_float, library_data_t::real_float, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_double, library_data_t::real_double, library_data_t::real_double, library_data_t::real_double): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_float, library_data_t::complex_float, - library_data_t::complex_float, library_data_t::complex_float): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); - break; - } - case detail::get_type_combination_id( - library_data_t::complex_double, library_data_t::complex_double, - library_data_t::complex_double, library_data_t::complex_double): - { - detail::gemm_batch_impl, std::complex, - std::complex, std::complex>( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_half, library_data_t::real_half, library_data_t::real_half, library_data_t::real_half): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, - a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #ifdef __INTEL_MKL__ @@ -2499,19 +2450,16 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, alpha, a, lda, - b, ldb, beta, c, ldc, batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } #endif @@ -2523,10 +2471,9 @@ namespace dpct dpct::get_value(reinterpret_cast(alpha), q); float beta_float = dpct::get_value(reinterpret_cast(beta), q); - detail::gemm_batch_impl(q, a_trans, b_trans, m, n, k, &alpha_float, - a, lda, b, ldb, &beta_float, c, ldc, - batch_size); + detail::gemm_batch_impl( + q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc, batch_size, + matrix_info); break; } case detail::get_type_combination_id( @@ -2534,8 +2481,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2543,8 +2489,7 @@ namespace dpct library_data_t::real_float, library_data_t::real_float): { detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } case detail::get_type_combination_id( @@ -2558,8 +2503,7 @@ namespace dpct sycl::half alpha_half(alpha_value); sycl::half beta_half(beta_value); detail::gemm_batch_impl( - q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, - batch_size); + q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc, batch_size, matrix_info); break; } default: diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp new file mode 100644 index 000000000..4bcd74376 --- /dev/null +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -0,0 +1,1030 @@ +#include "common.hpp" +#include "element_wise.hpp" + +void acc_f32(const float * x, const float * y, float * dst, const int ne, + const int ne10, const int ne11, const int ne12, + const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= ne) { + return; + } + int src1_idx = i - offset; + int oz = src1_idx / nb2; + int oy = (src1_idx - (oz * nb2)) / nb1; + int ox = src1_idx % nb1; + if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { + dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; + } else { + dst[i] = x[i]; + } +} + +void gelu_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + + float xi = x[i]; + dst[i] = 0.5f * xi * + (1.0f + + sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi))); +} + +void silu_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i])); +} + +void gelu_quick_f32(const float *x, float *dst, int k, + const sycl::nd_item<3> &item_ct1) { + const float GELU_QUICK_COEF = -1.702f; + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= k) { + return; + } + dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i]))); +} + +void tanh_f32(const float *x, float *dst, int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= k) { + return; + } + dst[i] = sycl::tanh((float)(x[i])); +} + +void relu_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::fmax((float)(x[i]), (float)0); +} + +void sigmoid_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = 1.0f / (1.0f + sycl::native::exp(-x[i])); +} + +void sqrt_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::sqrt(x[i]); +} + +void sin_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::sin(x[i]); +} + +void cos_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::cos(x[i]); +} + +void hardsigmoid_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); +} + +void hardswish_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); +} + +void exp_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = sycl::exp(x[i]); +} + +void log_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + float xi = x[i]; + if (xi <= 0) { + dst[i] = -INFINITY; + } else { + dst[i] = sycl::log(xi); + } +} + +void neg_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = -x[i]; +} + +void step_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] > 0.0f; +} + +void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + if (i >= k) { + return; + } + dst[i] = sycl::fmax((float)(x[i]), (float)0) + + sycl::fmin((float)(x[i]), 0.0f) * negative_slope; +} + +void sqr_f32(const float * x, float * dst, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + dst[i] = x[i] * x[i]; +} + +void upscale_f32(const float *x, float *dst, const int nb00, const int nb01, + const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int ne13, const float sf0, const float sf1, + const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { + int index = item_ct1.get_local_id(0) + + item_ct1.get_group(0) * item_ct1.get_local_range(0); + if (index >= ne10 * ne11 * ne12 * ne13) { + return; + } + // operation + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *(const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); +} + +void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02, + const sycl::nd_item<3> &item_ct1) { + int nidx = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (nidx >= ne0) { + return; + } + + // operation + int offset_dst = nidx + item_ct1.get_group(1) * ne0 + + item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); + if (nidx < ne00 && item_ct1.get_group(1) < (size_t) ne01 && item_ct1.get_group(0) < (size_t) ne02) { + int offset_src = nidx + item_ct1.get_group(1) * ne00 + + item_ct1.get_group(0) * ne00 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + dst[offset_dst] = 0.0f; + } +} + + + +void acc_f32_sycl(const float *x, const float *y, float *dst, + const int n_elements, const int ne10, const int ne11, + const int ne12, const int nb1, const int nb2, + const int offset, queue_ptr stream) { + int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, + item_ct1); + }); +} + +void gelu_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + gelu_f32(x, dst, k, item_ct1); + }); +} + +void silu_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + silu_f32(x, dst, k, item_ct1); + }); +} + +void gelu_quick_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + gelu_quick_f32(x, dst, k, item_ct1); + }); +} + +void tanh_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + tanh_f32(x, dst, k, item_ct1); + }); +} + +void relu_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + relu_f32(x, dst, k, item_ct1); + }); +} + +void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + hardsigmoid_f32(x, dst, k, item_ct1); + }); +} + +void hardswish_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + hardswish_f32(x, dst, k, item_ct1); + }); +} + +void exp_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + exp_f32(x, dst, k, item_ct1); + }); +} + +void log_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_EXP_BLOCK_SIZE - 1) / SYCL_EXP_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_EXP_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + log_f32(x, dst, k, item_ct1); + }); +} + +void neg_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + neg_f32(x, dst, k, item_ct1); + }); +} + +void step_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_NEG_BLOCK_SIZE - 1) / SYCL_NEG_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_NEG_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + step_f32(x, dst, k, item_ct1); + }); +} + +void sigmoid_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SIGMOID_BLOCK_SIZE - 1) / SYCL_SIGMOID_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SIGMOID_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sigmoid_f32(x, dst, k, item_ct1); + }); +} + +void sqrt_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SQRT_BLOCK_SIZE - 1) / SYCL_SQRT_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SQRT_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sqrt_f32(x, dst, k, item_ct1); + }); +} + +void sin_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sin_f32(x, dst, k, item_ct1); + }); +} + +void cos_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SIN_BLOCK_SIZE - 1) / SYCL_SIN_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SIN_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + cos_f32(x, dst, k, item_ct1); + }); +} + +void leaky_relu_f32_sycl(const float *x, float *dst, const int k, + const float negative_slope, + queue_ptr stream) { + const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + leaky_relu_f32(x, dst, k, negative_slope, item_ct1); + }); +} + +void sqr_f32_sycl(const float *x, float *dst, const int k, + queue_ptr stream) { + const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + sqr_f32(x, dst, k, item_ct1); + }); +} + +void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01, + const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int ne13, const float sf0, const float sf1, + const float sf2, const float sf3, queue_ptr stream) { + int dst_size = ne10 * ne11 * ne12 * ne13; + int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); + }); +} + +void pad_f32_sycl(const float *x, float *dst, const int ne00, + const int ne01, const int ne02, const int ne0, + const int ne1, const int ne2, queue_ptr stream) { + int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; + sycl::range<3> gridDim(ne2, ne1, num_blocks); + stream->parallel_for( + sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1); + }); +} + +inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} +inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + exp_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + log_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqrt_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sin_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + cos_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + step_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + neg_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + + leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const float sf0 = (float)dst->ne[0]/src0->ne[0]; + const float sf1 = (float)dst->ne[1]/src0->ne[1]; + const float sf2 = (float)dst->ne[2]/src0->ne[2]; + const float sf3 = (float)dst->ne[3]/src0->ne[3]; + + upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, + main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors + + pad_f32_sycl(src0_dd, dst_dd, + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream); + + GGML_UNUSED(dst); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + +inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + +inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + +inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, + ggml_tensor *dst, const float *src0_dd, + const float *src1_dd, float *dst_dd, + const queue_ptr &main_stream) { + + ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} + + +void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqrt); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sin); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_cos); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_acc); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_silu); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu_quick); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_tanh); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_relu); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sigmoid); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardsigmoid); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardswish); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + + +void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_exp); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_log); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_neg); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_step); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_leaky_relu); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqr); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_upscale); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pad); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + + + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_add); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sub); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_mul); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_div); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp new file mode 100644 index 000000000..464432645 --- /dev/null +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -0,0 +1,76 @@ +#ifndef GGML_SYCL_ELEMENTWISE_HPP +#define GGML_SYCL_ELEMENTWISE_HPP + +#include "common.hpp" + +static __dpct_inline__ float op_repeat(const float a, const float b) { + return b; + GGML_UNUSED(a); +} + +static __dpct_inline__ float op_add(const float a, const float b) { + return a + b; +} + +static __dpct_inline__ float op_sub(const float a, const float b) { + return a - b; +} + +static __dpct_inline__ float op_mul(const float a, const float b) { + return a * b; +} + +static __dpct_inline__ float op_div(const float a, const float b) { + return a / b; +} + + +void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_ELEMENTWISE_HPP diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 2ad9b36f4..3f0f34ad6 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -51,8 +51,8 @@ public: const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); - auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); - auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); + auto a_mem = dnnl::memory(a_in_md, eng, const_cast(a)); + auto b_mem = dnnl::memory(b_in_md, eng, const_cast(b)); auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); @@ -79,8 +79,8 @@ public: const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); - auto a_mem = dnnl::memory(a_in_md, eng, (void*)a); - auto b_mem = dnnl::memory(b_in_md, eng, (void*)b); + auto a_mem = dnnl::memory(a_in_md, eng, const_cast(a)); + auto b_mem = dnnl::memory(b_in_md, eng, const_cast(b)); auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md); auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c); diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp similarity index 73% rename from ggml/src/ggml-sycl.cpp rename to ggml/src/ggml-sycl/ggml-sycl.cpp index 4f03b01e7..2984ed82e 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -33,24 +33,326 @@ #include #include "ggml-sycl.h" -#include "ggml.h" +#include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-sycl/backend.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" -bool ggml_sycl_loaded(void); -void ggml_sycl_free_data(struct ggml_tensor * tensor); -void ggml_sycl_copy_to_device(struct ggml_tensor * tensor); -void ggml_sycl_set_main_device(int main_device); -void ggml_sycl_set_mul_mat_q(bool mul_mat_q); -void ggml_sycl_get_device_description(int device, char * description, size_t description_size); -bool ggml_backend_is_sycl(ggml_backend_t backend); -int ggml_backend_sycl_get_device(ggml_backend_t backend); -static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer); -static inline int get_sycl_env(const char *env_name, int default_val); +static bool g_sycl_loaded = false; +static ggml_sycl_device_info ggml_sycl_init() { + ggml_sycl_device_info info = {}; + + info.device_count = dpct::dev_mgr::instance().device_count(); + if (info.device_count == 0) { + GGML_LOG_ERROR("%s: failed to initialize: %s\n", GGML_SYCL_NAME, __func__); + return info; + } + + GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES); + + int64_t total_vram = 0; +/* This is a bit misleading; reserved for later */ +// #if defined(SYCL_USE_XMX) +// GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__); +// #else +// GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__); +// #endif + for (int i = 0; i < info.device_count; ++i) { + info.devices[i].vmm = 0; + dpct::device_info prop; + SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( + prop, dpct::dev_mgr::instance().get_device(i)))); + + info.default_tensor_split[i] = total_vram; + total_vram += prop.get_global_mem_size(); + + info.devices[i].cc = + 100 * prop.get_major_version() + 10 * prop.get_minor_version(); + + info.max_work_group_sizes[i] = prop.get_max_work_group_size(); + } + + for (int id = 0; id < info.device_count; ++id) { + info.default_tensor_split[id] /= total_vram; + } + return info; +} + +const ggml_sycl_device_info & ggml_sycl_info() { + static ggml_sycl_device_info info = ggml_sycl_init(); + return info; +} + +void print_device_detail(int id, sycl::device &device, std::string device_type) { + + dpct::device_info prop; + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::get_device_info(prop, device))); + + std::string version; + version += std::to_string(prop.get_major_version()); + version += "."; + version += std::to_string(prop.get_minor_version()); + + device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), ""); + std::string name = std::string(prop.get_name()); + name = std::regex_replace(name, std::regex("\\(R\\)"), ""); + name = std::regex_replace(name, std::regex("\\(TM\\)"), ""); + + auto global_mem_size = prop.get_global_mem_size()/1000000; + std::string xmx = gpu_has_xmx(device) ? "yes" : "no"; + GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|%14s|\n", id, device_type.c_str(), + name.c_str(), version.c_str(), prop.get_max_compute_units(), + prop.get_max_work_group_size(), prop.get_max_sub_group_size(), + global_mem_size, device.get_info().c_str(), xmx.c_str()); +} + +void ggml_backend_sycl_print_sycl_devices() { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n"); + int device_count = dpct::dev_mgr::instance().device_count(); + std::map DeviceNums; + GGML_LOG_INFO("Found %d SYCL devices:\n", device_count); + + GGML_LOG_INFO( + "| | | | " + " |Max | |Max |Global | | XMX |\n"); + GGML_LOG_INFO( + "| | | | " + " |compute|Max work|sub |mem | | or |\n"); + GGML_LOG_INFO( + "|ID| Device Type| " + "Name|Version|units |group |group|size | Driver version| Tensor Cores |\n"); + GGML_LOG_INFO( + "|--|-------------------|---------------------------------------|------" + "-|-------|--------|-----|-------|---------------------|--------------|\n"); + + for (int id = 0; id < device_count; ++id) { + sycl::device device = dpct::dev_mgr::instance().get_device(id); + std::string backend_type = get_device_backend_and_type(device); + int type_id = DeviceNums[backend_type]++; + std::stringstream device_type; + device_type << "[" << backend_type << ":" << std::to_string(type_id) + << "]"; + print_device_detail(id, device, device_type.str()); + } +} + +static inline int get_sycl_env(const char *env_name, int default_val) { + char *user_device_string = getenv(env_name); + int user_number = default_val; + + unsigned n; + if (user_device_string != NULL && + sscanf(user_device_string, " %u", &n) == 1) { + user_number = (int)n; + } else { + user_number = default_val; + } + return user_number; +} + +static void ggml_check_sycl() try { + static bool initialized = false; + + if (!initialized) { + GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); + g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); + GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); +#if defined(GGML_SYCL_FORCE_MMQ) + GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n"); +#else + GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: no\n"); +#endif +#if defined(GGML_SYCL_F16) + GGML_LOG_INFO("GGML_SYCL_F16: yes\n"); +#else + GGML_LOG_INFO("GGML_SYCL_F16: no\n"); +#endif + +/* NOT REMOVE, keep it for next optimize for XMX. +#if defined(SYCL_USE_XMX) + fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__); +#else + fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); +#endif +*/ + + if (CHECK_TRY_ERROR(g_all_sycl_device_count = + dpct::dev_mgr::instance().device_count()) != 0) { + initialized = true; + g_sycl_loaded = false; + return; + } + GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES); + + initialized = true; + g_sycl_loaded = true; + ggml_backend_sycl_print_sycl_devices(); + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +/* +device_index: device index from 0 to n (continue numbers). + It is used for device select/set in SYCL backend internal data structure. +*/ +inline void check_allow_gpu_index(const int device_index) { + if (device_index >= ggml_sycl_info().device_count) { + char error_buf[256]; + snprintf( + error_buf, + sizeof(error_buf), + "%s error: device_index:%d is out of range: [0-%d]", + __func__, + device_index, + ggml_sycl_info().device_count - 1); + GGML_LOG_ERROR("%s\n", error_buf); + assert(false); + } +} + +GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_gpu_list\n"); + for(int i=0;i=max_len) break; + id_list[i] = i; + } + return; +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +// sycl buffer + +struct ggml_backend_sycl_buffer_context { + int device; + void * dev_ptr = nullptr; + queue_ptr stream; + std::string name; + + ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) : + device(device), dev_ptr(dev_ptr), stream(stream) { + check_allow_gpu_index(device); + name = (GGML_SYCL_NAME + std::to_string(device)); + } + + + ~ggml_backend_sycl_buffer_context() { + if (dev_ptr != nullptr) { + ggml_sycl_set_device(device); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); + } + } +}; + +static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft); + +static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_sycl_buffer_type_get_name; +} + +static void +ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { + ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + ggml_sycl_set_device(ctx->device); + + delete ctx; +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + return ctx->dev_ptr; +} + +static void +ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, + ggml_tensor *tensor) try { + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context; + + if (tensor->view_src != NULL) { + assert(tensor->view_src->buffer->buft == buffer->buft); + return; + } + + + if (ggml_is_quantized(tensor->type)) { + // initialize padding to 0 to avoid possible NaN values + size_t original_size = ggml_nbytes(tensor); + size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); + + if (padded_size > original_size && tensor->view_src == nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset( + (char *)tensor->data + original_size, 0, + padded_size - original_size).wait())); + } + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor *tensor, + const void *data, size_t offset, + size_t size) try { + + ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + + ggml_sycl_set_device(ctx->device); + auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); + SYCL_CHECK( + CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); + char* host_buf = (char*)malloc(size); + memcpy(host_buf, data, size); + SYCL_CHECK( + CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size) + .wait())); + free(host_buf); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor *tensor, + void *data, size_t offset, + size_t size) try { + + ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + + ggml_sycl_set_device(ctx->device); + auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue(); + + SYCL_CHECK(CHECK_TRY_ERROR( + stream.memcpy(data, (const char *)tensor->data + offset, size) + .wait())); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, const void *ptr_src, size_t size) { @@ -60,8 +362,910 @@ void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, free(host_buf); } +static bool +ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor *src, + ggml_tensor *dst) try { + if (ggml_backend_buffer_is_sycl(src->buffer)) { + ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context; + ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context; + + ggml_sycl_set_device(src_ctx->device); + /* + DPCT1009:198: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw())); + ggml_sycl_set_device(dst_ctx->device); + /* + DPCT1009:199: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw())); + /* + DPCT1009:200: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + + queue_ptr stream_dst = dst_ctx->stream; + queue_ptr stream_src = src_ctx->stream; + size_t size = ggml_nbytes(src); + + //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs. + dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size); + +//todo, it's known issue:error in device2device cross GPUs. reused when the issue is fixed. DON"T remove +#if 0 + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy( + (char *)dst->data, (const char *)src->data, size).wait())); + + /* + DPCT1009:201: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw())); +#endif + return true; + } + return false; + GGML_UNUSED(buffer); +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, + uint8_t value) try { + ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + + ggml_sycl_set_device(ctx->device); + queue_ptr stream = ctx->stream; + SYCL_CHECK( + CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); + + SYCL_CHECK(CHECK_TRY_ERROR((*stream) + .memset(ctx->dev_ptr, value, buffer->size) + .wait())); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { + /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer, + /* .get_base = */ ggml_backend_sycl_buffer_get_base, + /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, + /* .clear = */ ggml_backend_sycl_buffer_clear, + /* .reset = */ NULL, +}; + +// sycl buffer type +struct ggml_backend_sycl_buffer_type_context { + int device; + std::string name; + + // each buffer type has its own stream + queue_ptr stream = nullptr; +}; + +static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t +ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) try { + ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; + ggml_sycl_set_device(buft_ctx->device); + const queue_ptr stream = buft_ctx->stream; + size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 + + void * dev_ptr; + SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( + size, *stream))); + if (!dev_ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size); + return nullptr; + } + ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream); + return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size); +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 128; + GGML_UNUSED(buft); +} + +static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + return dpct::get_current_device().get_max_mem_alloc_size(); + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + size_t size = ggml_nbytes(tensor); + int64_t ne0 = tensor->ne[0]; + + if (ggml_is_quantized(tensor->type)) { + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + } + + return size; + + GGML_UNUSED(buft); +} + +static const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = { + /* .get_name = */ ggml_backend_sycl_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_sycl_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_sycl_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_sycl_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_sycl_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n"); + + auto dev_count = ggml_backend_sycl_get_device_count(); + + if (device>=dev_count or device<0) { + GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", + device, dev_count-1); + GGML_ASSERT(devicedevice; + if (device>=ggml_sycl_info().device_count or device<0) { + GGML_LOG_ERROR("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", + device, ggml_sycl_info().device_count-1); + GGML_ASSERT(devicestream(i, 0)}, + }; + } + ggml_backend_sycl_buffer_type_initialized = true; + } + return &ggml_backend_sycl_buffer_types[device]; +} + +// sycl split buffer + +static int64_t get_row_rounding(ggml_type type, const std::array & tensor_split) { + int64_t min_compute_capability = INT_MAX; + int64_t max_compute_capability = INT_MIN; + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) { + if (min_compute_capability > ggml_sycl_info().devices[i].cc) { + min_compute_capability = ggml_sycl_info().devices[i].cc; + } + if (max_compute_capability < ggml_sycl_info().devices[i].cc) { + max_compute_capability = ggml_sycl_info().devices[i].cc; + } + } + } + + switch(type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return max_compute_capability >= VER_GEN9 ? 128 : 64; + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return 64; + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return 1; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return max_compute_capability >= VER_GEN9 ? 128 : 64; + case GGML_TYPE_IQ3_S: + return max_compute_capability >= VER_GEN9 ? 128 : 64; + case GGML_TYPE_Q6_K: + return 64; + default: + GGML_ABORT("fatal error"); + } +} + +static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array & tensor_split, int id) { + const int64_t nrows = ggml_nrows(tensor); + const int64_t rounding = get_row_rounding(tensor->type, tensor_split); + + *row_low = id == 0 ? 0 : nrows*tensor_split[id]; + *row_low -= *row_low % rounding; + if (id == ggml_sycl_info().device_count - 1) { + *row_high = nrows; + } else { + *row_high = nrows*tensor_split[id + 1]; + *row_high -= *row_high % rounding; + } +} + +static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]); +} + +struct ggml_backend_sycl_split_buffer_type_context { + std::array tensor_split; +}; + +struct ggml_backend_sycl_split_buffer_context { + ~ggml_backend_sycl_split_buffer_context() try { + for (ggml_tensor_extra_gpu * extra : tensor_extras) { + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { + if (extra->events[i][is] != nullptr) { + /* + DPCT1009:206: SYCL uses exceptions to report errors and + does not use the error codes. The original code was + commented out and a warning string was inserted. You + need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::destroy_event(extra->events[i][is]))); + } + } + if (extra->data_device[i] != nullptr) { + /* + DPCT1009:207: SYCL uses exceptions to report errors and does + not use the error codes. The original code was commented out + and a warning string was inserted. You need to rewrite this + code. + */ + ggml_sycl_set_device(i); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free( + extra->data_device[i], *(streams[i])))); + } + } + delete extra; + } + } + catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); + } + + std::vector tensor_extras; + std::vector streams; +}; + +static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; + delete ctx; +} + +static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) { + // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced + return (void *)0x1000; + + GGML_UNUSED(buffer); +} + +static void +ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, + ggml_tensor *tensor) try { + GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported + + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; + + const int64_t ne0 = tensor->ne[0]; + + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + + ctx->tensor_extras.push_back(extra); + ctx->streams.push_back(&(dpct::get_current_device().default_queue())); + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + // FIXME: do not crash if SYCL Buffer alloc fails + // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first + ggml_sycl_set_device(i); + const queue_ptr stream = ctx->streams[i]; + char * buf; + /* + DPCT1009:208: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( + size, *stream))); + if (!buf) { + char err_buf[1024]; + snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size); + throw std::runtime_error(err_buf); + } + // set padding to 0 to avoid possible NaN values + if (size > original_size) { + /* + DPCT1009:209: SYCL uses exceptions to report errors and does not use + the error codes. The original code was commented out and a warning + string was inserted. You need to rewrite this code. + */ + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memset(buf + original_size, 0, size - original_size) + .wait())); + } + + extra->data_device[i] = buf; + + for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { + /* + DPCT1009:210: SYCL uses exceptions to report errors and does not use + the error codes. The original code was commented out and a warning + string was inserted. You need to rewrite this code. + */ + SYCL_CHECK( + CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event())); + } + } + tensor->extra = extra; +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void +ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor *tensor, const void *data, + size_t offset, size_t size) try { + // split tensors must always be set in their entirety at once + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; + + const int64_t ne0 = tensor->ne[0]; + const size_t nb1 = tensor->nb[1]; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + const size_t offset_split = row_low*nb1; + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + const char * buf_host = (const char *)data + offset_split; + /* + DPCT1009:211: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + ggml_sycl_set_device(i); + const queue_ptr stream = ctx->streams[i]; + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memcpy(extra->data_device[i], buf_host, original_size) + .wait())); + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void +ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor *tensor, void *data, + size_t offset, size_t size) try { + // split tensors must always be set in their entirety at once + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; + + const int64_t ne0 = tensor->ne[0]; + const size_t nb1 = tensor->nb[1]; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + const size_t offset_split = row_low*nb1; + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + char * buf_host = (char *)data + offset_split; + /* + DPCT1009:212: SYCL uses exceptions to report errors and does not use the + error codes. The original code was commented out and a warning string + was inserted. You need to rewrite this code. + */ + ggml_sycl_set_device(i); + const queue_ptr stream = ctx->streams[i]; + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memcpy(buf_host, extra->data_device[i], original_size) + .wait())); + } +} +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_UNUSED(buffer); + GGML_UNUSED(value); +} + +static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = { + /* .free_buffer = */ ggml_backend_sycl_split_buffer_free_buffer, + /* .get_base = */ ggml_backend_sycl_split_buffer_get_base, + /* .init_tensor = */ ggml_backend_sycl_split_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_sycl_split_buffer_clear, + /* .reset = */ NULL, +}; + +// sycl split buffer type + +static const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return GGML_SYCL_NAME "_Split"; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_sycl_split_buffer_type_get_name; +} + +static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point + // instead, we allocate them for each tensor separately in init_tensor + // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated, + // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct. + ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context(); + + return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size); +} + +static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 128; + GGML_UNUSED(buft); +} + +static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context; + + size_t total_size = 0; + + const int64_t ne0 = tensor->ne[0]; + + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + total_size += ggml_nbytes_split(tensor, nrows_split); + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + } + + return total_size; +} + +static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = { + /* .get_name = */ ggml_backend_sycl_split_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_sycl_split_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_sycl_split_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_sycl_split_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_sycl_split_buffer_type_is_host, +}; + +ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n"); + ggml_check_sycl(); + // FIXME: this is not thread safe + static std::map, struct ggml_backend_buffer_type> buft_map; + + std::array tensor_split_arr = {}; + + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; }); + if (all_zero) { + tensor_split_arr = ggml_sycl_info().default_tensor_split; + } else { + float split_sum = 0.0f; + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + tensor_split_arr[i] = split_sum; + split_sum += tensor_split[i]; + } + for (int i = 0; i < ggml_sycl_info().device_count; ++i) { + tensor_split_arr[i] /= split_sum; + } + } + + auto it = buft_map.find(tensor_split_arr); + if (it != buft_map.end()) { + return &it->second; + } + + struct ggml_backend_buffer_type buft { + /* .iface = */ ggml_backend_sycl_split_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0), + /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr}, + }; + + auto result = buft_map.emplace(tensor_split_arr, buft); + return &result.first->second; +} + +// host buffer type + +static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) { + return GGML_SYCL_NAME "_Host"; + + GGML_UNUSED(buft); +} + +static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_sycl_host_free(buffer->context); +} + +static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * ptr = ggml_sycl_host_malloc(size); + + if (ptr == nullptr) { + // fallback to cpu buffer + return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + } + + // FIXME: this is a hack to avoid having to implement a new buffer type + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer; + + return buffer; +} + +ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n"); + static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = { + /* .iface = */ { + /* .get_name = */ ggml_backend_sycl_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_sycl_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, + /* .get_max_size = */ NULL, // TODO: return device.maxBufferLength + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0), + /* .context = */ nullptr, + }; + + return &ggml_backend_sycl_buffer_type_host; +} + +// buffer pool for sycl (legacy) +struct ggml_sycl_pool_leg : public ggml_sycl_pool { + static const int MAX_SYCL_BUFFERS = 256; + + int device; + queue_ptr qptr; + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {}; + size_t pool_size = 0; + + explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {} + + ~ggml_sycl_pool_leg() { + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + pool_size -= b.size; + } + } + GGML_ASSERT(pool_size == 0); + } + + void * alloc(size_t size, size_t * actual_size) override { +#ifdef DEBUG_sycl_MALLOC + int nnz = 0; + size_t max_size = 0; +#endif + size_t best_diff = 1ull << 36; + int ibest = -1; + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + ggml_sycl_buffer& b = buffer_pool[i]; + if (b.ptr != nullptr) { +#ifdef DEBUG_sycl_MALLOC + ++nnz; + if (b.size > max_size) max_size = b.size; +#endif + if (b.size >= size) { + size_t diff = b.size - size; + if (diff < best_diff) { + best_diff = diff; + ibest = i; + if (!best_diff) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + } + } + } + if (ibest >= 0) { + ggml_sycl_buffer& b = buffer_pool[ibest]; + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + void * ptr; + size_t look_ahead_size = (size_t) (1.05 * size); + + SYCL_CHECK( + CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( + look_ahead_size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size); + return nullptr; + } + + *actual_size = look_ahead_size; + pool_size += look_ahead_size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_DEBUG("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz, + (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024)); +#endif + + // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr); + return ptr; + } + + void free(void * ptr, size_t size) override { + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + ggml_sycl_buffer& b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + pool_size -= size; + } +}; + +struct ggml_sycl_pool_host : public ggml_sycl_pool { + queue_ptr qptr; + int device; + + inline static int counter{ 0 }; + + struct ggml_sycl_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + // Set arbitrarly to 64 + static constexpr int MAX_POOL_SIZE{ 64 }; + std::vector buffer_pool = std::vector(MAX_POOL_SIZE); + size_t pool_size = 0; + + explicit ggml_sycl_pool_host(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} + + ~ggml_sycl_pool_host() { + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + b.ptr = nullptr; + pool_size -= b.size; + b.size = 0; + } + } + counter = 0; + } + + void * alloc(size_t size, size_t * actual_size) override { + if (counter == MAX_POOL_SIZE) { + ggml_sycl_buffer b = buffer_pool[0]; + void * ptr = b.ptr; + *actual_size = b.size; + counter = 1; + return ptr; + } + ggml_sycl_buffer & b = buffer_pool[counter]; + + if (b.ptr == nullptr) { + void * ptr; + + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *) sycl::malloc_host(size, *qptr))); + if (!ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on host\n", __func__, size); + return nullptr; + } + pool_size += size; + *actual_size = size; + counter = counter + 1; + return ptr; + } else { + ++counter; + b.size = size; + return b.ptr; + } + } + + void free(void * ptr, size_t size) override { + // if the pool is not completed add the pointer to it in place of the first nullptr found. + // Otherwise do nothing, pointers will be freed once the pool is deallocated. + for (int i = 0; i < MAX_POOL_SIZE; ++i) { + ggml_sycl_buffer & b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + } +}; + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_host(queue_ptr qptr, int device) { + // return pool for the host to speed up memory management + return std::unique_ptr(new ggml_sycl_pool_host(qptr, device)); +} + +std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { + // TBD: NO VMM support + // if (ggml_sycl_info().devices[device].vmm) { + // return std::unique_ptr(new ggml_sycl_pool_vmm(device)); + // } + return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); +} + +// TBD pool with virtual memory management +// struct ggml_sycl_pool_vmm : public ggml_sycl_pool + +/// kernels + typedef void (*cpy_kernel_t)(const char * cx, char * cdst); -typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); typedef void (*ggml_sycl_op_mul_mat_t)( ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -69,272 +1273,8 @@ typedef void (*ggml_sycl_op_mul_mat_t)( float *dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, const queue_ptr &stream); -typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream); -static __dpct_inline__ float op_repeat(const float a, const float b) { - return b; - GGML_UNUSED(a); -} -static __dpct_inline__ float op_add(const float a, const float b) { - return a + b; -} - -static __dpct_inline__ float op_mul(const float a, const float b) { - return a * b; -} - -static __dpct_inline__ float op_div(const float a, const float b) { - return a / b; -} - -template -static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); - const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) / - ne3; - const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) % - ne3; - - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i_src0; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - for (int i0 = i0s; i0 < ne0; - i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); - } -} - -template -static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; - - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i_src0; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); -} - -static void acc_f32(const float * x, const float * y, float * dst, const int ne, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= ne) { - return; - } - int src1_idx = i - offset; - int oz = src1_idx / nb2; - int oy = (src1_idx - (oz * nb2)) / nb1; - int ox = src1_idx % nb1; - if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { - dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; - } else { - dst[i] = x[i]; - } -} - -static void gelu_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - - float xi = x[i]; - dst[i] = 0.5f * xi * - (1.0f + - sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi))); -} - -static void silu_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i])); -} - -static void gelu_quick_f32(const float *x, float *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const float GELU_QUICK_COEF = -1.702f; - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; - } - dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i]))); -} - -static void tanh_f32(const float *x, float *dst, int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; - } - dst[i] = sycl::tanh((float)(x[i])); -} - -static void relu_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::fmax((float)(x[i]), (float)0); -} - -static void hardsigmoid_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); -} - -static void hardswish_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f)); -} - -static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - if (i >= k) { - return; - } - dst[i] = sycl::fmax((float)(x[i]), (float)0) + - sycl::fmin((float)(x[i]), 0.0f) * negative_slope; -} - -static void sqr_f32(const float * x, float * dst, const int k, - const sycl::nd_item<3> &item_ct1) { - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - if (i >= k) { - return; - } - dst[i] = x[i] * x[i]; -} - -static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { - int index = item_ct1.get_local_id(0) + - item_ct1.get_group(0) * item_ct1.get_local_range(0); - if (index >= ne10 * ne11 * ne12 * ne13) { - return; - } - // operation - int i10 = index % ne10; - int i11 = (index / ne10) % ne11; - int i12 = (index / (ne10 * ne11)) % ne12; - int i13 = (index / (ne10 * ne11 * ne12)) % ne13; - - int i00 = i10 / sf0; - int i01 = i11 / sf1; - int i02 = i12 / sf2; - int i03 = i13 / sf3; - - dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); -} - -static void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02, - const sycl::nd_item<3> &item_ct1) { - int nidx = item_ct1.get_local_id(2) + - item_ct1.get_group(2) * item_ct1.get_local_range(2); - if (nidx >= ne0) { - return; - } - - // operation - int offset_dst = nidx + item_ct1.get_group(1) * ne0 + - item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1); - if (nidx < ne00 && item_ct1.get_group(1) < ne01 && - item_ct1.get_group(0) < ne02) { - int offset_src = nidx + item_ct1.get_group(1) * ne00 + - item_ct1.get_group(0) * ne00 * ne01; - dst[offset_dst] = x[offset_src]; - } else { - dst[offset_dst] = 0.0f; - } -} template static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded, @@ -365,7 +1305,7 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, zeros[i] = 0.f; qzeros[i] = 0; } - const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros; + const TC xi = ix < kx ? *(const TC *)&x[iy * kx + ix] : zeros; float sum = xi[0]; float amax = sycl::fabs(xi[0]); #pragma unroll @@ -926,6 +1866,9 @@ static void pool2d_nchw_kernel( switch (op) { case GGML_OP_POOL_AVG: res = 0; break; case GGML_OP_POOL_MAX: res = -FLT_MAX; break; + default: + res = (To) sycl::nan(uint32_t(0)); + break; } for (int i = bh; i < eh; i += 1) { @@ -944,6 +1887,9 @@ static void pool2d_nchw_kernel( switch (op) { case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break; case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break; + default: + res = (To) sycl::nan(uint32_t(0)); + break; } } } @@ -982,7 +1928,8 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr s3, nb01, nb02, nb03, s10, s11, s12, item_ct1); }); - (void) dst; + GGML_UNUSED(dst); + GGML_UNUSED(ctx); } template @@ -1020,299 +1967,8 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens }); } - (void) dst; -} - -template -struct bin_bcast_sycl { - template - void operator()(ggml_backend_sycl_context & ctx, - const struct ggml_tensor *src0, - const struct ggml_tensor *src1, struct ggml_tensor *dst, - const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd, - queue_ptr stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - int nr0 = ne10/ne0; - int nr1 = ne11/ne1; - int nr2 = ne12/ne2; - int nr3 = ne13/ne3; - - int nr[4] = { nr0, nr1, nr2, nr3 }; - - // collapse dimensions until first broadcast dimension - int64_t cne0[] = {ne0, ne1, ne2, ne3}; - int64_t cne1[] = {ne10, ne11, ne12, ne13}; - size_t cnb0[] = {nb0, nb1, nb2, nb3}; - size_t cnb1[] = {nb10, nb11, nb12, nb13}; - auto collapse = [](int64_t cne[]) { - cne[0] *= cne[1]; - cne[1] = cne[2]; - cne[2] = cne[3]; - cne[3] = 1; - }; - - auto collapse_nb = [](size_t cnb[], int64_t cne[]) { - cnb[1] *= cne[1]; - cnb[2] *= cne[2]; - cnb[3] *= cne[3]; - }; - - for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne0); - collapse(cne1); - } - } - { - int64_t ne0 = cne0[0]; - int64_t ne1 = cne0[1]; - int64_t ne2 = cne0[2]; - int64_t ne3 = cne0[3]; - - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - - size_t nb0 = cnb0[0]; - size_t nb1 = cnb0[1]; - size_t nb2 = cnb0[2]; - size_t nb3 = cnb0[3]; - - size_t nb10 = cnb1[0]; - size_t nb11 = cnb1[1]; - size_t nb12 = cnb1[2]; - size_t nb13 = cnb1[3]; - - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); - - size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s10 == 1); - - const int block_size = 128; - - int64_t hne0 = std::max(ne0/2LL, 1LL); - - sycl::range<3> block_dims(1, 1, 1); - block_dims[2] = std::min(hne0, block_size); - block_dims[1] = std::min( - ne1, block_size / (unsigned int)block_dims[2]); - block_dims[0] = std::min( - std::min( - ne2 * ne3, block_size / (unsigned int)block_dims[2] / - (unsigned int)block_dims[1]), - 64U); - - sycl::range<3> block_nums( - (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], - (ne1 + block_dims[1] - 1) / block_dims[1], - (hne0 + block_dims[2] - 1) / block_dims[2]); - - if (block_nums[0] > 65535) { - // this is the maximum number of blocks in z direction, fallback to 1D grid kernel - int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * - sycl::range<3>(1, 1, block_size), - sycl::range<3>(1, 1, block_size)), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast_unravel( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12, - s13, item_ct1); - }); - } - } else { - /* - DPCT1049:16: The work-group size passed to the SYCL kernel may - exceed the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if - needed. - */ - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, - ne2, ne3, ne10, ne11, ne12, ne13, - s1, s2, s3, s11, s12, s13, - item_ct1); - }); - } - } - } -}; - -static void acc_f32_sycl(const float *x, const float *y, float *dst, - const int n_elements, const int ne10, const int ne11, - const int ne12, const int nb1, const int nb2, - const int offset, queue_ptr stream) { - int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, - item_ct1); - }); -} - -static void gelu_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu_f32(x, dst, k, item_ct1); - }); -} - -static void silu_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - silu_f32(x, dst, k, item_ct1); - }); -} - -static void gelu_quick_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - gelu_quick_f32(x, dst, k, item_ct1); - }); -} - -static void tanh_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - tanh_f32(x, dst, k, item_ct1); - }); -} - -static void relu_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - relu_f32(x, dst, k, item_ct1); - }); -} - -static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - hardsigmoid_f32(x, dst, k, item_ct1); - }); -} - -static void hardswish_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - hardswish_f32(x, dst, k, item_ct1); - }); -} - -static void leaky_relu_f32_sycl(const float *x, float *dst, const int k, - const float negative_slope, - queue_ptr stream) { - const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - leaky_relu_f32(x, dst, k, negative_slope, item_ct1); - }); -} - -static void sqr_f32_sycl(const float *x, float *dst, const int k, - queue_ptr stream) { - const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE; - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * - sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - sqr_f32(x, dst, k, item_ct1); - }); -} - -static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, queue_ptr stream) { - int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; - sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { - upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); - }); -} - -static void pad_f32_sycl(const float *x, float *dst, const int ne00, - const int ne01, const int ne02, const int ne0, - const int ne1, const int ne2, queue_ptr stream) { - int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; - sycl::range<3> gridDim(ne2, ne1, num_blocks); - stream->parallel_for( - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), - sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { - pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1); - }); + GGML_UNUSED(dst); + GGML_UNUSED(ctx); } static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, @@ -1691,6 +2347,58 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, } } +static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, + const int nrows, queue_ptr stream) { + const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE); + const sycl::range<3> block_nums(1, nrows, 1); + const size_t shared_mem = 256 * sizeof(float); + + stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor shared_data( + sycl::range<1>(shared_mem/sizeof(float)), cgh); + sycl::local_accessor shared_indices( + sycl::range<1>(shared_mem/sizeof(float)), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + const int tid = item_ct1.get_local_id(2); + const int row = item_ct1.get_global_id(1); + + float max_val = -INFINITY; + int max_idx = -1; + + for (int col = tid; col < ncols; col += 256) { + float val = x[row * ncols + col]; + if (val > max_val) { + max_val = val; + max_idx = col; + } + } + + shared_data[tid] = max_val; + shared_indices[tid] = max_idx; + item_ct1.barrier(sycl::access::fence_space::local_space); + + for (int stride = 256/2; stride > 0; stride >>= 1) { + if (tid < stride) { + float val1 = shared_data[tid]; + float val2 = shared_data[tid + stride]; + if (val2 > val1) { + shared_data[tid] = val2; + shared_indices[tid] = shared_indices[tid + stride]; + } + } + item_ct1.barrier(sycl::access::fence_space::local_space); + } + + + if (tid == 0) { + dst[row] = shared_indices[0]; + } + }); + }); +} static void diag_mask_inf_f32_sycl(const float *x, float *dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, @@ -1706,296 +2414,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst, }); } -static bool g_sycl_loaded = false; - -bool ggml_sycl_loaded(void) { - return g_sycl_loaded; -} - -void print_device_detail(int id, sycl::device &device, std::string device_type) { - - dpct::device_info prop; - SYCL_CHECK(CHECK_TRY_ERROR( - dpct::get_device_info(prop, device))); - - std::string version; - version += std::to_string(prop.get_major_version()); - version += "."; - version += std::to_string(prop.get_minor_version()); - - device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), ""); - std::string name = std::string(prop.get_name()); - name = std::regex_replace(name, std::regex("\\(R\\)"), ""); - name = std::regex_replace(name, std::regex("\\(TM\\)"), ""); - - auto global_mem_size = prop.get_global_mem_size()/1000000; - - fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(), - name.c_str(), version.c_str(), prop.get_max_compute_units(), - prop.get_max_work_group_size(), prop.get_max_sub_group_size(), - global_mem_size, device.get_info().c_str()); -} - -void ggml_backend_sycl_print_sycl_devices() { - GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n"); - int device_count = dpct::dev_mgr::instance().device_count(); - std::map DeviceNums; - fprintf(stderr, "found %d SYCL devices:\n", device_count); - fprintf(stderr, "| | | | |Max | |Max |Global | |\n"); - fprintf(stderr, "| | | | |compute|Max work|sub |mem | |\n"); - fprintf(stderr, "|ID| Device Type| Name|Version|units |group |group|size | Driver version|\n"); - fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n"); - for (int id = 0; id < device_count; ++id) { - sycl::device device = dpct::dev_mgr::instance().get_device(id); - sycl::backend backend = device.get_backend(); - std::string backend_type = get_device_backend_and_type(device); - int type_id=DeviceNums[backend_type]++; - std::stringstream device_type; - device_type << "[" << backend_type << ":" << std::to_string(type_id) << "]"; - print_device_detail(id, device, device_type.str()); - } -} - -static inline int get_sycl_env(const char *env_name, int default_val) { - char *user_device_string = getenv(env_name); - int user_number = default_val; - - unsigned n; - if (user_device_string != NULL && - sscanf(user_device_string, " %u", &n) == 1) { - user_number = (int)n; - } else { - user_number = default_val; - } - return user_number; -} - -static void ggml_check_sycl() try { - static bool initialized = false; - - if (!initialized) { - fprintf(stderr, "[SYCL] call ggml_check_sycl\n"); - g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); - - fprintf(stderr, "%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug); - -#if defined(GGML_SYCL_F16) - fprintf(stderr, "%s: GGML_SYCL_F16: yes\n", __func__); -#else - fprintf(stderr, "%s: GGML_SYCL_F16: no\n", __func__); -#endif - -/* NOT REMOVE, keep it for next optimize for XMX. -#if defined(SYCL_USE_XMX) - fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__); -#else - fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); -#endif -*/ - - if (CHECK_TRY_ERROR(g_all_sycl_device_count = - dpct::dev_mgr::instance().device_count()) != 0) { - initialized = true; - g_sycl_loaded = false; - return; - } - GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES); - ggml_backend_sycl_print_sycl_devices(); - initialized = true; - g_sycl_loaded = true; - } -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static ggml_sycl_device_info ggml_sycl_init() { - ggml_sycl_device_info info = {}; - - info.device_count = dpct::dev_mgr::instance().device_count(); - if (info.device_count == 0) { - fprintf(stderr, "%s: failed to initialize " GGML_SYCL_NAME ": %s\n", __func__); - return info; - } - - GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES); - - int64_t total_vram = 0; -#if defined(GGML_SYCL_FORCE_MMQ) - fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ: yes\n", __func__); -#else - fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ: no\n", __func__); -#endif -#if defined(SYCL_USE_XMX) - fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__); -#else - fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); -#endif - fprintf(stderr, "%s: found %d " GGML_SYCL_NAME " devices:\n", __func__, info.device_count); - - for (int i = 0; i < info.device_count; ++i) { - info.devices[i].vmm = 0; - dpct::device_info prop; - SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( - prop, dpct::dev_mgr::instance().get_device(i)))); - - info.default_tensor_split[i] = total_vram; - total_vram += prop.get_global_mem_size(); - - info.devices[i].cc = - 100 * prop.get_major_version() + 10 * prop.get_minor_version(); - - info.max_work_group_sizes[i] = prop.get_max_work_group_size(); - } - - for (int id = 0; id < info.device_count; ++id) { - info.default_tensor_split[id] /= total_vram; - } - return info; -} - -const ggml_sycl_device_info & ggml_sycl_info() { - static ggml_sycl_device_info info = ggml_sycl_init(); - return info; -} - -/* -device_index: device index from 0 to n (continue numbers). - It is used for device select/set in SYCL backend internal data structure. -*/ -inline void check_allow_gpu_index(const int device_index) { - if (device_index >= ggml_sycl_info().device_count) { - char error_buf[256]; - snprintf( - error_buf, - sizeof(error_buf), - "%s error: device_index:%d is out of range: [0-%d]", - __func__, - device_index, - ggml_sycl_info().device_count - 1); - fprintf(stderr, "%s\n", error_buf); - assert(false); - } -} - -// buffer pool for sycl (legacy) -struct ggml_sycl_pool_leg : public ggml_sycl_pool { - static const int MAX_SYCL_BUFFERS = 256; - - int device; - queue_ptr qptr; - struct ggml_sycl_buffer { - void * ptr = nullptr; - size_t size = 0; - }; - - ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {}; - size_t pool_size = 0; - - explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : - qptr(qptr_), - device(device_) { - } - - ~ggml_sycl_pool_leg() { - for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { - ggml_sycl_buffer & b = buffer_pool[i]; - if (b.ptr != nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); - pool_size -= b.size; - } - } - GGML_ASSERT(pool_size == 0); - } - - void * alloc(size_t size, size_t * actual_size) override { -#ifdef DEBUG_sycl_MALLOC - int nnz = 0; - size_t max_size = 0; -#endif - size_t best_diff = 1ull << 36; - int ibest = -1; - for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { - ggml_sycl_buffer& b = buffer_pool[i]; - if (b.ptr != nullptr) { -#ifdef DEBUG_sycl_MALLOC - ++nnz; - if (b.size > max_size) max_size = b.size; -#endif - if (b.size >= size) { - size_t diff = b.size - size; - if (diff < best_diff) { - best_diff = diff; - ibest = i; - if (!best_diff) { - void * ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; - } - } - } - } - } - if (ibest >= 0) { - ggml_sycl_buffer& b = buffer_pool[ibest]; - void * ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; - } - void * ptr; - size_t look_ahead_size = (size_t) (1.05 * size); - - SYCL_CHECK( - CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( - look_ahead_size, *qptr))); - if (!ptr) { - fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, look_ahead_size); - return nullptr; - } - - *actual_size = look_ahead_size; - pool_size += look_ahead_size; - - #ifdef DEBUG_SYCL_MALLOC - fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz, - (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024)); - #endif - // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr); - return ptr; - } - - void free(void * ptr, size_t size) override { - for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { - ggml_sycl_buffer& b = buffer_pool[i]; - if (b.ptr == nullptr) { - b.ptr = ptr; - b.size = size; - return; - } - } - fprintf(stderr, "WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); - pool_size -= size; - } -}; - -std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { - // TBD: NO VMM support - // if (ggml_sycl_info().devices[device].vmm) { - // return std::unique_ptr(new ggml_sycl_pool_vmm(device)); - // } - return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); -} - -// TBD pool with virtual memory management -// struct ggml_sycl_pool_vmm : public ggml_sycl_pool - static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst, const struct ggml_tensor *src, int64_t i3, int64_t i2, @@ -2004,12 +2422,22 @@ static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst, dpct::memcpy_direction kind; char * src_ptr; - if (src->backend == GGML_BACKEND_TYPE_CPU) { + if (ggml_backend_buffer_is_host(src->buffer)) { kind = dpct::host_to_device; + //GGML_SYCL_DEBUG("%s: Host buffer type src tensor\n", __func__); src_ptr = (char *) src->data; // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr); - } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) { - GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1])); + } else if (ggml_backend_buffer_is_sycl(src->buffer)) { + // If buffer is a SYCL buffer + //GGML_SYCL_DEBUG("%s: SYCL buffer type src tensor\n", __func__); + kind = dpct::device_to_device; + src_ptr = (char *) src->data; + } else if (ggml_backend_buffer_is_sycl_split(src->buffer)) { + /* + If buffer is a SYCL split buffer + */ + //GGML_SYCL_DEBUG("%s: Split buffer type src tensor\n", __func__); + GGML_ASSERT(i1_low == 0 && i1_high == src->ne[1]); kind = dpct::device_to_device; ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra; int id; @@ -2105,39 +2533,12 @@ static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_te break; default: // TODO: k-quants - fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); + GGML_LOG_ERROR("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); GGML_ABORT("fatal error"); break; } } -template -inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, - (sycl::half *)dst_dd, main_stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd, - main_stream); - } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { - op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd, - main_stream); - } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { - op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd, - main_stream); - } else { - fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, - ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); - } -} static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -2147,282 +2548,10 @@ static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tens ggml_sycl_op_bin_bcast>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream); - (void) src1; - (void) src1_d; + GGML_UNUSED(src1); + GGML_UNUSED(src1_d); } -inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); -} - -inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported - - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes - - acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream); - - (void) dst; -} - -inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); -} - -inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); -} - -inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - float negative_slope; - memcpy(&negative_slope, dst->op_params, sizeof(float)); - - leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - const float sf0 = (float)dst->ne[0]/src0->ne[0]; - const float sf1 = (float)dst->ne[1]/src0->ne[1]; - const float sf2 = (float)dst->ne[2]/src0->ne[2]; - const float sf3 = (float)dst->ne[3]/src0->ne[3]; - - upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], - dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, - main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, - ggml_tensor *dst, const float *src0_dd, - const float *src1_dd, float *dst_dd, - const queue_ptr &main_stream) { - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors - - pad_f32_sycl(src0_dd, dst_dd, - src0->ne[0], src0->ne[1], src0->ne[2], - dst->ne[0], dst->ne[1], dst->ne[2], main_stream); - - (void) src1; - (void) dst; - (void) src1_dd; -} - -static int64_t get_row_rounding(ggml_type type, const std::array & tensor_split) { - int64_t min_compute_capability = INT_MAX; - int64_t max_compute_capability = INT_MIN; - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) { - if (min_compute_capability > ggml_sycl_info().devices[i].cc) { - min_compute_capability = ggml_sycl_info().devices[i].cc; - } - if (max_compute_capability < ggml_sycl_info().devices[i].cc) { - max_compute_capability = ggml_sycl_info().devices[i].cc; - } - } - } - - switch(type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - return max_compute_capability >= VER_GEN9 ? 128 : 64; - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - return 64; - case GGML_TYPE_F16: - case GGML_TYPE_F32: - return 1; - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - return max_compute_capability >= VER_GEN9 ? 128 : 64; - case GGML_TYPE_IQ3_S: - return max_compute_capability >= VER_GEN9 ? 128 : 64; - case GGML_TYPE_Q6_K: - return 64; - default: - GGML_ABORT("fatal error"); - } - -} inline void ggml_sycl_op_mul_mat_sycl( ggml_backend_sycl_context & ctx, @@ -2439,17 +2568,18 @@ inline void ggml_sycl_op_mul_mat_sycl( const int64_t ne00 = src0->ne[0]; const int64_t ne10 = src1->ne[0]; - const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); - +#if !GGML_SYCL_DNNL + const int64_t ne0 = dst->ne[0]; // the main device has a larger memory buffer to hold the results from all GPUs // ldc == nrows of the matrix that cuBLAS writes into int ldc = id == ctx.device ? ne0 : row_diff; +#endif #ifdef GGML_SYCL_F16 bool use_fp16 = true; // TODO(Yu) SYCL capability check @@ -2486,9 +2616,9 @@ inline void ggml_sycl_op_mul_mat_sycl( : src1_as_f16.get(); ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols); - const sycl::half alpha_f16 = 1.0f; - const sycl::half beta_f16 = 0.0f; #if !GGML_SYCL_DNNL + const sycl::half alpha_f16 = 1.0f; + const sycl::half beta_f16 = 0.0f; SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, @@ -2525,24 +2655,29 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); - const float alpha = 1.0f; - const float beta = 0.0f; #if !GGML_SYCL_DNNL + const float alpha = 1.0f; + const float beta = 0.0f; +# ifdef GGML_SYCL_NVIDIA SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, - dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, - src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), + oneapi::mkl::backend_selector{ *stream }, oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, + ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); +# else + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, + dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); +# endif #else auto dnnl_stream = ctx.stream_dnnl(stream); DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt()); #endif } - (void) dst; - (void) src1_ddq_i; - (void) src1_padded_row_size; + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_padded_row_size); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -2588,8 +2723,27 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens item_ct1); }); - (void) src1; - (void) src1_dd; + GGML_UNUSED(src1); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne = ggml_nelements(src0); + + sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); } inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, @@ -2606,9 +2760,10 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_te sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); - (void) src1; - (void) dst; - (void) src1_dd; + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); } inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, @@ -2627,9 +2782,30 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_ten argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream); - (void) src1; - (void) dst; - (void) src1_dd; + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); +} + +inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, + const ggml_tensor *src1, ggml_tensor *dst, + const float *src0_dd, const float *src1_dd, + float *dst_dd, + const queue_ptr &main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + argmax_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, main_stream); + + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); } inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, @@ -2649,9 +2825,10 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const gg diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); - (void) src1; - (void) dst; - (void) src1_dd; + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); } inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, @@ -2672,9 +2849,10 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tenso */ SYCL_CHECK(0); - (void) src1; - (void) dst; - (void) src1_dd; + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); } inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, @@ -2697,49 +2875,10 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tenso */ SYCL_CHECK(0); - (void) src1; - (void) dst; - (void) src1_dd; -} - -static void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const ggml_sycl_op_flatten_t op) try { - const int64_t nrows0 = ggml_nrows(src0); - - const bool use_src1 = src1 != nullptr; - const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; - - GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT); - GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT); - - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - - // dd = data device - float * src0_ddf = (float *) src0->data; - float * src1_ddf = use_src1 ? (float *) src1->data : nullptr; - float * dst_ddf = (float *) dst->data; - - ggml_sycl_pool_alloc src0_f(ctx.pool()); - ggml_sycl_pool_alloc src1_f(ctx.pool()); - ggml_sycl_pool_alloc dst_f(ctx.pool()); - - ggml_sycl_set_device(ctx.device); - queue_ptr main_stream = ctx.stream(); - // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n", - // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device); - - // do the computation - op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream); - // print_ggml_tensor("tensor", dst); -} -catch (sycl::exception const &exc) { - - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); } static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { @@ -2783,10 +2922,6 @@ static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) { peer_access_enabled = enable_peer_access; } -struct ggml_backend_sycl_split_buffer_type_context { - std::array tensor_split; -}; - static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, ggml_sycl_op_mul_mat_t op, @@ -2805,8 +2940,8 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; - GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT); - GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(dst->buffer)); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src1->buffer)); GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0); @@ -2820,14 +2955,13 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; const bool src0_is_contiguous = ggml_is_contiguous(src0); const bool src1_is_contiguous = ggml_is_contiguous(src1); int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); - const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT; + const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); GGML_ASSERT(!(split && ne02 > 1)); GGML_ASSERT(!(split && ne03 > 1)); GGML_ASSERT(!(split && ne02 < ne12)); @@ -3113,124 +3247,33 @@ catch (sycl::exception const &exc) { } -static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm); GGML_SYCL_DEBUG("call %s done\n", __func__); } -static void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -static void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - - -static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm); GGML_SYCL_DEBUG("call %s done\n", __func__); } @@ -3238,7 +3281,7 @@ static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const gg const ggml_tensor *src1, ggml_tensor *dst) try { GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation GGML_ASSERT(src0->type == GGML_TYPE_F16); @@ -3271,7 +3314,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); @@ -3333,12 +3376,11 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT); + GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_TENSOR_BINARY_OP_LOCALS - const int64_t ne_dst = ggml_nelements(dst); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); queue_ptr main_stream = ctx.stream();; @@ -3400,6 +3442,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); sycl::range<3> block_dims(1, ne12, ne13); /* @@ -3428,14 +3471,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, }); } SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::mkl::transpose::trans, - oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, - (const void **)(ptrs_src.get() + 0 * ne23), - dpct::library_data_t::real_half, nb01 / nb00, - (const void **)(ptrs_src.get() + 1 * ne23), - dpct::library_data_t::real_half, nb11 / nb10, beta, - (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, - cu_compute_type))); + *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, + (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, + (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); } } catch (sycl::exception const &exc) { @@ -3446,6 +3485,7 @@ catch (sycl::exception const &exc) { inline bool ggml_sycl_supports_mmq(enum ggml_type type) { // TODO: accuracy issues in MMQ + GGML_UNUSED(type); return false; } @@ -3496,8 +3536,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE - && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE); + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; @@ -3514,8 +3553,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // KQ single-batch - ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst); + // TODO: Refactor and cleanup of mul mat dispatching. + if (src0->ne[3] == 1 && src1->ne[3] == 1) { + // KQ single-batch + // mmv p021 was specific for these dimensions + ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst); + } else { + // The kernel from the if path is faster for that specific case, but does not support all mul mats. + ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); + } } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); @@ -3599,9 +3645,10 @@ __dpct_inline__ static void k_copy_dst_from_contiguous( } } -static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, +static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers"); const ggml_tensor *ids = dst->src[2]; @@ -3767,12 +3814,12 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale); +static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale); } -static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp); +static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp); } static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, @@ -3810,12 +3857,11 @@ static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *sr } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { - fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); GGML_ABORT("fatal error"); } - - (void) dst; + GGML_UNUSED(dst); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -3823,57 +3869,53 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { // TODO: why do we pass dst as src1 here? - ggml_sycl_cpy(ctx, src0, dst, nullptr); - (void) src1; + ggml_sycl_cpy(ctx, dst->src[0], dst, nullptr); } -static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf); +static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf); } -static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max); +static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope); } -static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope); +static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d); } -static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d); +static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col); } -static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col); +static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum); } -static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows); +static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows); } -static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(src0)); - ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort); +static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort); } -static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - (void) src0; - (void) src1; - (void) dst; +static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst->src[0])); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax); } -static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]); -} void ggml_sycl_set_main_device(const int main_device) try { - if (dpct::get_current_device_id() == main_device) return; + if (dpct::get_current_device_id() == static_cast (main_device)) { + return; + } check_allow_gpu_index(main_device); dpct::select_device(main_device); @@ -3881,7 +3923,7 @@ void ggml_sycl_set_main_device(const int main_device) try { dpct::device_info prop; SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, dpct::dev_mgr::instance().get_device(main_device)))); - fprintf(stderr, "Using device %d (%s) as main device\n", + GGML_LOG_INFO("Using device %d (%s) as main device\n", main_device, prop.get_name()); } } @@ -3891,187 +3933,198 @@ catch (sycl::exception const &exc) { std::exit(1); } -bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) { +bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) { if (!g_sycl_loaded) return false; - ggml_sycl_func_t func; + if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) { + ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device); + } - switch (tensor->op) { + switch (dst->op) { + case GGML_OP_ARGMAX: + ggml_sycl_argmax(ctx, dst); + break; case GGML_OP_CONV_TRANSPOSE_1D: - func = ggml_sycl_op_conv_transpose_1d; + ggml_sycl_op_conv_transpose_1d(ctx, dst); break; case GGML_OP_REPEAT: - func = ggml_sycl_repeat; + ggml_sycl_repeat(ctx, dst); break; case GGML_OP_GET_ROWS: - func = ggml_sycl_get_rows; + ggml_sycl_get_rows(ctx, dst); break; case GGML_OP_DUP: - func = ggml_sycl_dup; + ggml_sycl_dup(ctx, dst); break; case GGML_OP_ADD: - func = ggml_sycl_add; + case GGML_OP_ADD1: // TODO: more efficient implementation + ggml_sycl_add(ctx, dst); + break; + case GGML_OP_SUB: + ggml_sycl_sub(ctx, dst); break; case GGML_OP_ACC: - func = ggml_sycl_acc; + ggml_sycl_acc(ctx, dst); break; case GGML_OP_MUL: - func = ggml_sycl_mul; + ggml_sycl_mul(ctx, dst); + break; + case GGML_OP_LOG: + ggml_sycl_log(ctx, dst); break; case GGML_OP_DIV: - func = ggml_sycl_div; + ggml_sycl_div(ctx, dst); break; case GGML_OP_UNARY: - switch (ggml_get_unary_op(tensor)) { + switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_NEG: + ggml_sycl_neg(ctx, dst); + break; + case GGML_UNARY_OP_STEP: + ggml_sycl_step(ctx, dst); + break; case GGML_UNARY_OP_GELU: - func = ggml_sycl_gelu; + ggml_sycl_gelu(ctx, dst); break; case GGML_UNARY_OP_SILU: - func = ggml_sycl_silu; + ggml_sycl_silu(ctx, dst); break; case GGML_UNARY_OP_GELU_QUICK: - func = ggml_sycl_gelu_quick; + ggml_sycl_gelu_quick(ctx, dst); break; case GGML_UNARY_OP_TANH: - func = ggml_sycl_tanh; + ggml_sycl_tanh(ctx, dst); break; case GGML_UNARY_OP_RELU: - func = ggml_sycl_relu; + ggml_sycl_relu(ctx, dst); + break; + case GGML_UNARY_OP_SIGMOID: + ggml_sycl_sigmoid(ctx, dst); break; case GGML_UNARY_OP_HARDSIGMOID: - func = ggml_sycl_hardsigmoid; + ggml_sycl_hardsigmoid(ctx, dst); break; case GGML_UNARY_OP_HARDSWISH: - func = ggml_sycl_hardswish; + ggml_sycl_hardswish(ctx, dst); + break; + case GGML_UNARY_OP_EXP: + ggml_sycl_exp(ctx, dst); break; default: return false; } break; case GGML_OP_NORM: - func = ggml_sycl_norm; + ggml_sycl_norm(ctx, dst); break; case GGML_OP_GROUP_NORM: - func = ggml_sycl_group_norm; + ggml_sycl_group_norm(ctx, dst); break; case GGML_OP_CONCAT: - func = ggml_sycl_op_concat; + ggml_sycl_op_concat(ctx, dst); break; case GGML_OP_UPSCALE: - func = ggml_sycl_upscale; + ggml_sycl_upscale(ctx, dst); break; case GGML_OP_PAD: - func = ggml_sycl_pad; + ggml_sycl_pad(ctx, dst); break; case GGML_OP_LEAKY_RELU: - func = ggml_sycl_leaky_relu; + ggml_sycl_leaky_relu(ctx, dst); break; case GGML_OP_RMS_NORM: - func = ggml_sycl_rms_norm; + ggml_sycl_rms_norm(ctx, dst); break; case GGML_OP_MUL_MAT: - if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) { + if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { return false; } - func = ggml_sycl_mul_mat; + /* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */ + ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst); break; case GGML_OP_MUL_MAT_ID: - if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) { + if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { return false; } - func = ggml_sycl_mul_mat_id; + ggml_sycl_mul_mat_id(ctx, dst); + break; + case GGML_OP_OUT_PROD: + ggml_sycl_op_out_prod(ctx, dst); break; case GGML_OP_SCALE: - func = ggml_sycl_scale; + ggml_sycl_scale(ctx, dst); break; case GGML_OP_SQR: - func = ggml_sycl_sqr; + ggml_sycl_sqr(ctx, dst); + break; + case GGML_OP_SQRT: + ggml_sycl_sqrt(ctx, dst); + break; + case GGML_OP_SIN: + ggml_sycl_sin(ctx, dst); + break; + case GGML_OP_COS: + ggml_sycl_cos(ctx, dst); break; case GGML_OP_CLAMP: - func = ggml_sycl_clamp; + ggml_sycl_clamp(ctx, dst); break; case GGML_OP_CPY: - func = ggml_sycl_cpy; + ggml_sycl_cpy(ctx, dst->src[0], dst->src[1], dst); break; case GGML_OP_CONT: - func = ggml_sycl_dup; + ggml_sycl_dup(ctx, dst); break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - func = ggml_sycl_nop; + GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__); break; case GGML_OP_DIAG_MASK_INF: - func = ggml_sycl_diag_mask_inf; + ggml_sycl_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - func = ggml_sycl_soft_max; + ggml_sycl_op_soft_max(ctx, dst); break; case GGML_OP_ROPE: - func = ggml_sycl_rope; + ggml_sycl_rope(ctx, dst); break; case GGML_OP_IM2COL: - func = ggml_sycl_im2col; + ggml_sycl_im2col(ctx, dst); break; case GGML_OP_POOL_2D: - func = ggml_sycl_pool2d; + ggml_sycl_pool2d(ctx, dst); + break; + case GGML_OP_SUM: + ggml_sycl_sum(ctx, dst); break; case GGML_OP_SUM_ROWS: - func = ggml_sycl_sum_rows; + ggml_sycl_sum_rows(ctx, dst); break; case GGML_OP_ARGSORT: - func = ggml_sycl_argsort; + ggml_sycl_argsort(ctx, dst); break; case GGML_OP_TIMESTEP_EMBEDDING: - func = ggml_sycl_op_timestep_embedding; + ggml_sycl_op_timestep_embedding(ctx, dst); + break; + case GGML_OP_RWKV_WKV6: + ggml_sycl_op_rwkv_wkv6(ctx, dst); + break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_sycl_op_gated_linear_attn(ctx, dst); break; default: return false; } - if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) { - ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device); - } - - func(ctx, tensor->src[0], tensor->src[1], tensor); return true; } -GGML_API GGML_CALL void ggml_sycl_get_gpu_list(int *id_list, int max_len) try { - GGML_SYCL_DEBUG("[SYCL] call ggml_sycl_get_gpu_list\n"); - for(int i=0;i=max_len) break; - id_list[i] = i; - } - return; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -int ggml_sycl_get_device_count() try { - int device_count; - if (CHECK_TRY_ERROR(device_count = - dpct::dev_mgr::instance().device_count()) != 0) { - return 0; - } - return device_count; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -GGML_API GGML_CALL void ggml_sycl_get_device_description(int device, char *description, +GGML_API void ggml_backend_sycl_get_device_description(int device, char *description, size_t description_size) try { - GGML_SYCL_DEBUG("[SYCL] call ggml_sycl_get_device_description\n"); + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_description\n"); dpct::device_info prop; SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, dpct::dev_mgr::instance().get_device(device)))); @@ -4083,7 +4136,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -GGML_CALL void ggml_backend_sycl_get_device_memory(int device, size_t *free, +void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total) try { GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n"); ggml_sycl_set_device(device); @@ -4109,815 +4162,23 @@ catch (sycl::exception const &exc) { //////////////////////////////////////////////////////////////////////////////// -// backend interface - -#define UNUSED GGML_UNUSED - -// sycl buffer - -struct ggml_backend_sycl_buffer_context { - int device; - void * dev_ptr = nullptr; - queue_ptr stream; - std::string name; - - ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) : - device(device), dev_ptr(dev_ptr), stream(stream) { - check_allow_gpu_index(device); - name = (GGML_SYCL_NAME + std::to_string(device)); - } - - - ~ggml_backend_sycl_buffer_context() { - if (dev_ptr != nullptr) { - ggml_sycl_set_device(device); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); - } - } -}; - -GGML_CALL static const char * ggml_backend_sycl_buffer_get_name(ggml_backend_buffer_t buffer) { - ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context; - return ctx->name.c_str(); -} - -GGML_CALL static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) { - return buffer->iface.get_name == ggml_backend_sycl_buffer_get_name; -} - -static void -ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { - ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; - ggml_sycl_set_device(ctx->device); - - delete ctx; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) { - ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; - return ctx->dev_ptr; -} - -GGML_CALL static void -ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, - ggml_tensor *tensor) try { - ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context; - - if (tensor->view_src != NULL && tensor->view_offs == 0) { - assert(tensor->view_src->buffer->buft == buffer->buft); - tensor->backend = tensor->view_src->backend; - tensor->extra = tensor->view_src->extra; - return; - } - - - if (ggml_is_quantized(tensor->type)) { - // initialize padding to 0 to avoid possible NaN values - size_t original_size = ggml_nbytes(tensor); - size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); - - if (padded_size > original_size && tensor->view_src == nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset( - (char *)tensor->data + original_size, 0, - padded_size - original_size).wait())); - } - } -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, - ggml_tensor *tensor, - const void *data, size_t offset, - size_t size) try { - - ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; - - ggml_sycl_set_device(ctx->device); - auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); - SYCL_CHECK( - CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); - char* host_buf = (char*)malloc(size); - memcpy(host_buf, data, size); - SYCL_CHECK( - CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size) - .wait())); - free(host_buf); -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer, - const ggml_tensor *tensor, - void *data, size_t offset, - size_t size) try { - - ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; - - ggml_sycl_set_device(ctx->device); - auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue(); - - SYCL_CHECK(CHECK_TRY_ERROR( - stream.memcpy(data, (const char *)tensor->data + offset, size) - .wait())); -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -GGML_CALL static bool -ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, - const ggml_tensor *src, - ggml_tensor *dst) try { - if (ggml_backend_buffer_is_sycl(src->buffer)) { - ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context; - ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context; - - ggml_sycl_set_device(src_ctx->device); - /* - DPCT1009:198: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR( - dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw())); - ggml_sycl_set_device(dst_ctx->device); - /* - DPCT1009:199: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR( - dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw())); - /* - DPCT1009:200: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - - queue_ptr stream_dst = dst_ctx->stream; - queue_ptr stream_src = src_ctx->stream; - size_t size = ggml_nbytes(src); - - //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs. - dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size); - -//todo, it's known issue:error in device2device cross GPUs. reused when the issue is fixed. DON"T remove -#if 0 - SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy( - (char *)dst->data, (const char *)src->data, size).wait())); - - /* - DPCT1009:201: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR( - dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw())); -#endif - return true; - } - return false; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - - -static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, - uint8_t value) try { - ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; - - ggml_sycl_set_device(ctx->device); - queue_ptr stream = ctx->stream; - SYCL_CHECK( - CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); - - SYCL_CHECK(CHECK_TRY_ERROR((*stream) - .memset(ctx->dev_ptr, value, buffer->size) - .wait())); -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { - /* .get_name = */ ggml_backend_sycl_buffer_get_name, - /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer, - /* .get_base = */ ggml_backend_sycl_buffer_get_base, - /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor, - /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, - /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, - /* .clear = */ ggml_backend_sycl_buffer_clear, - /* .reset = */ NULL, -}; - -// sycl buffer type -struct ggml_backend_sycl_buffer_type_context { - int device; - std::string name; - - // each buffer type has its own stream - queue_ptr stream = nullptr; -}; - -GGML_CALL static const char * ggml_backend_sycl_buffer_type_name(ggml_backend_buffer_type_t buft) { - ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; - - return ctx->name.c_str(); -} -GGML_CALL static ggml_backend_buffer_t -ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, - size_t size) try { - ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; - ggml_sycl_set_device(buft_ctx->device); - const queue_ptr stream = buft_ctx->stream; - size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 - - void * dev_ptr; - SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( - size, *stream))); - if (!dev_ptr) { - fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, size); - return nullptr; - } - ggml_backend_sycl_buffer_context * ctx = new ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream); - return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size); -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -GGML_CALL static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; - UNUSED(buft); -} - -static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { - return dpct::get_current_device().get_max_mem_alloc_size(); - - UNUSED(buft); -} - -GGML_CALL static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - size_t size = ggml_nbytes(tensor); - int64_t ne0 = tensor->ne[0]; - - if (ggml_is_quantized(tensor->type)) { - if (ne0 % MATRIX_ROW_PADDING != 0) { - size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); - } - } - - return size; - - UNUSED(buft); -} - -static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = { - /* .get_name = */ ggml_backend_sycl_buffer_type_name, - /* .alloc_buffer = */ ggml_backend_sycl_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_sycl_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_sycl_buffer_type_get_max_size, - /* .get_alloc_size = */ ggml_backend_sycl_buffer_type_get_alloc_size, - /* .is_host = */ nullptr, -}; - -ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) { - static std::mutex mutex; - std::lock_guard lock(mutex); - - GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n"); - - if (device>=ggml_sycl_info().device_count or device<0) { - printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", - device, ggml_sycl_info().device_count-1); - GGML_ASSERT(devicedevice; - if (device>=ggml_sycl_info().device_count or device<0) { - printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n", - device, ggml_sycl_info().device_count-1); - GGML_ASSERT(devicestream(i, 0)}, - }; - } - ggml_backend_sycl_buffer_type_initialized = true; - } - return &ggml_backend_sycl_buffer_types[device]; -} - -// sycl split buffer type -static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array & tensor_split, int id) { - const int64_t nrows = ggml_nrows(tensor); - const int64_t rounding = get_row_rounding(tensor->type, tensor_split); - - *row_low = id == 0 ? 0 : nrows*tensor_split[id]; - *row_low -= *row_low % rounding; - if (id == ggml_sycl_info().device_count - 1) { - *row_high = nrows; - } else { - *row_high = nrows*tensor_split[id + 1]; - *row_high -= *row_high % rounding; - } -} - -struct ggml_backend_sycl_split_buffer_context { - ~ggml_backend_sycl_split_buffer_context() try { - for (ggml_tensor_extra_gpu * extra : tensor_extras) { - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { - if (extra->events[i][is] != nullptr) { - /* - DPCT1009:206: SYCL uses exceptions to report errors and - does not use the error codes. The original code was - commented out and a warning string was inserted. You - need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR( - dpct::destroy_event(extra->events[i][is]))); - } - } - if (extra->data_device[i] != nullptr) { - /* - DPCT1009:207: SYCL uses exceptions to report errors and does - not use the error codes. The original code was commented out - and a warning string was inserted. You need to rewrite this - code. - */ - ggml_sycl_set_device(i); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free( - extra->data_device[i], *(streams[i])))); - } - } - delete extra; - } - } - catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); - } - - std::vector tensor_extras; - std::vector streams; -}; - -GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backend_buffer_t buffer) { - return GGML_SYCL_NAME "_Split"; - - UNUSED(buffer); -} - -static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) { - return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name; -} - -GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; - delete ctx; -} - -GGML_CALL static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) { - // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced - return (void *)0x1000; - - UNUSED(buffer); -} - -GGML_CALL static void -ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, - ggml_tensor *tensor) try { - GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported - - ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; - ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; - - const int64_t ne0 = tensor->ne[0]; - - ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; - - ctx->tensor_extras.push_back(extra); - ctx->streams.push_back(&(dpct::get_current_device().default_queue())); - - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - int64_t row_low, row_high; - get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); - - int64_t nrows_split = row_high - row_low; - if (nrows_split == 0) { - continue; - } - - size_t size = ggml_nbytes_split(tensor, nrows_split); - const size_t original_size = size; - - // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses - if (ne0 % MATRIX_ROW_PADDING != 0) { - size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); - } - - // FIXME: do not crash if cudaMalloc fails - // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first - ggml_sycl_set_device(i); - const queue_ptr stream = ctx->streams[i]; - char * buf; - /* - DPCT1009:208: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( - size, *stream))); - if (!buf) { - char err_buf[1024]; - snprintf(err_buf, 1023, "%s: can't malloc %lu Bytes memory on device", __func__, size); - throw std::runtime_error(err_buf); - } - // set padding to 0 to avoid possible NaN values - if (size > original_size) { - /* - DPCT1009:209: SYCL uses exceptions to report errors and does not use - the error codes. The original code was commented out and a warning - string was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR( - (*stream) - .memset(buf + original_size, 0, size - original_size) - .wait())); - } - - extra->data_device[i] = buf; - - for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { - /* - DPCT1009:210: SYCL uses exceptions to report errors and does not use - the error codes. The original code was commented out and a warning - string was inserted. You need to rewrite this code. - */ - SYCL_CHECK( - CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event())); - } - } - tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT; - tensor->extra = extra; -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -GGML_CALL static void -ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer, - ggml_tensor *tensor, const void *data, - size_t offset, size_t size) try { - // split tensors must always be set in their entirety at once - GGML_ASSERT(offset == 0); - GGML_ASSERT(size == ggml_nbytes(tensor)); - - ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; - ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; - - const int64_t ne0 = tensor->ne[0]; - const size_t nb1 = tensor->nb[1]; - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; - - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - int64_t row_low, row_high; - get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); - - int64_t nrows_split = row_high - row_low; - if (nrows_split == 0) { - continue; - } - - const size_t offset_split = row_low*nb1; - size_t size = ggml_nbytes_split(tensor, nrows_split); - const size_t original_size = size; - - // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses - if (ne0 % MATRIX_ROW_PADDING != 0) { - size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); - } - - const char * buf_host = (const char *)data + offset_split; - /* - DPCT1009:211: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - ggml_sycl_set_device(i); - const queue_ptr stream = ctx->streams[i]; - SYCL_CHECK(CHECK_TRY_ERROR( - (*stream) - .memcpy(extra->data_device[i], buf_host, original_size) - .wait())); - } -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -GGML_CALL static void -ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer, - const ggml_tensor *tensor, void *data, - size_t offset, size_t size) try { - // split tensors must always be set in their entirety at once - GGML_ASSERT(offset == 0); - GGML_ASSERT(size == ggml_nbytes(tensor)); - - ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; - ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context; - - const int64_t ne0 = tensor->ne[0]; - const size_t nb1 = tensor->nb[1]; - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; - - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - int64_t row_low, row_high; - get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i); - - int64_t nrows_split = row_high - row_low; - if (nrows_split == 0) { - continue; - } - - const size_t offset_split = row_low*nb1; - size_t size = ggml_nbytes_split(tensor, nrows_split); - const size_t original_size = size; - - // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses - if (ne0 % MATRIX_ROW_PADDING != 0) { - size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); - } - - char * buf_host = (char *)data + offset_split; - /* - DPCT1009:212: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - ggml_sycl_set_device(i); - const queue_ptr stream = ctx->streams[i]; - SYCL_CHECK(CHECK_TRY_ERROR( - (*stream) - .memcpy(buf_host, extra->data_device[i], original_size) - .wait())); - } -} -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} - -GGML_CALL static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - UNUSED(buffer); - UNUSED(value); -} - -static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = { - /* .get_name = */ ggml_backend_sycl_split_buffer_get_name, - /* .free_buffer = */ ggml_backend_sycl_split_buffer_free_buffer, - /* .get_base = */ ggml_backend_sycl_split_buffer_get_base, - /* .init_tensor = */ ggml_backend_sycl_split_buffer_init_tensor, - /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor, - /* .cpy_tensor = */ NULL, - /* .clear = */ ggml_backend_sycl_split_buffer_clear, - /* .reset = */ NULL, -}; - -GGML_CALL static const char * ggml_backend_sycl_split_buffer_type_name(ggml_backend_buffer_type_t buft) { - return GGML_SYCL_NAME "_Split"; - - UNUSED(buft); -} - -GGML_CALL static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point - // instead, we allocate them for each tensor separately in init_tensor - // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated, - // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct. - ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context(); - - return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size); -} - -GGML_CALL static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; - UNUSED(buft); -} - -GGML_CALL static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context; - - size_t total_size = 0; - - const int64_t ne0 = tensor->ne[0]; - - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - int64_t row_low, row_high; - get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i); - - int64_t nrows_split = row_high - row_low; - if (nrows_split == 0) { - continue; - } - - total_size += ggml_nbytes_split(tensor, nrows_split); - - // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses - if (ne0 % MATRIX_ROW_PADDING != 0) { - total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); - } - } - - return total_size; -} - -GGML_CALL static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return false; - - UNUSED(buft); -} - -static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = { - /* .get_name = */ ggml_backend_sycl_split_buffer_type_name, - /* .alloc_buffer = */ ggml_backend_sycl_split_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_sycl_split_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX - /* .get_alloc_size = */ ggml_backend_sycl_split_buffer_type_get_alloc_size, - /* .is_host = */ ggml_backend_sycl_split_buffer_type_is_host, -}; - -GGML_CALL ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) { - static std::mutex mutex; - std::lock_guard lock(mutex); - - GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n"); - ggml_check_sycl(); - // FIXME: this is not thread safe - static std::map, struct ggml_backend_buffer_type> buft_map; - - std::array tensor_split_arr = {}; - - bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; }); - if (all_zero) { - tensor_split_arr = ggml_sycl_info().default_tensor_split; - } else { - float split_sum = 0.0f; - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - tensor_split_arr[i] = split_sum; - split_sum += tensor_split[i]; - } - for (int i = 0; i < ggml_sycl_info().device_count; ++i) { - tensor_split_arr[i] /= split_sum; - } - } - - auto it = buft_map.find(tensor_split_arr); - if (it != buft_map.end()) { - return &it->second; - } - - struct ggml_backend_buffer_type buft { - /* .iface = */ ggml_backend_sycl_split_buffer_type_interface, - /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr}, - }; - - auto result = buft_map.emplace(tensor_split_arr, buft); - return &result.first->second; -} - -// host buffer type - -GGML_CALL static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) { - return GGML_SYCL_NAME "_Host"; - - UNUSED(buft); -} - -GGML_CALL static const char * ggml_backend_sycl_host_buffer_name(ggml_backend_buffer_t buffer) { - return GGML_SYCL_NAME "_Host"; - - UNUSED(buffer); -} - -static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_sycl_host_free(buffer->context); -} - -static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - void * ptr = ggml_sycl_host_malloc(size); - - if (ptr == nullptr) { - // fallback to cpu buffer - return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - } - - // FIXME: this is a hack to avoid having to implement a new buffer type - ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); - buffer->buft = buft; - buffer->iface.get_name = ggml_backend_sycl_host_buffer_name; - buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer; - - return buffer; -} - -ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() { - GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n"); - static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = { - /* .iface = */ { - /* .get_name = */ ggml_backend_sycl_host_buffer_type_name, - /* .alloc_buffer = */ ggml_backend_sycl_host_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, - /* .get_max_size = */ NULL, // TODO: return device.maxBufferLength - /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, - /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, - }, - /* .context = */ nullptr, - }; - - return &ggml_backend_sycl_buffer_type_host; -} - // backend -GGML_CALL static const char * ggml_backend_sycl_name(ggml_backend_t backend) { +static const char * ggml_backend_sycl_get_name(ggml_backend_t backend) { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; return sycl_ctx->name.c_str(); } -GGML_CALL static void ggml_backend_sycl_free(ggml_backend_t backend) { +static void ggml_backend_sycl_free(ggml_backend_t backend) { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; delete sycl_ctx; delete backend; } - -GGML_CALL static ggml_backend_buffer_type_t ggml_backend_sycl_get_default_buffer_type(ggml_backend_t backend) { - ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; - return ggml_backend_sycl_buffer_type(sycl_ctx->device); -} - -GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, +static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, ggml_tensor *tensor, const void *data, size_t offset, size_t size) try { @@ -4926,8 +4187,8 @@ GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); - SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( - (char *)tensor->data + offset, data, size).wait())); + SYCL_CHECK(CHECK_TRY_ERROR( + (stream)->memcpy((char *)tensor->data + offset, data, size))); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -4935,7 +4196,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, +static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, const ggml_tensor *tensor, void *data, size_t offset, size_t size) try { @@ -4953,9 +4214,9 @@ catch (sycl::exception const &exc) { std::exit(1); } -GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, - const ggml_tensor *src, - ggml_tensor *dst) try { +static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, + const ggml_tensor *src, + ggml_tensor *dst) try { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) { /* @@ -4982,7 +4243,7 @@ static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try { const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait())); - UNUSED(backend); + GGML_UNUSED(backend); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -4990,7 +4251,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { +static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_sycl_set_main_device(sycl_ctx->device); @@ -5010,7 +4271,7 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back #endif bool ok = ggml_sycl_compute_forward(*sycl_ctx, node); if (!ok) { - fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); } @@ -5018,7 +4279,148 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back return GGML_STATUS_SUCCESS; } -GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_tensor * op) { +static void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event) +try +{ + ggml_backend_sycl_context *sycl_ctx = + (ggml_backend_sycl_context *)backend->context; + + sycl::event *sycl_event = static_cast(event->context); + + const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0); + // Record the current state of the queue + SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier())); +} +catch (sycl::exception const &exc) +{ + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try { + + sycl::event* sycl_event = static_cast(event->context); + + if (ggml_backend_is_sycl(backend)) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait())); + } else + GGML_ABORT("fatal error"); +} catch (sycl::exception const& exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static ggml_backend_i ggml_backend_sycl_interface = { + /* .get_name = */ ggml_backend_sycl_get_name, + /* .free = */ ggml_backend_sycl_free, + /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async, + // // TODO: update for the new + // interface + /* .synchronize = */ ggml_backend_sycl_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_sycl_graph_compute, + /* .event_record = */ ggml_backend_sycl_event_record, + /* .event_wait = */ ggml_backend_sycl_event_wait, +}; + +static ggml_guid_t ggml_backend_sycl_guid() { + static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 }; + return &guid; +} + +bool ggml_backend_is_sycl(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid()); +} + +int ggml_backend_sycl_get_device_count() { + GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n"); + return ggml_sycl_info().device_count; +} + + +// backend device + +struct ggml_backend_sycl_device_context { + int device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + ggml_sycl_set_device(ctx->device); + SYCL_CHECK(CHECK_TRY_ERROR( + dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total))); +} + +static enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_sycl_device_get_name(dev); + props->description = ggml_backend_sycl_device_get_description(dev); + props->type = ggml_backend_sycl_device_get_type(dev); + ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total); + + bool host_buffer = getenv("GGML_SYCL_NO_PINNED") == nullptr; +#ifdef GGML_SYCL_NO_PEER_COPY + bool events = false; +#else + bool events = true; +#endif + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ host_buffer, + /* .buffer_from_host_ptr = */ false, + /* .events = */ events, + }; +} + +static ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + return ggml_backend_sycl_init(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context; + return ggml_backend_sycl_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return ggml_backend_sycl_host_buffer_type(); +} + +static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + GGML_UNUSED(dev); + GGML_UNUSED(ptr); + GGML_UNUSED(size); + GGML_UNUSED(max_tensor_size); + return nullptr; +} + +static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { switch (op->op) { case GGML_OP_CONV_TRANSPOSE_1D: { @@ -5031,13 +4433,17 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons } break; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]); default: return false; @@ -5074,6 +4480,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons } return true; } break; + case GGML_OP_OUT_PROD: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { @@ -5119,10 +4527,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; - int dim = op->op_params[0]; - return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; case GGML_OP_DUP: + case GGML_OP_ARGMAX: case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_REPEAT: @@ -5131,20 +4539,40 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons case GGML_OP_TRANSPOSE: case GGML_OP_NORM: case GGML_OP_ADD: + case GGML_OP_ADD1: + case GGML_OP_LOG: + case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: case GGML_OP_CLAMP: + return true; case GGML_OP_CONT: + return op->src[0]->type != GGML_TYPE_BF16; case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: return true; case GGML_OP_ROPE: - return ggml_is_contiguous(op->src[0]); + { + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } + return ggml_is_contiguous(op->src[0]); + } case GGML_OP_IM2COL: + // TODO: add support for the new F32 operations + return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_2D: + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: @@ -5153,58 +4581,200 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons case GGML_OP_PAD: case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: return true; default: return false; } - UNUSED(backend); + GGML_UNUSED(dev); } -GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) { - const int min_batch_size = 32; - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID; - GGML_UNUSED(backend); -} - -GGML_CALL static bool ggml_backend_sycl_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (buft->iface.get_name != ggml_backend_sycl_buffer_type_name) { +static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) { return false; } ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context; - ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; + ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context; return buft_ctx->device == sycl_ctx->device; } -static ggml_backend_i ggml_backend_sycl_interface = { - /* .get_name = */ ggml_backend_sycl_name, - /* .free = */ ggml_backend_sycl_free, - /* .get_default_buffer_type = */ ggml_backend_sycl_get_default_buffer_type, - /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, - /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, - /* .cpy_tensor_async = */ NULL, //ggml_backend_sycl_cpy_tensor_async, // TODO: update for the new interface - /* .synchronize = */ ggml_backend_sycl_synchronize, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_sycl_graph_compute, - /* .supports_op = */ ggml_backend_sycl_supports_op, - /* .supports_buft = */ ggml_backend_sycl_supports_buft, - /* .offload_op = */ ggml_backend_sycl_offload_op, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, -}; - -static ggml_guid_t ggml_backend_sycl_guid() { - static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 }; - return &guid; +static int64_t get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + return 0; + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + return op->ne[2]; + default: + return ggml_nrows(op); + } } -GGML_CALL ggml_backend_t ggml_backend_sycl_init(int device) { +static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + return get_op_batch_size(op) >= min_batch_size; + GGML_UNUSED(dev); +} + +static ggml_backend_event_t +ggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) { + +#ifdef GGML_SYCL_NO_PEER_COPY + return nullptr; +#else + sycl::event *event_ptr = new sycl::event(); + + return new ggml_backend_event{ + /* .device = */ dev, + /* .context = */ event_ptr, + }; +#endif +} + +static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try { + GGML_UNUSED(dev); + if (event == nullptr) { + return; + } + + if (event->context != nullptr) { + sycl::event *sycl_event = static_cast(event->context); + delete sycl_event; + event->context = nullptr; + } + + delete event; +} catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + + +static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try { + GGML_UNUSED(dev); + + sycl::event *sycl_event = static_cast(event->context); + SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait())); +} catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} + +static const ggml_backend_device_i ggml_backend_sycl_device_interface = { + /* .get_name = */ ggml_backend_sycl_device_get_name, + /* .get_description = */ ggml_backend_sycl_device_get_description, + /* .get_memory = */ ggml_backend_sycl_device_get_memory, + /* .get_type = */ ggml_backend_sycl_device_get_type, + /* .get_props = */ ggml_backend_sycl_device_get_props, + /* .init_backend = */ ggml_backend_sycl_device_init, + /* .get_buffer_type = */ ggml_backend_sycl_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_sycl_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ ggml_backend_sycl_device_buffer_from_host_ptr, + /* .supports_op = */ ggml_backend_sycl_device_supports_op, + /* .supports_buft = */ ggml_backend_sycl_device_supports_buft, + /* .offload_op = */ ggml_backend_sycl_device_offload_op, + /* .event_new = */ ggml_backend_sycl_device_event_new, + /* .event_free = */ ggml_backend_sycl_device_event_free, + /* .event_synchronize = */ ggml_backend_sycl_device_event_synchronize, +}; + +// backend reg + +struct ggml_backend_sycl_reg_context { + std::vector devices; +}; + +static const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return GGML_SYCL_NAME; +} + +static size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) { + ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context; + return ctx->devices.size(); +} + +static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; +} + +static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) { + GGML_UNUSED(reg); + + if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { + return (void *)ggml_backend_sycl_split_buffer_type; + } + + // SYCL doesn't support registering host memory, left here for reference + // "ggml_backend_register_host_buffer" + // "ggml_backend_unregister_host_buffer" + GGML_UNUSED(name); + return nullptr; +} + +static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = { + /* .get_name = */ ggml_backend_sycl_reg_get_name, + /* .get_device_count = */ ggml_backend_sycl_reg_get_device_count, + /* .get_device = */ ggml_backend_sycl_reg_get_device, + /* .get_proc_address = */ ggml_backend_sycl_reg_get_proc_address, +}; + + +// backend registry + +ggml_backend_reg_t ggml_backend_sycl_reg() { + static ggml_backend_reg reg; + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context; + + for (int i = 0; i < ggml_sycl_info().device_count; i++) { + ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context; + dev_ctx->device = i; + dev_ctx->name = GGML_SYCL_NAME + std::to_string(i); + + ggml_sycl_set_device(i); + + dpct::device_info prop; + SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( + prop, dpct::dev_mgr::instance().get_device(i)))); + + dev_ctx->description = prop.get_name(); + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_sycl_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx + }; + ctx->devices.push_back(dev); + } + + reg = ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_sycl_reg_interface, + /* .context = */ ctx + }; + } + + initialized = true; + } + + return ® +} + +ggml_backend_t ggml_backend_sycl_init(int device) { GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_init\n"); ggml_check_sycl(); @@ -5212,43 +4782,18 @@ GGML_CALL ggml_backend_t ggml_backend_sycl_init(int device) { ggml_backend_sycl_context * ctx = new ggml_backend_sycl_context(device); if (ctx == nullptr) { - fprintf(stderr, "%s: error: failed to allocate context\n", __func__); + GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); return nullptr; }; ggml_backend_t sycl_backend = new ggml_backend { /* .guid = */ ggml_backend_sycl_guid(), /* .interface = */ ggml_backend_sycl_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device), /* .context = */ ctx }; return sycl_backend; } -bool ggml_backend_is_sycl(ggml_backend_t backend) { - return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid()); -} - -GGML_CALL int ggml_backend_sycl_get_device_count() { - GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n"); - return ggml_sycl_info().device_count; -} - -GGML_CALL static ggml_backend_t ggml_backend_reg_sycl_init(const char * params, void * user_data) { - ggml_backend_t sycl_backend = ggml_backend_sycl_init((int) (intptr_t) user_data); - return sycl_backend; - - UNUSED(params); -} - -extern "C" int ggml_backend_sycl_reg_devices(); - -int ggml_backend_sycl_reg_devices() { - assert(ggml_sycl_info().device_count>0); - for (int i = 0; i < ggml_sycl_info().device_count; i++) { - char name[128]; - snprintf(name, sizeof(name), "%s%d", GGML_SYCL_NAME, i); - ggml_backend_register(name, ggml_backend_reg_sycl_init, ggml_backend_sycl_buffer_type(i), (void *) (intptr_t) i); - } - return ggml_sycl_info().device_count; -} +GGML_BACKEND_DL_IMPL(ggml_backend_sycl_reg) diff --git a/ggml/src/ggml-sycl/gla.cpp b/ggml/src/ggml-sycl/gla.cpp new file mode 100644 index 000000000..eedb47486 --- /dev/null +++ b/ggml/src/ggml-sycl/gla.cpp @@ -0,0 +1,105 @@ +#include + +#include "common.hpp" + +template +static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, u_int T, u_int C, u_int H, float scale, + const float * k, const float * v, const float * r, const float * td, + const float * s, float * dst) { + const u_int head_size = HEAD_SIZE; + const u_int state_size = C * head_size; + const u_int n_seq_tokens = T / B; + sycl::range<1> block_dims((C / H)); + sycl::range<1> grid_dims((B * H)); + stream->submit([&](sycl::handler & cgh) { + /* local memory accessors*/ + auto _k = sycl::local_accessor(sycl::range<1>(head_size), cgh); + auto _r = sycl::local_accessor(sycl::range<1>(head_size), cgh); + auto _td = sycl::local_accessor(sycl::range<1>(head_size), cgh); + + cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) { + u_int tid = item.get_local_id(0); + u_int bid = item.get_group(0); + + u_int batch_i = bid / H; + u_int head_i = bid % H; + + float state[head_size]; + +#pragma unroll + for (u_int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + for (u_int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + + item.barrier(sycl::access::fence_space::local_space); //sync threads + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + item.barrier(sycl::access::fence_space::local_space); //sync threads + + const float _v = v[t]; + float y = 0; + + for (u_int j = 0; j < head_size; j += 4) { + const sycl::float4 & k = (sycl::float4 &) (_k[j]); + const sycl::float4 & r = (sycl::float4 &) (_r[j]); + const sycl::float4 & td = (sycl::float4 &) (_td[j]); + sycl::float4 & s = (sycl::float4 &) (state[j]); + sycl::float4 kv; + + kv.x() = k.x() * _v; + kv.y() = k.y() * _v; + kv.z() = k.z() * _v; + kv.w() = k.w() * _v; + + s.x() = s.x() * td.x() + kv.x(); + s.y() = s.y() * td.y() + kv.y(); + s.z() = s.z() * td.z() + kv.z(); + s.w() = s.w() * td.w() + kv.w(); + + y += r.x() * s.x(); + y += r.y() * s.y(); + y += r.z() * s.z(); + y += r.w() * s.w(); + } + dst[t] = y * scale; + } +#pragma unroll + for (u_int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } + }); + }); +} + +void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const float * k_d = static_cast(dst->src[0]->data); + const float * v_d = static_cast(dst->src[1]->data); + const float * r_d = static_cast(dst->src[2]->data); + const float * td_d = static_cast(dst->src[3]->data); + const float * s_d = static_cast(dst->src[4]->data); + + const int64_t B = dst->src[4]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + dpct::queue_ptr stream = ctx.stream(); + GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64 || C / H == 128); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + float * dst_d = (float *) dst->data; + + if (C / H == 64) { + gated_linear_attn_f32_kernel<64>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } else { + gated_linear_attn_f32_kernel<128>(stream, B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } +} diff --git a/ggml/src/ggml-sycl/gla.hpp b/ggml/src/ggml-sycl/gla.hpp new file mode 100644 index 000000000..607cf3a7f --- /dev/null +++ b/ggml/src/ggml-sycl/gla.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_GLA_HPP +#define GGML_SYCL_GLA_HPP + +#include "common.hpp" + +void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_GLA_HPP diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 6a0a0fcd0..6146a99ed 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -120,6 +120,7 @@ void ggml_sycl_op_im2col( im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream); } - (void) src0; - (void) src0_dd; + GGML_UNUSED(src0); + GGML_UNUSED(src0_dd); + GGML_UNUSED(ctx); } diff --git a/ggml/src/ggml-sycl/mmq.cpp b/ggml/src/ggml-sycl/mmq.cpp index e952533d3..8ea82c940 100644 --- a/ggml/src/ggml-sycl/mmq.cpp +++ b/ggml/src/ggml-sycl/mmq.cpp @@ -813,7 +813,7 @@ load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql, x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + constexpr int blocks_per_tile_x_row = QI4_K > WARP_SIZE ? 1 : WARP_SIZE / QI4_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll @@ -961,7 +961,7 @@ load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql, x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + constexpr int blocks_per_tile_x_row = QI5_K > WARP_SIZE ? 1 : WARP_SIZE / QI5_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll @@ -1109,7 +1109,7 @@ load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql, dpct::sub_sat()); } - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + constexpr int blocks_per_tile_x_row = QI6_K > WARP_SIZE ? 1 : WARP_SIZE / QI6_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 float * x_dmf = (float *) x_dm; @@ -3020,9 +3020,9 @@ void ggml_sycl_op_mul_mat_q( break; } - (void) src1; - (void) dst; - (void) src1_ddf_i; + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddf_i); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 1b96925e1..221f65c21 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1,6 +1,6 @@ #include "mmvq.hpp" #include "vecdotq.hpp" - +#include template static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, @@ -13,7 +13,8 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -37,7 +38,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -61,7 +62,8 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -85,7 +87,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -109,8 +111,8 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -133,7 +135,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -157,8 +159,8 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -181,7 +183,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -205,8 +207,8 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -229,7 +231,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -253,8 +255,8 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -277,7 +279,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -301,8 +303,8 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -325,7 +327,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -349,8 +351,8 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -373,7 +375,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -397,8 +399,8 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -421,7 +423,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -446,8 +448,8 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, } const int blocks_per_row = ncols / qk; - const int blocks_per_warp = vdr * WARP_SIZE / qi; - + const int blocks_per_warp = vdr * QK_WARP_SIZE / qi; + assert(blocks_per_warp>0); // partial sum for each thread float tmp = 0.0f; @@ -470,7 +472,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, // sum up partial sums and write back result #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); } @@ -487,7 +489,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -495,7 +497,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -511,7 +513,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK4_1 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -519,7 +521,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -535,7 +537,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK5_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -543,7 +545,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -559,7 +561,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK5_1 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -567,7 +569,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -583,7 +585,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK8_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -591,7 +593,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -607,7 +609,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -615,7 +617,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -631,7 +633,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -639,7 +641,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -655,7 +657,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -663,7 +665,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -679,7 +681,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -687,7 +689,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -703,7 +705,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { @@ -711,7 +713,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q( vx, vy, dst, ncols, nrows, item_ct1); @@ -728,13 +730,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq2_xxs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -749,17 +751,13 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { - - stream->submit([&](sycl::handler &cgh) { - auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0]; - auto ksigns64_ptr_ct1 = &ksigns64[0]; - + stream->submit([&](sycl::handler & cgh) { cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq2_xs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -774,17 +772,14 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { - auto iq2xs_grid_ptr_ct1 = &iq2xs_grid[0]; - auto ksigns64_ptr_ct1 = &ksigns64[0]; - cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq2_s_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -799,17 +794,14 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { - auto iq3xxs_grid_ptr_ct1 = &iq3xxs_grid[0]; - auto ksigns64_ptr_ct1 = &ksigns64[0]; - cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq3_xxs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -824,16 +816,14 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { - auto iq3s_grid_ptr_ct1 = &iq3s_grid[0]; - cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq3_s_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -848,17 +838,14 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { - auto iq1s_grid_ptr_ct1 = &iq1s_grid_gpu[0]; - auto ksigns64_ptr_ct1 = &ksigns64[0]; - cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq1_s_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -873,13 +860,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq1_m_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -894,14 +881,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK4_NL == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq4_nl_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -916,14 +903,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE); { stream->submit([&](sycl::handler &cgh) { cgh.parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[intel::reqd_sub_group_size(WARP_SIZE)]] { + [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] { mul_mat_vec_q_iq4_xs_q8_1( vx, vy, dst, ncols, nrows, item_ct1); }); @@ -952,7 +939,7 @@ void ggml_sycl_op_mul_mat_vec_q( const size_t q8_1_bs = QK8_1; // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = id == ctx.device ? ne00 : row_diff; + for (int i = 0; i < src1_ncols; i++) { const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs; @@ -1021,7 +1008,8 @@ void ggml_sycl_op_mul_mat_vec_q( break; } } - (void) src1; - (void) dst; - (void) src1_ddf_i; + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddf_i); + GGML_UNUSED(ctx); } diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index b3159b9d1..9cf2be155 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep const int nthreads = item_ct1.get_local_range(2); const int nwarps = nthreads / WARP_SIZE; - assert(nwarps % WARP_SIZE == 0); sycl::float2 mean_var = sycl::float2(0.f, 0.f); for (int col = tid; col < ncols; col += block_size) { @@ -32,7 +31,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep */ item_ct1.barrier(sycl::access::fence_space::local_space); mean_var = 0.f; - int nreduce = nwarps / WARP_SIZE; + size_t nreduce = nwarps / WARP_SIZE; for (size_t i = 0; i < nreduce; i += 1) { mean_var += s_sum[lane_id + i * WARP_SIZE]; @@ -55,9 +54,8 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con int end = start + group_size; const int nthreads = item_ct1.get_local_range(2); const int nwarps = nthreads / WARP_SIZE; - assert(nwarps % WARP_SIZE == 0); start += item_ct1.get_local_id(2); - int nreduce = nwarps / WARP_SIZE; + size_t nreduce = nwarps / WARP_SIZE; if (end >= ne_elements) { end = ne_elements; @@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa const int tid = item_ct1.get_local_id(2); const int nthreads = item_ct1.get_local_range(2); const int nwarps = nthreads / WARP_SIZE; - assert(nwarps % WARP_SIZE == 0); float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { @@ -166,7 +163,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa converged control flow. You may need to adjust the code. */ item_ct1.barrier(sycl::access::fence_space::local_space); - int nreduce = nwarps / WARP_SIZE; + size_t nreduce = nwarps / WARP_SIZE; tmp = 0.f; for (size_t i = 0; i < nreduce; i += 1) { @@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols, } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); const sycl::range<3> block_dims(1, 1, work_group_size); /* DPCT1049:17: The work-group size passed to the SYCL kernel may exceed @@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst, } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); const sycl::range<3> block_dims(1, 1, work_group_size); /* DPCT1049:18: The work-group size passed to the SYCL kernel may exceed @@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); const sycl::range<3> block_dims(1, 1, work_group_size); /* DPCT1049:19: The work-group size passed to the SYCL kernel may exceed @@ -352,6 +352,7 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* (void)src1; (void)dst; (void)src1_dd; + GGML_UNUSED(ctx); } void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp new file mode 100644 index 000000000..8e8347ff4 --- /dev/null +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -0,0 +1,56 @@ +#include +#include +#include "outprod.hpp" + + +void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_TENSOR_BINARY_OP_LOCALS + + // Get SYCL queue + dpct::queue_ptr stream = ctx.stream(); + + // Dimension checks + GGML_ASSERT(ne01 == ne11); // Inner dimensions must match + GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows + GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols + + // Get data pointers + const float* src0_d = (const float*)src0->data; + const float* src1_d = (const float*)src1->data; + float* dst_d = (float*)dst->data; + + // GEMM parameters + const float alpha = 1.0f; + const float beta = 0.0f; + + // Handle transposition of src1 + const bool src1_T = ggml_is_transposed(src1); + const oneapi::mkl::transpose src1_op = + src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans; + const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); + + try { + // Perform matrix multiplication using oneMKL GEMM +#ifdef GGML_SYCL_NVIDIA + oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector{ *stream }, + oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, + ne00, src1_d, ldb, beta, dst_d, ne0); +#else + oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, + src0_d, ne00, src1_d, ldb, beta, dst_d, ne0); +#endif + } + catch (sycl::exception const& exc) { + std::cerr << exc.what() << std::endl; + GGML_ASSERT(false); + } +} diff --git a/ggml/src/ggml-sycl/outprod.hpp b/ggml/src/ggml-sycl/outprod.hpp new file mode 100644 index 000000000..f50413d3f --- /dev/null +++ b/ggml/src/ggml-sycl/outprod.hpp @@ -0,0 +1,10 @@ +#ifndef GGML_SYCL_OUTPROD_HPP +#define GGML_SYCL_OUTPROD_HPP + +#include "common.hpp" + +void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + + +#endif // GGML_SYCL_OUTPROD_HPP + diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp index 340ab8e93..af1890727 100644 --- a/ggml/src/ggml-sycl/presets.hpp +++ b/ggml/src/ggml-sycl/presets.hpp @@ -25,6 +25,11 @@ #define SYCL_RELU_BLOCK_SIZE 256 #define SYCL_HARDSIGMOID_BLOCK_SIZE 256 #define SYCL_HARDSWISH_BLOCK_SIZE 256 +#define SYCL_EXP_BLOCK_SIZE 256 +#define SYCL_NEG_BLOCK_SIZE 256 +#define SYCL_SIGMOID_BLOCK_SIZE 256 +#define SYCL_SQRT_BLOCK_SIZE 256 +#define SYCL_SIN_BLOCK_SIZE 256 #define SYCL_SQR_BLOCK_SIZE 256 #define SYCL_CPY_BLOCK_SIZE 32 #define SYCL_SCALE_BLOCK_SIZE 256 @@ -41,6 +46,7 @@ #define SYCL_ACC_BLOCK_SIZE 256 #define SYCL_IM2COL_BLOCK_SIZE 256 #define SYCL_POOL2D_BLOCK_SIZE 256 +#define SYCL_ARGMAX_BLOCK_SIZE 256 #define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256 #define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index 1f06f78fa..1244b231a 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -269,7 +269,8 @@ void ggml_sycl_op_rope( } } - (void) src1; - (void) dst; - (void) src1_dd; + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_dd); + GGML_UNUSED(ctx); } diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 17a542e49..563e0655f 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -1,7 +1,7 @@ -#include "norm.hpp" +#include "softmax.hpp" -template -static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par, +template +static void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) { const int ncols = ncols_template == 0 ? ncols_par : ncols_template; @@ -16,7 +16,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; const int nthreads = block_size; const int nwarps = nthreads / WARP_SIZE; - int nreduce = nwarps / WARP_SIZE; + size_t nreduce = nwarps / WARP_SIZE; float slope = 1.0f; // ALiBi @@ -29,7 +29,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const slope = sycl::pow(base, float(exp)); } - float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols; + float *vals = vals_smem ? buf + sycl::max(nwarps, WARP_SIZE) : dst + rowx * ncols; float max_val = -INFINITY; for (int col0 = 0; col0 < ncols; col0 += block_size) { @@ -42,7 +42,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const const int ix = rowx*ncols + col; const int iy = rowy*ncols + col; - const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f); + const float val = x[ix]*scale + (mask ? slope*static_cast(mask[iy]) : 0.0f); vals[col] = val; max_val = sycl::max(max_val, val); @@ -53,8 +53,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const if (block_size > WARP_SIZE) { if (warp_id == 0) { buf[lane_id] = -INFINITY; - for (size_t i = 1; i < nreduce; i += 1) + for (size_t i = 1; i < nreduce; i += 1) { buf[lane_id + i * WARP_SIZE] = -INFINITY; + } } item_ct1.barrier(sycl::access::fence_space::local_space); @@ -63,9 +64,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const } item_ct1.barrier(sycl::access::fence_space::local_space); max_val = buf[lane_id]; - for (size_t i = 1; i < nreduce; i += 1) - { - max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]); + for (size_t i = 1; i < nreduce; i += 1) { + max_val = sycl::max(max_val, buf[lane_id + i * WARP_SIZE]); } max_val = warp_reduce_max(max_val, item_ct1); } @@ -89,8 +89,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); if (warp_id == 0) { buf[lane_id] = 0.f; - for (size_t i = 1; i < nreduce; i += 1) + for (size_t i = 1; i < nreduce; i += 1) { buf[lane_id + i * WARP_SIZE] = 0.f; + } } item_ct1.barrier(sycl::access::fence_space::local_space); @@ -100,8 +101,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const item_ct1.barrier(sycl::access::fence_space::local_space); tmp = buf[lane_id]; - for (size_t i = 1; i < nreduce; i += 1) - { + for (size_t i = 1; i < nreduce; i += 1) { tmp += buf[lane_id + i * WARP_SIZE]; } tmp = warp_reduce_sum(tmp, item_ct1); @@ -122,8 +122,8 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const } } -template -static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par, +template +static void soft_max_f32_submitter(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims, const size_t n_local_scratch, queue_ptr stream) { @@ -141,7 +141,8 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float * }); } -static void soft_max_f32_sycl(const float * x, const float * mask, +template +static void soft_max_f32_sycl(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, queue_ptr stream, int device) { @@ -223,22 +224,16 @@ static void soft_max_f32_sycl(const float * x, const float * mask, } } -void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream) { +void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); -#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support") -#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021") - GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + GGML_ASSERT(!dst->src[1] || dst->src[1]->type == GGML_TYPE_F16 || dst->src[1]->type == GGML_TYPE_F32); // src1 contains mask and it is optional - const int64_t ne00 = src0->ne[0]; - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; + const int64_t ne00 = dst->src[0]->ne[0]; + const int64_t nrows_x = ggml_nrows(dst->src[0]); + const int64_t nrows_y = dst->src[0]->ne[1]; float scale = 1.0f; float max_bias = 0.0f; @@ -246,6 +241,21 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *s memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, - nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + const float * src0_dd = static_cast(dst->src[0]->data); + float * dst_dd = static_cast(dst->data); + + ggml_sycl_set_device(ctx.device); + dpct::queue_ptr main_stream = ctx.stream(); + + if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { + const sycl::half * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, + main_stream, ctx.device); + } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { + const float * src1_dd = static_cast(dst->src[1]->data); + soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } else { + /* mask unavailable */ + soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); + } } diff --git a/ggml/src/ggml-sycl/softmax.hpp b/ggml/src/ggml-sycl/softmax.hpp index bdb8f712e..2cf8582ec 100644 --- a/ggml/src/ggml-sycl/softmax.hpp +++ b/ggml/src/ggml-sycl/softmax.hpp @@ -15,10 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst, - const float *src0_dd, const float *src1_dd, - float *dst_dd, - const queue_ptr &main_stream); +void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, ggml_tensor *dst); #endif // GGML_SYCL_SOFTMAX_HPP diff --git a/ggml/src/ggml-sycl/tsembd.cpp b/ggml/src/ggml-sycl/tsembd.cpp index d5c227cd1..b877d18c1 100644 --- a/ggml/src/ggml-sycl/tsembd.cpp +++ b/ggml/src/ggml-sycl/tsembd.cpp @@ -55,8 +55,9 @@ static void timestep_embedding_f32_sycl( }); } -void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor * dst) { +void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; dpct::queue_ptr stream = ctx.stream(); @@ -68,4 +69,5 @@ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml const int max_period = dst->op_params[1]; timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream); + GGML_UNUSED(src1); } diff --git a/ggml/src/ggml-sycl/tsembd.hpp b/ggml/src/ggml-sycl/tsembd.hpp index ff854c337..4c18748bb 100644 --- a/ggml/src/ggml-sycl/tsembd.hpp +++ b/ggml/src/ggml-sycl/tsembd.hpp @@ -15,7 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor * dst); +void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_TSEMBD_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index d2dccade2..c5942008a 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -968,8 +968,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq, grid1[0] ^ signs[0], signs[0], std::minus<>()); const int grid_h = dpct::vectorized_binary( grid2[0] ^ signs[1], signs[1], std::minus<>()); - sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi); - sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi); + sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi); q8 += 8; aux32 >>= 7; } @@ -1009,8 +1009,8 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq, grid1[0] ^ signs0, signs0, std::minus<>()); const int grid_h = dpct::vectorized_binary( grid2[0] ^ signs1, signs1, std::minus<>()); - sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi); - sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi); + sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi); q8 += 8; } const float d = diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp new file mode 100644 index 000000000..b54c20964 --- /dev/null +++ b/ggml/src/ggml-sycl/wkv6.cpp @@ -0,0 +1,143 @@ +#include +#include "wkv6.hpp" + +constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE + +// Helper function for the main kernel +static void rwkv_wkv_f32_kernel( + const int B, const int T, const int C, const int H, + const float* k, const float* v, const float* r, + const float* tf, const float* td, const float* s, + float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) { + + const int tid = item_ct1.get_local_id(2); + const int bid = item_ct1.get_group(2); + + const int head_size = WKV_BLOCK_SIZE; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + // Set up shared memory pointers + float* _k = shared_mem; + float* _r = _k + head_size; + float* _tf = _r + head_size; + float* _td = _tf + head_size; + + // Local state array + float state[WKV_BLOCK_SIZE]; + + // Load initial state + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + // Sync threads before shared memory operations + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Load time-mixing parameters + _tf[tid] = tf[head_i * head_size + tid]; + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Main sequence processing loop + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; + t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; + t += C) { + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Load current timestep data to shared memory + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + + item_ct1.barrier(sycl::access::fence_space::local_space); + + const float _v = v[t]; + float y = 0; + + // Process in chunks of 4 for better vectorization + sycl::float4 k4, r4, tf4, td4, s4; + #pragma unroll + for (int j = 0; j < head_size; j += 4) { + // Load data in vec4 chunks + k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); + + // Compute key-value product + sycl::float4 kv4 = k4 * _v; + + // Accumulate weighted sum + y += sycl::dot(r4, tf4 * kv4 + s4); + + // Update state + s4 = s4 * td4 + kv4; + + // Store updated state + state[j] = s4.x(); + state[j+1] = s4.y(); + state[j+2] = s4.z(); + state[j+3] = s4.w(); + } + + dst[t] = y; + } + + // Save final state + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + const float* k_d = (const float*)dst->src[0]->data; + const float* v_d = (const float*)dst->src[1]->data; + const float* r_d = (const float*)dst->src[2]->data; + const float* tf_d = (const float*)dst->src[3]->data; + const float* td_d = (const float*)dst->src[4]->data; + const float* s_d = (const float*)dst->src[5]->data; + float* dst_d = (float*)dst->data; + + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64 + + dpct::queue_ptr stream = ctx.stream(); + + // Calculate execution configuration + const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td + sycl::range<3> block_dims(1, 1, C / H); + sycl::range<3> grid_dims(1, 1, B * H); + + // Submit kernel + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor shared_mem_acc(shared_mem_size, cgh); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + rwkv_wkv_f32_kernel( + B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, + item_ct1, (float*)shared_mem_acc.get_multi_ptr().get() + ); + }); + }); + + GGML_UNUSED(src0); + GGML_UNUSED(src1); +} diff --git a/ggml/src/ggml-sycl/wkv6.hpp b/ggml/src/ggml-sycl/wkv6.hpp new file mode 100644 index 000000000..8c596a997 --- /dev/null +++ b/ggml/src/ggml-sycl/wkv6.hpp @@ -0,0 +1,9 @@ +#ifndef GGML_SYCL_WKV6_HPP +#define GGML_SYCL_WKV6_HPP + +#include "common.hpp" + +void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + + +#endif // GGML_SYCL_WKV6_HPP diff --git a/ggml/src/ggml-threading.cpp b/ggml/src/ggml-threading.cpp new file mode 100644 index 000000000..25a19eedb --- /dev/null +++ b/ggml/src/ggml-threading.cpp @@ -0,0 +1,12 @@ +#include "ggml-threading.h" +#include + +std::mutex ggml_critical_section_mutex; + +void ggml_critical_section_start() { + ggml_critical_section_mutex.lock(); +} + +void ggml_critical_section_end(void) { + ggml_critical_section_mutex.unlock(); +} diff --git a/ggml/src/ggml-threading.h b/ggml/src/ggml-threading.h new file mode 100644 index 000000000..dec2c8840 --- /dev/null +++ b/ggml/src/ggml-threading.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml.h" + +#ifdef __cplusplus +extern "C" { +#endif + +GGML_API void ggml_critical_section_start(void); +GGML_API void ggml_critical_section_end(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt new file mode 100644 index 000000000..d970f7e20 --- /dev/null +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -0,0 +1,162 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + +find_package(Vulkan COMPONENTS glslc REQUIRED) + +function(detect_host_compiler) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + else() + find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + endif() + set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE) + set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) +endfunction() + +if (Vulkan_FOUND) + message(STATUS "Vulkan found") + + ggml_add_backend_library(ggml-vulkan + ggml-vulkan.cpp + ../../include/ggml-vulkan.h + ) + + # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") + message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") + else() + message(STATUS "GL_KHR_cooperative_matrix supported by glslc") + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + endif() + + # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") + message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") + else() + message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + endif() + + target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) + target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + + # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build + # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) + endif() + + if (GGML_VULKAN_CHECK_RESULTS) + add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + endif() + + if (GGML_VULKAN_DEBUG) + add_compile_definitions(GGML_VULKAN_DEBUG) + endif() + + if (GGML_VULKAN_MEMORY_DEBUG) + add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) + endif() + + if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + endif() + + if (GGML_VULKAN_PERF) + add_compile_definitions(GGML_VULKAN_PERF) + endif() + + if (GGML_VULKAN_VALIDATE) + add_compile_definitions(GGML_VULKAN_VALIDATE) + endif() + + if (GGML_VULKAN_RUN_TESTS) + add_compile_definitions(GGML_VULKAN_RUN_TESTS) + endif() + + if (NOT CMAKE_CROSSCOMPILING) + add_subdirectory(vulkan-shaders) + if (MSVC) + foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) + string(TOUPPER ${CONFIG} CONFIG) + set_target_properties(vulkan-shaders-gen PROPERTIES + RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() + endif() + else() + if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) + set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) + else() + detect_host_compiler() + if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER) + message(FATAL_ERROR "Host compiler not found") + else() + message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}") + endif() + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) + set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) + endif() + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") + + include(ExternalProject) + # Native build through ExternalProject_Add + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} + -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + BUILD_COMMAND ${CMAKE_COMMAND} --build . + INSTALL_COMMAND ${CMAKE_COMMAND} --install . + INSTALL_DIR ${CMAKE_BINARY_DIR} + ) + ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + endif() + set (_ggml_vk_host_suffix $,.exe,>) + set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix}) + set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) + set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) + set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) + set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) + + file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") + set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen) + + if (CMAKE_CROSSCOMPILING) + set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) + endif() + + add_custom_command( + OUTPUT ${_ggml_vk_header} + ${_ggml_vk_source} + + COMMAND ${_ggml_vk_genshaders_cmd} + --glslc ${Vulkan_GLSLC_EXECUTABLE} + --input-dir ${_ggml_vk_input_dir} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + --target-cpp ${_ggml_vk_source} + --no-clean + + DEPENDS ${_ggml_vk_shader_deps} + COMMENT "Generate vulkan shaders" + ) + + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header}) + +else() + message(WARNING "Vulkan not found") +endif() diff --git a/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in new file mode 100644 index 000000000..b6af747a5 --- /dev/null +++ b/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in @@ -0,0 +1,15 @@ +set(CMAKE_BUILD_TYPE Release) +set(CMAKE_C_FLAGS -O2) +set(CMAKE_CXX_FLAGS -O2) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) +set(CMAKE_C_COMPILER @HOST_C_COMPILER@) +set(CMAKE_CXX_COMPILER @HOST_CXX_COMPILER@) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@) + +if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC") + foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() +endif() diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp similarity index 60% rename from ggml/src/ggml-vulkan.cpp rename to ggml/src/ggml-vulkan/ggml-vulkan.cpp index d6f647c89..9ca3959ab 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1,7 +1,8 @@ #include "ggml-vulkan.h" #include -#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS) #include +#include "ggml-cpu.h" #endif #include @@ -20,14 +21,14 @@ #include #include #include +#include +#include -#include "ggml.h" +#include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-vulkan-shaders.hpp" -#define VK_API_VERSION VK_API_VERSION_1_2 - #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) #define VK_VENDOR_ID_AMD 0x1002 @@ -41,12 +42,6 @@ #define MAX_VK_BUFFERS 256 -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 1 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - #define VK_CHECK(err, msg) \ do { \ vk::Result err_ = (err); \ @@ -90,6 +85,10 @@ struct vk_pipeline_struct { uint32_t parameter_count; std::array wg_denoms; uint32_t align; + // set to true to request the pipeline is compiled after the dryrun + bool needed {}; + // set to true when the shader has been compiled + bool compiled {}; }; typedef std::shared_ptr vk_pipeline; @@ -104,6 +103,15 @@ struct vk_matmul_pipeline_struct { typedef std::shared_ptr vk_matmul_pipeline; +struct vk_matmul_pipeline2 { + vk_matmul_pipeline2() { + f16acc = std::make_shared(); + f32acc = std::make_shared(); + } + vk_matmul_pipeline f32acc; + vk_matmul_pipeline f16acc; +}; + struct vk_device_struct; typedef std::shared_ptr vk_device; typedef std::weak_ptr vk_device_ref; @@ -117,11 +125,11 @@ struct ggml_backend_vk_buffer_type_context { vk_device device; }; -GGML_CALL static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); -GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); -GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); -GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); -GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { /* .get_name = */ ggml_backend_vk_buffer_type_name, /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, @@ -139,6 +147,8 @@ class vk_perf_logger; #endif static void ggml_vk_destroy_buffer(vk_buffer& buf); +static constexpr uint32_t mul_mat_vec_max_cols = 8; + struct vk_device_struct { std::mutex mutex; @@ -147,33 +157,60 @@ struct vk_device_struct { std::string name; uint64_t max_memory_allocation_size; bool fp16; + bool pipeline_robustness; vk::Device device; uint32_t vendor_id; vk_queue compute_queue; vk_queue transfer_queue; bool single_queue; uint32_t subgroup_size; + uint32_t shader_core_count; bool uma; + bool float_controls_rte_fp16; + + bool subgroup_size_control; + uint32_t subgroup_min_size; + uint32_t subgroup_max_size; + bool subgroup_require_full_support; + + bool coopmat_support; + bool coopmat_acc_f32_support; + bool coopmat_acc_f16_support; + uint32_t coopmat_m; + uint32_t coopmat_n; + uint32_t coopmat_k; + bool coopmat2; size_t idx; - vk_matmul_pipeline pipeline_matmul_f32; - vk_matmul_pipeline pipeline_matmul_f32_f16; - vk_matmul_pipeline pipeline_matmul_f16; - vk_matmul_pipeline pipeline_matmul_f16_f32; + bool mul_mat_l; + bool mul_mat_m; + bool mul_mat_s; + bool mul_mat_id_l; + bool mul_mat_id_m; + bool mul_mat_id_s; + + // set to true to indicate that some shaders need to be compiled after the dryrun + bool need_compiles {}; + + vk_matmul_pipeline pipeline_matmul_f32 {}; + vk_matmul_pipeline pipeline_matmul_f32_f16 {}; + vk_matmul_pipeline2 pipeline_matmul_f16; + vk_matmul_pipeline2 pipeline_matmul_f16_f32; vk_pipeline pipeline_matmul_split_k_reduce; - vk_matmul_pipeline pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; - vk_matmul_pipeline pipeline_matmul_id_f32; - vk_matmul_pipeline pipeline_matmul_id_f16; - vk_matmul_pipeline pipeline_matmul_id_f16_f32; + vk_matmul_pipeline pipeline_matmul_id_f32 {}; + vk_matmul_pipeline2 pipeline_matmul_id_f16; + vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; - vk_matmul_pipeline pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; - vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT]; - vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; @@ -181,9 +218,10 @@ struct vk_device_struct { vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; - vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16; - vk_pipeline pipeline_mul_f32; - vk_pipeline pipeline_div_f32; + vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; + vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; + vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; + vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_scale_f32; @@ -194,6 +232,9 @@ struct vk_device_struct { vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_repeat_f32; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; + vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; @@ -205,12 +246,23 @@ struct vk_device_struct { vk_pipeline pipeline_tanh_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; + vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; vk_pipeline pipeline_argsort_f32; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; + vk_pipeline pipeline_pool2d_f32; + vk_pipeline pipeline_rwkv_wkv6_f32; + + // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; std::unordered_map pipelines; std::unordered_map pipeline_descriptor_set_requirements; @@ -323,6 +375,43 @@ struct vk_mat_vec_id_push_constants { uint32_t nei0; uint32_t ne11; }; +struct vk_flash_attn_push_constants { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; +}; + struct vk_op_push_constants { uint32_t KX; uint32_t KY; @@ -334,16 +423,55 @@ struct vk_op_unary_push_constants { uint32_t ne; uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; - uint32_t d_offset; + uint32_t misalign_offsets; float param1; float param2; + uint32_t ne0_012mp; uint32_t ne0_012L; + uint32_t ne0_01mp; uint32_t ne0_01L; + uint32_t ne0_0mp; uint32_t ne0_0L; + uint32_t ne1_012mp; uint32_t ne1_012L; + uint32_t ne1_01mp; uint32_t ne1_01L; + uint32_t ne1_0mp; uint32_t ne1_0L; }; +static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L) +{ + // compute L = ceil(log2(d)); + L = 0; + while (L < 32 && (uint32_t{1} << L) < d) { + L++; + } + + mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1); +} + +template void init_pushconst_fastdiv(T &p) { + GGML_UNUSED(p); + static_assert(!std::is_const::value, "unexpected type"); +} + +template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) { + // Compute magic values to divide by these six numbers. + init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L); + init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L); + init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L); + init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L); + init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L); + init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L); +} struct vk_op_binary_push_constants { uint32_t ne; uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; - uint32_t d_offset; + uint32_t misalign_offsets; float param1; float param2; int32_t param3; }; @@ -374,6 +502,7 @@ struct vk_op_soft_max_push_constants { float m0; float m1; uint32_t n_head_log2; + uint32_t nrows_x; }; struct vk_op_argsort_push_constants { @@ -401,6 +530,24 @@ struct vk_op_timestep_embedding_push_constants { uint32_t max_period; }; +struct vk_op_pool2d_push_constants { + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t OC; + uint32_t pelements; + uint32_t op; + int32_t k0; int32_t k1; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; +}; + +struct vk_op_rwkv_wkv6_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -411,7 +558,7 @@ struct vk_staging_memcpy { }; struct vk_op_upscale_push_constants { - uint32_t ne; uint32_t d_offset; + uint32_t ne; uint32_t a_offset; uint32_t d_offset; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; float sf0; float sf1; float sf2; float sf3; @@ -431,16 +578,6 @@ struct vk_context_struct { typedef std::shared_ptr vk_context; typedef std::weak_ptr vk_context_ref; -struct ggml_tensor_extra_gpu { - vk_buffer_ref buffer_gpu; - uint64_t offset; - - void reset() { - buffer_gpu.reset(); - offset = 0; - } -}; - struct ggml_vk_garbage_collector { std::vector tl_semaphores; std::vector semaphores; @@ -551,6 +688,31 @@ struct ggml_backend_vk_context { std::vector tensor_ctxs; }; +static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT + +static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { + if (tensor->view_src) { + return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base; + } + return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; +} + +struct ggml_backend_vk_buffer_context { + vk_device_ref device; + vk_buffer dev_buffer; + std::string name; + + ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : + device(device), + dev_buffer(dev_buffer), + name(name) { + } + + ~ggml_backend_vk_buffer_context() { + ggml_vk_destroy_buffer(dev_buffer); + } +}; + #ifdef GGML_VULKAN_MEMORY_DEBUG void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { std::lock_guard guard(log_mutex); @@ -605,22 +767,22 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor); typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); -GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend); +static void ggml_backend_vk_free(ggml_backend_t backend); -static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, const std::string& name, size_t spv_size, const void* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { - VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")"); +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, + uint32_t parameter_count, std::array wg_denoms, std::vector specialization_constants, + bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { + VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count << + ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << + disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); GGML_ASSERT(parameter_count > 0); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT - std::lock_guard guard(device->mutex); - - pipeline = std::make_shared(); - pipeline->name = name; - pipeline->parameter_count = parameter_count; - pipeline->push_constant_size = push_constant_size; - pipeline->wg_denoms = wg_denoms; - pipeline->align = align; - vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); @@ -669,19 +831,59 @@ static void ggml_vk_create_pipeline(vk_device& device, vk_pipeline& pipeline, co specialization_constants.data() ); + vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; + + if (device->subgroup_require_full_support && require_full_subgroups) { + pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; + } + vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( - vk::PipelineShaderStageCreateFlags(), + pipeline_shader_stage_create_flags, vk::ShaderStageFlagBits::eCompute, pipeline->shader_module, entrypoint.c_str(), &specialization_info); + + vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; + pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; + if (device->subgroup_size_control && required_subgroup_size > 0) { + GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); + pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); + } + vk::ComputePipelineCreateInfo compute_pipeline_create_info( - vk::PipelineCreateFlags(), + vk::PipelineCreateFlags{}, pipeline_shader_create_info, pipeline->layout); - pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; - device->pipelines.insert({ pipeline->name, pipeline }); + vk::PipelineRobustnessCreateInfoEXT rci; + + if (device->pipeline_robustness && disable_robustness) { + rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + compute_pipeline_create_info.setPNext(&rci); + } + + try { + pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + pipeline->compiled = true; + + { + std::lock_guard guard(device->mutex); + device->pipelines.insert({ pipeline->name, pipeline }); + } + + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); } static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { @@ -705,6 +907,10 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); device->pipeline_descriptor_set_requirements[pipeline->name] += n; + if (!pipeline->compiled) { + pipeline->needed = true; + device->need_compiles = true; + } } static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { @@ -787,6 +993,9 @@ static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, s static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { if (ctx->seqs.empty()) { + if (fence) { + ctx->q->queue.submit({}, fence); + } return; } VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")"); @@ -1002,7 +1211,6 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor return buf; } - buf->size = size; vk::BufferCreateInfo buffer_create_info{ vk::BufferCreateFlags(), size, @@ -1030,17 +1238,29 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor if (memory_type_index == UINT32_MAX) { device->device.destroyBuffer(buf->buffer); - buf->size = 0; throw vk::OutOfDeviceMemoryError("No suitable memory type found"); } try { buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); } catch (const vk::SystemError& e) { - // Out of Host/Device memory, clean up buffer - device->device.destroyBuffer(buf->buffer); - buf->size = 0; - throw e; + if (buf->memory_property_flags != fallback_flags) { + // Try again with fallback flags + memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags); + buf->memory_property_flags = fallback_flags; + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index }); + } + catch (const vk::SystemError& e) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } else { + // Out of Host/Device memory, clean up buffer + device->device.destroyBuffer(buf->buffer); + throw e; + } } buf->ptr = nullptr; @@ -1051,6 +1271,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); buf->device = device; + buf->size = size; #ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger->log_allocation(buf, size); @@ -1076,7 +1297,8 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { // Fall back to host memory type buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); } else { - buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal); + // use rebar if available, otherwise fallback to device only visible memory + buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal); } } catch (const vk::SystemError& e) { std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; @@ -1139,507 +1361,646 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ); } +// number of rows/cols for flash attention shader +static constexpr uint32_t flash_attention_num_small_rows = 32; +static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { + GGML_UNUSED(clamp); + + // small rows, large cols + if (small_rows) { + return {flash_attention_num_small_rows, 128}; + } + // small cols to reduce register count + if (ggml_is_quantized(type) || D == 256) { + return {64, 32}; + } + return {64, 64}; +}; + +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id) { + // Needs to be kept up to date on shader changes + const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; + const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t warps = warptile[0] / warptile[10]; + + const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; + const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; + const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; + + return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize; +} + static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); + // some shaders have a minimum subgroup size + const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); + const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); + // mulmat - std::initializer_list warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; - std::initializer_list warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; - std::initializer_list warptile_s = { device->subgroup_size, 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size }; + std::vector l_warptile, m_warptile, s_warptile, + l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, + l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; + std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, + l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, + l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, + l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; - std::initializer_list warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size }; - std::initializer_list warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size }; - std::initializer_list warptile_mmq_s = { device->subgroup_size, 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size }; + uint32_t l_align, m_align, s_align; + if (device->coopmat2) { + // spec constants and tile sizes for non-quant matmul/matmul_id + l_warptile = { 256, 128, 256, 64 }; + m_warptile = { 256, 128, 128, 64 }; + s_warptile = { 128, 64, 64, 64 }; + l_wg_denoms = {128, 256, 1 }; + m_wg_denoms = {128, 128, 1 }; + s_wg_denoms = { 64, 64, 1 }; - std::array l_wg_denoms = {128, 128, 1 }; - std::array m_wg_denoms = { 64, 64, 1 }; - std::array s_wg_denoms = { 32, 32, 1 }; + // spec constants and tile sizes for quant matmul (non-Qi_K) + l_warptile_mmq = { 256, 128, 256, 64 }; + m_warptile_mmq = { 256, 128, 128, 64 }; + s_warptile_mmq = { 256, 128, 128, 64 }; + l_mmq_wg_denoms = { 128, 256, 1 }; + m_mmq_wg_denoms = { 128, 128, 1 }; + s_mmq_wg_denoms = { 128, 128, 1 }; - uint32_t l_align = 128; - uint32_t m_align = 64; - uint32_t s_align = 32; + // spec constants and tile sizes for quant matmul (Qi_K) + l_warptile_mmq_k = { 256, 128, 512, 16 }; + m_warptile_mmq_k = { 256, 128, 256, 16 }; + s_warptile_mmq_k = { 256, 32, 128, 64 }; + l_mmq_wg_denoms_k = { 128, 512, 1 }; + m_mmq_wg_denoms_k = { 128, 256, 1 }; + s_mmq_wg_denoms_k = { 32, 128, 1 }; - device->pipeline_matmul_f32 = std::make_shared(); - device->pipeline_matmul_f32_f16 = std::make_shared(); - device->pipeline_matmul_f16_f32 = std::make_shared(); - device->pipeline_matmul_f16 = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL] = std::make_shared(); + // spec constants and tile sizes for quant matmul_id + l_warptile_mmqid = { 256, 128, 128, 16 }; + m_warptile_mmqid = { 256, 128, 64, 16 }; + s_warptile_mmqid = { 256, 64, 64, 16 }; + l_mmqid_wg_denoms = { 128, 128, 1 }; + m_mmqid_wg_denoms = { 128, 64, 1 }; + s_mmqid_wg_denoms = { 64, 64, 1 }; - device->pipeline_matmul_id_f32 = std::make_shared(); - device->pipeline_matmul_id_f16_f32 = std::make_shared(); - device->pipeline_matmul_id_f16 = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared(); - device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared(); - - if (device->fp16) { - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + l_align = 128; + m_align = 64; + s_align = 32; } else { - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + // Matrix cores require different warp group sizes + const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2; + const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4; + const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; + m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size }; + m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size }; - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align); + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; + m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; + s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; + l_align = 128; + m_align = 64; + s_align = 32; - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders + // and tile sizes, this should handle 16KB, 32KB, and 48KB+. + // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders. + // But the numbers happen to work out for 32KB shared memory size that when using the medium + // size there's enough room for everything, and we assert for this. + uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); + if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { + l_warptile = m_warptile; + l_wg_denoms = m_wg_denoms; + shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float); + GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); + } + if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { + // assert mul_mat_mat_id shaders will fit. + GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); + } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); + if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) { + if (device->properties.limits.maxComputeSharedMemorySize == 32768) { + l_warptile_mmq = m_warptile_mmq; + l_mmq_wg_denoms = m_mmq_wg_denoms; + } else { + l_warptile_mmq = s_warptile_mmq; + l_mmq_wg_denoms = s_mmq_wg_denoms; + } + shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float); + GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize); + } + if (device->properties.limits.maxComputeSharedMemorySize >= 32768) { + // assert mul_mat_mat_id shaders will fit. + GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize); + } + // Disable medium and large matrix multiplication if not enough shared memory is available + // Check mmq warptiles as the largest configuration + // Throw an error if not enough for any matrix multiplication is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) { + std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; + throw std::runtime_error("Shared memory size too small for matrix multiplication."); + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) { + device->mul_mat_m = false; + device->mul_mat_l = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) { + device->mul_mat_l = false; + } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + // Disable mul_mat_id if not enough shared memory is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) { + device->mul_mat_id_s = false; + device->mul_mat_id_m = false; + device->mul_mat_id_l = false; + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) { + device->mul_mat_id_m = false; + device->mul_mat_id_l = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) { + device->mul_mat_id_l = false; + } + } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + if (!device->pipeline_matmul_f32) { + device->pipeline_matmul_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_f32_f16) { + device->pipeline_matmul_f32_f16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_f32) { + device->pipeline_matmul_id_f32 = std::make_shared(); + } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + std::vector> compiles; + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + if (!pipeline) { + pipeline = std::make_shared(); + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + if (!pipeline->needed || pipeline->compiled) { + return; + } + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, + parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + }; - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; + }; - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; + auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); + return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; + }; - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align); +#define CREATE_FA2(TYPE, NAMELC, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align); +#define CREATE_FA(TYPE, NAMELC) \ + CREATE_FA2(TYPE, NAMELC, 64) \ + CREATE_FA2(TYPE, NAMELC, 80) \ + CREATE_FA2(TYPE, NAMELC, 96) \ + CREATE_FA2(TYPE, NAMELC, 112) \ + CREATE_FA2(TYPE, NAMELC, 128) \ + CREATE_FA2(TYPE, NAMELC, 256) - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align); + CREATE_FA(GGML_TYPE_F16, f16) + CREATE_FA(GGML_TYPE_Q4_0, q4_0) + CREATE_FA(GGML_TYPE_Q4_1, q4_1) + CREATE_FA(GGML_TYPE_Q5_0, q5_0) + CREATE_FA(GGML_TYPE_Q5_1, q5_1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0) + // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently + //CREATE_FA(GGML_TYPE_Q2_K, q2_k) + //CREATE_FA(GGML_TYPE_Q3_K, q3_k) + //CREATE_FA(GGML_TYPE_Q4_K, q4_k) + //CREATE_FA(GGML_TYPE_Q5_K, q5_k) + //CREATE_FA(GGML_TYPE_Q6_K, q6_k) + //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) + //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) + //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) + //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) + //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) +#undef CREATE_FA - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align); + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) +#undef CREATE_MM +#undef CREATE_MM2 + } else +#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat_support) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->coopmat_acc_f16_support) { \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + if (device->coopmat_acc_f32_support) { \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + if (device->coopmat_acc_f16_support) { + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } else { + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. + if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + if (device->coopmat_acc_f16_support) { + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } else { + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } + } +#undef CREATE_MM2 +#undef CREATE_MM + } else +#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->fp16) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. + if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM2 +#undef CREATE_MM + } else { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _l) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + if (device->mul_mat ## ID ## _m) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + if (device->mul_mat ## ID ## _s) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines. + if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) { + CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM } // mul mat vec - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + // the number of rows computed per shader depends on GPU model and quant + uint32_t rm_stdq = 1; + uint32_t rm_kq = 2; + if (device->vendor_id == VK_VENDOR_ID_AMD) { + if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN + rm_stdq = 2; + rm_kq = 4; + } + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + rm_stdq = 2; - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -1653,7 +2014,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -1663,7 +2029,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -1672,9 +2043,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); @@ -1687,13 +2063,35 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -1720,27 +2118,49 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); + } ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + + for (auto &c : compiles) { + c.wait(); + } + device->need_compiles = false; } +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); + static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -1768,12 +2188,40 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device = physical_devices[dev_num]; const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + bool fp16_storage = false; + bool fp16_compute = false; bool maintenance4_support = false; + bool sm_builtins = false; + bool amd_shader_core_properties2 = false; + bool pipeline_robustness = false; + bool coopmat2_support = false; + device->coopmat_support = false; // Check if maintenance4 is supported for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { maintenance4_support = true; + } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; + } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) { + amd_shader_core_properties2 = true; + } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { + pipeline_robustness = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + device->subgroup_size_control = true; + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + device->coopmat_support = true; + device->coopmat_m = 0; + device->coopmat_n = 0; + device->coopmat_k = 0; + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; } } @@ -1781,18 +2229,51 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceMaintenance3Properties props3; vk::PhysicalDeviceMaintenance4Properties props4; vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan12Properties vk12_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + props2.pNext = &props3; props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + driver_props.pNext = &vk12_props; + + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; + if (maintenance4_support) { - subgroup_props.pNext = &props4; + last_struct->pNext = (VkBaseOutStructure *)&props4; + last_struct = (VkBaseOutStructure *)&props4; } + if (sm_builtins) { + last_struct->pNext = (VkBaseOutStructure *)&sm_props; + last_struct = (VkBaseOutStructure *)&sm_props; + } + if (amd_shader_core_properties2) { + last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + } + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; + } + +#if defined(VK_NV_cooperative_matrix2) + vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; + last_struct = (VkBaseOutStructure *)&coopmat2_props; + } +#endif + device->physical_device.getProperties2(&props2); device->properties = props2.properties; const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { - device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); + device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); } else if (maintenance4_support) { device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); } else { @@ -1802,23 +2283,23 @@ static vk_device ggml_vk_get_device(size_t idx) { device->vendor_id = device->properties.vendorID; device->subgroup_size = subgroup_props.subgroupSize; device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; - - bool fp16_storage = false; - bool fp16_compute = false; - - for (const auto& properties : ext_props) { - if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { - fp16_storage = true; - } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { - fp16_compute = true; - } + if (sm_builtins) { + device->shader_core_count = sm_props.shaderSMCount; + } else if (amd_shader_core_properties2) { + device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else { + device->shader_core_count = 0; } + device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; - const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); - const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { + device->coopmat_support = false; + } + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues @@ -1856,10 +2337,161 @@ static vk_device ggml_vk_get_device(size_t idx) { vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; vk11_features.pNext = &vk12_features; + last_struct = (VkBaseOutStructure *)&vk12_features; + + VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; + pl_robustness_features.pNext = nullptr; + pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; + pl_robustness_features.pipelineRobustness = VK_FALSE; + + if (pipeline_robustness) { + last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; + last_struct = (VkBaseOutStructure *)&pl_robustness_features; + device_extensions.push_back("VK_EXT_pipeline_robustness"); + } + + VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; + subgroup_size_control_features.pNext = nullptr; + subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; + subgroup_size_control_features.computeFullSubgroups = false; + subgroup_size_control_features.subgroupSizeControl = false; + + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; + } + +#if defined(VK_KHR_cooperative_matrix) + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (device->coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } +#endif + +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.pNext = nullptr; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + device_extensions.push_back("VK_NV_cooperative_matrix2"); + } +#endif + + VkPhysicalDeviceMaintenance4Features maint4_features {}; + maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&maint4_features; + last_struct = (VkBaseOutStructure *)&maint4_features; + device_extensions.push_back("VK_KHR_maintenance4"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; + + if (device->subgroup_size_control) { + device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; + device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; + device_extensions.push_back("VK_EXT_subgroup_size_control"); + } + + device->subgroup_size_control = device->subgroup_size_control && + (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && + subgroup_size_control_features.subgroupSizeControl; + + if (device->subgroup_size_control) { + device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; + } + +#if defined(VK_KHR_cooperative_matrix) + device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; +#endif + + if (coopmat2_support) { +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads && + vk12_features.bufferDeviceAddress) { + + std::vector flexible_dimensions; + uint32_t count = 0; + + PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = + (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) + vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV"); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); + + VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; + empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; + flexible_dimensions.resize(count, empty_prop); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); + + bool found_fp16_128 = false, + found_fp16_256 = false, + found_fp32_128 = false, + found_fp32_256 = false; + // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 + // with 32x16x16 and 256 with 32x32x16. + for (auto &prop : flexible_dimensions) { + if (prop.saturatingAccumulation == VK_FALSE && + prop.scope == VK_SCOPE_WORKGROUP_KHR && + prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } + } + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } + } + } + } + if (found_fp16_128 && found_fp16_256 && + found_fp32_128 && found_fp32_256 && + coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { + device->coopmat2 = true; + } + } +#endif + } + if (!vk11_features.storageBuffer16BitAccess) { std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; throw std::runtime_error("Unsupported device"); @@ -1874,7 +2506,76 @@ static vk_device ggml_vk_get_device(size_t idx) { if (device->fp16) { device_extensions.push_back("VK_KHR_shader_float16_int8"); } - device->name = device->properties.deviceName.data(); + +#if defined(VK_KHR_cooperative_matrix) + if (device->coopmat_support) { + // Query supported shapes + std::vector cm_props; + + PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = + (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR"); + + uint32_t cm_props_num; + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr); + + cm_props.resize(cm_props_num); + + for (auto& prop : cm_props) { + prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; + } + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data()); + + VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size()); + + for (auto& prop : cm_props) { + VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope)); + + if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f32_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f32_support = true; + } + } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f16_support = true; + } + } + } + } + + if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { + // No suitable matmul mode found + GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); + device->coopmat_support = false; + } + } + + if (device->coopmat_support) { + device_extensions.push_back("VK_KHR_cooperative_matrix"); + } +#endif + device->name = GGML_VK_NAME + std::to_string(idx); device_create_info = { vk::DeviceCreateFlags(), @@ -1889,6 +2590,37 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false); // Shaders + // Disable matmul tile sizes early if performance low or not supported + switch (device->vendor_id) { +#ifndef GGML_VULKAN_RUN_TESTS + case VK_VENDOR_ID_AMD: + case VK_VENDOR_ID_INTEL: + device->mul_mat_l = false; + device->mul_mat_m = true; + device->mul_mat_s = true; + device->mul_mat_id_l = false; + device->mul_mat_id_m = true; + device->mul_mat_id_s = true; + break; + case VK_VENDOR_ID_APPLE: + device->mul_mat_l = false; + device->mul_mat_m = true; + device->mul_mat_s = false; + device->mul_mat_id_l = false; + device->mul_mat_id_m = true; + device->mul_mat_id_s = false; + break; +#endif + default: + device->mul_mat_l = true; + device->mul_mat_m = true; + device->mul_mat_s = true; + device->mul_mat_id_l = true; + device->mul_mat_id_m = true; + device->mul_mat_id_s = true; + break; + } + ggml_vk_load_shaders(device); if (!device->single_queue) { @@ -1901,6 +2633,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->buffer_type = { /* .iface = */ ggml_backend_vk_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx), /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device }, }; @@ -1914,7 +2647,6 @@ static vk_device ggml_vk_get_device(size_t idx) { return vk_instance.devices[idx]; } - static void ggml_vk_print_gpu_info(size_t idx) { GGML_ASSERT(idx < vk_instance.device_indices.size()); size_t dev_num = vk_instance.device_indices[idx]; @@ -1945,15 +2677,31 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool fp16_storage = false; bool fp16_compute = false; + bool coopmat_support = false; + bool coopmat2_support = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { fp16_storage = true; } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { fp16_compute = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + coopmat_support = true; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; +#endif } } + if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { + coopmat_support = false; + } + const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; @@ -1976,15 +2724,35 @@ static void ggml_vk_print_gpu_info(size_t idx) { vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; vk11_features.pNext = &vk12_features; + // Pointer to the last chain element + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; + +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; + coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; +#endif + + std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; + std::string device_name = props2.properties.deviceName.data(); - std::cerr << GGML_VK_NAME << idx << ": " << device_name << " (" << driver_props.driverName << ") | uma: " << uma << " | fp16: " << fp16 << " | warp size: " << subgroup_size << std::endl; + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { - std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl; + GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); } } @@ -1999,7 +2767,14 @@ void ggml_vk_instance_init() { vk_instance_initialized = true; - vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; + uint32_t api_version = vk::enumerateInstanceVersion(); + + if (api_version < VK_API_VERSION_1_2) { + std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; + GGML_ABORT("fatal error"); + } + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions); @@ -2039,8 +2814,7 @@ void ggml_vk_instance_init() { }; validation_features.setPNext(nullptr); instance_create_info.setPNext(&validation_features); - - std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl; + GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); } vk_instance.instance = vk::createInstance(instance_create_info); @@ -2154,8 +2928,7 @@ void ggml_vk_instance_init() { vk_instance.device_indices.push_back(0); } } - - std::cerr << "ggml_vulkan: Found " << vk_instance.device_indices.size() << " Vulkan devices:" << std::endl; + GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { ggml_vk_print_gpu_info(i); @@ -2202,6 +2975,11 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -2211,7 +2989,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type return ctx->device->pipeline_dequant[type]; } -static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) { +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_f32; @@ -2219,14 +2997,23 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { return ctx->device->pipeline_matmul_f32_f16; } - if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { - return ctx->device->pipeline_matmul_f16_f32; - } - if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { - return ctx->device->pipeline_matmul_f16; + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f32acc; + } } - if (src1_type != GGML_TYPE_F32) { + if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { return nullptr; } @@ -2241,18 +3028,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: return nullptr; } - return ctx->device->pipeline_dequant_mul_mat_mat[src0_type]; + if (ctx->device->coopmat2) { + assert(src1_type == GGML_TYPE_F16); + return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc; + } + return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; } -static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16); + GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); switch (a_type) { case GGML_TYPE_F32: @@ -2267,28 +3064,42 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: return nullptr; } - return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type]; + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1]; } -static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) { +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()"); if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_id_f32; } - if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { - return ctx->device->pipeline_matmul_id_f16_f32; - } - if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { - return ctx->device->pipeline_matmul_id_f16; + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f32acc; + } } - GGML_ASSERT(src1_type == GGML_TYPE_F32); + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { case GGML_TYPE_Q4_0: @@ -2301,13 +3112,18 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: return nullptr; } - return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type]; + return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; } static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { @@ -2327,6 +3143,11 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -2559,8 +3380,8 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont GGML_ABORT("fatal error"); } // Check if src is pinned memory - vk_buffer buf; - size_t buf_offset; + vk_buffer buf = nullptr; + size_t buf_offset = 0; ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset); const uint64_t ne0 = tensor->ne[0]; @@ -2623,7 +3444,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont VkBufferCopy buf_copy{ 0, offset, copy_size }; ggml_vk_sync_buffers(subctx); - vkCmdCopyBuffer(subctx->s->buffer, staging->buffer, dst->buffer, 1, &buf_copy); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); for (uint64_t i3 = 0; i3 < ne3; i3++) { for (uint64_t i2 = 0; i2 < ne2; i2++) { @@ -2656,7 +3477,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } // Check if src is pinned memory vk_buffer buf = nullptr; - size_t buf_offset; + size_t buf_offset = 0; ggml_vk_host_get(dst->device, src, buf, buf_offset); if (buf != nullptr) { @@ -2698,7 +3519,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz copy_size}; ggml_vk_sync_buffers(subctx); - vkCmdCopyBuffer(subctx->s->buffer, staging_buffer->buffer, dst->buffer, 1, &buf_copy); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); if (width == spitch) { deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); @@ -2754,7 +3575,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size // Check if dst is pinned memory vk_buffer buf = nullptr; - size_t buf_offset; + size_t buf_offset = 0; ggml_vk_host_get(src->device, dst, buf, buf_offset); std::vector slices(1); @@ -2803,7 +3624,11 @@ static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); - if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + + // If the device is not an UMA device the memory is host-accessible through rebar. While writing + // through PCIe is sufficient fast reading back data from PCIe is slower than going through + // the HW device to host copy path. + if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); memcpy(dst, (uint8_t *) src->ptr + offset, size); @@ -2830,7 +3655,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds VkBufferCopy bc{ src_offset, dst_offset, size }; - vkCmdCopyBuffer(ctx->s->buffer, src->buffer, dst->buffer, 1, &bc); + vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); } static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { @@ -2872,55 +3697,44 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz dst->device->device.resetFences({ dst->device->fence }); } -static uint32_t ggml_vk_guess_split_k(int m, int n, int k) { +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")"); - // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) { - // return 4; - // } - return 1; - - GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k); -} - -static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { - if (m <= 32 || n <= 32) { - return aligned ? mmp->a_s : mmp->s; + uint32_t split_k = 1; + if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) { + // If k is 'large' and the SMs will fill less than halfway, use split_k. + uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); + uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); + if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) { + split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); + // Clamp to 2 or 4 + split_k = std::min(split_k, 4u); + if (split_k == 3) { + split_k = 2; + } + } } - return aligned ? mmp->a_m : mmp->m; - GGML_UNUSED(ctx); -} - -static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) { - return aligned ? mmp->a_m : mmp->m; - - GGML_UNUSED(ctx); -} - -static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) { - return aligned ? mmp->a_s : mmp->s; - - GGML_UNUSED(ctx); + return split_k; } static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); - switch (ctx->device->vendor_id) { - case VK_VENDOR_ID_AMD: - return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned); - case VK_VENDOR_ID_APPLE: - return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned); - case VK_VENDOR_ID_INTEL: - return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned); - default: - break; - } - if (m <= 32 || n <= 32) { + if (ctx->device->coopmat2) { + if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) { + return aligned ? mmp->a_l : mmp->l; + } + if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) { + return aligned ? mmp->a_m : mmp->m; + } return aligned ? mmp->a_s : mmp->s; } - if (m <= 64 || n <= 64) { + + if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; @@ -2955,6 +3769,33 @@ static void ggml_vk_matmul( ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); } +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")"); + + if (ctx->device->coopmat2) { + if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) { + return aligned ? mmp->a_l : mmp->l; + } + if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; +} + +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align; +} + static void ggml_vk_matmul_id( ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, @@ -2978,18 +3819,61 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } -static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) { - if (from == GGML_TYPE_F32 && to == GGML_TYPE_F32) { - return ctx->device->pipeline_cpy_f32_f32; +static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { + + // Choose "contiguous copy" shader if src/dst are contiguous + bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); + + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } } - if (from == GGML_TYPE_F32 && to == GGML_TYPE_F16) { - return ctx->device->pipeline_cpy_f32_f16; + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f16; + } else { + return ctx->device->pipeline_cpy_f32_f16; + } } - if (from == GGML_TYPE_F16 && to == GGML_TYPE_F16) { - return ctx->device->pipeline_cpy_f16_f16; + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } + } + if (src->type == GGML_TYPE_F32) { + switch (to) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_f32_quant[to]; + default: + break; + } } - std::cerr << "Missing CPY op for types: " << ggml_type_name(from) << " " << ggml_type_name(to) << std::endl; + if (to == GGML_TYPE_F32) { + switch (src->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_quant_f32[src->type]; + default: + break; + } + } + + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; GGML_ABORT("fatal error"); } @@ -2999,16 +3883,27 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& const int tensor_type_size = ggml_type_size(tensor->type); const uint32_t ne = ggml_nelements(tensor); + std::array elements; - const vk_op_unary_push_constants pc = { + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { (uint32_t)ne, (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]), 0, 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; + init_pushconst_fastdiv(pc); ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); } static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -3035,13 +3930,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t r2 = ne12 / ne02; const uint64_t r3 = ne13 / ne03; - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; - ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; - vk_buffer d_Qx; + vk_buffer d_Qx = nullptr; size_t qx_buf_offset = 0; - vk_buffer d_Qy; + vk_buffer d_Qy = nullptr; size_t qy_buf_offset = 0; bool src0_uma = false; @@ -3054,19 +3949,22 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub src1_uma = d_Qy != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src1); const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type); + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); const bool qx_needs_dequant = mmp == nullptr || x_non_contig; const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; - if (mmp == nullptr) { + if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat - mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16); + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); } // Not implemented @@ -3079,10 +3977,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11)); const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; - const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10); - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned); + const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; @@ -3093,12 +3991,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline to_fp16_vk_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -3108,7 +4006,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t y_sz_upd = y_sz * ne12 * ne13; - const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0; + const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; if ( (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) || (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) || @@ -3139,8 +4037,8 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub return; } - vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset + dst->view_offs; + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; GGML_ASSERT(d_D != nullptr); GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); vk_buffer d_X; @@ -3148,13 +4046,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub vk_buffer d_Y; uint64_t y_buf_offset = 0; if (!src0_uma) { - d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset + src0->view_offs; + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if (!src1_uma) { - d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset + src1->view_offs; + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if (qx_needs_dequant) { @@ -3225,8 +4123,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t ne12 = src1->ne[2]; const uint64_t ne13 = src1->ne[3]; - GGML_ASSERT(ne11 == 1); - const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; const uint64_t ne22 = dst->ne[2]; @@ -3235,13 +4131,18 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t r2 = ne12 / ne02; const uint64_t r3 = ne13 / ne03; - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; - ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra; + // batch_n indicates that we need to compute a few vector results, and this assumes + // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. + GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); + bool batch_n = ne11 > 1; - vk_buffer d_Qx; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; size_t qx_buf_offset = 0; - vk_buffer d_Qy; + vk_buffer d_Qy = nullptr; size_t qy_buf_offset = 0; bool src0_uma = false; @@ -3278,14 +4179,14 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type); + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11); GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); @@ -3316,21 +4217,21 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& return; } - vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset + dst->view_offs; + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_X; uint64_t x_buf_offset = 0; vk_buffer d_Y; uint64_t y_buf_offset = 0; if(!src0_uma) { - d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset + src0->view_offs; + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if(!src1_uma) { - d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset + src1->view_offs; + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if (qx_needs_dequant) { @@ -3357,8 +4258,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); } - uint32_t stride_batch_x = ne00*ne01; - uint32_t stride_batch_y = ne10*ne11; + // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride + uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; + uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); + uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); @@ -3375,13 +4278,13 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (ne01 > max_groups_x) { groups_z = 64; - groups_x /= groups_z; + groups_x = CEIL_DIV(groups_x, groups_z); } // compute const vk_mat_vec_push_constants pc = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21), + stride_batch_x, stride_batch_y, stride_batch_d, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, }; ggml_vk_sync_buffers(subctx); @@ -3413,11 +4316,11 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c GGML_ASSERT(ne11 == 1); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; - ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; - vk_buffer d_Qy; + vk_buffer d_Qy = nullptr; size_t qy_buf_offset = 0; bool src1_uma = false; @@ -3441,15 +4344,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c return; } - vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset + dst->view_offs; + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; GGML_ASSERT(d_D != nullptr); - vk_buffer d_Qx = extra_src0->buffer_gpu.lock(); - const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs; + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); if (!src1_uma) { - d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset + src1->view_offs; + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; GGML_ASSERT(d_Qx != nullptr); } @@ -3491,9 +4394,9 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con GGML_ASSERT(ne11 == 1); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; - ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; vk_buffer d_Qy = nullptr; size_t qy_buf_offset = 0; @@ -3520,15 +4423,15 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con return; } - vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset + dst->view_offs; + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; GGML_ASSERT(d_D != nullptr); - vk_buffer d_Qx = extra_src0->buffer_gpu.lock(); - const uint64_t qx_buf_offset = extra_src0->offset + src0->view_offs; + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); if (!src1_uma) { - d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset + src1->view_offs; + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; GGML_ASSERT(d_Qx != nullptr); } @@ -3547,11 +4450,24 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); - if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) { + if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && + // detect 0213 permutation, and batch size of 1 + src0->nb[0] <= src0->nb[2] && + src0->nb[2] <= src0->nb[1] && + src0->nb[1] <= src0->nb[3] && + src1->nb[0] <= src1->nb[2] && + src1->nb[2] <= src1->nb[1] && + src1->nb[1] <= src1->nb[3] && + src0->ne[3] == 1 && + src1->ne[3] == 1) { ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun); - } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1) { + } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && + !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); - } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) + // when ne12 and ne13 are one. + } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); } else { ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); @@ -3590,16 +4506,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t n_as = ne02; - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; - ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra; - ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; - vk_buffer d_Qx; + vk_buffer d_Qx = nullptr; size_t qx_buf_offset = 0; - vk_buffer d_Qy; + vk_buffer d_Qy = nullptr; size_t qy_buf_offset = 0; - vk_buffer d_ids; + vk_buffer d_ids = nullptr; size_t ids_buf_offset = 0; bool src0_uma = false; @@ -3615,18 +4531,22 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ids_uma = d_ids != nullptr; } - const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src1); const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type); + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); const bool qx_needs_dequant = mmp == nullptr || x_non_contig; const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; - if (mmp == nullptr) { - GGML_ABORT("fatal error"); + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); } // Not implemented @@ -3636,10 +4556,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t y_ne = ne11 * ne10; const uint64_t d_ne = ne21 * ne20; - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1)); + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1)); const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); @@ -3652,12 +4572,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline to_fp16_vk_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -3690,26 +4610,26 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& return; } - vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset + dst->view_offs; + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_X; uint64_t x_buf_offset = 0; vk_buffer d_Y; uint64_t y_buf_offset = 0; if (!src0_uma) { - d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset + src0->view_offs; + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if (!src1_uma) { - d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset + src1->view_offs; + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if (!ids_uma) { - d_ids = extra_ids->buffer_gpu.lock(); - ids_buf_offset = extra_ids->offset + ids->view_offs; + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; GGML_ASSERT(d_ids != nullptr); } if (qx_needs_dequant) { @@ -3795,16 +4715,16 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t ne22 = dst->ne[2]; const uint64_t ne23 = dst->ne[3]; - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; - ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * extra_src1 = (ggml_tensor_extra_gpu *) src1->extra; - ggml_tensor_extra_gpu * extra_ids = (ggml_tensor_extra_gpu *) ids->extra; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; - vk_buffer d_Qx; + vk_buffer d_Qx = nullptr; size_t qx_buf_offset = 0; - vk_buffer d_Qy; + vk_buffer d_Qy = nullptr; size_t qy_buf_offset = 0; - vk_buffer d_ids; + vk_buffer d_ids = nullptr; size_t ids_buf_offset = 0; bool src0_uma = false; @@ -3845,10 +4765,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -3883,26 +4803,26 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte return; } - vk_buffer d_D = extra->buffer_gpu.lock(); - const uint64_t d_buf_offset = extra->offset + dst->view_offs; + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; GGML_ASSERT(d_D != nullptr); vk_buffer d_X; uint64_t x_buf_offset = 0; vk_buffer d_Y; uint64_t y_buf_offset = 0; if(!src0_uma) { - d_Qx = extra_src0->buffer_gpu.lock(); - qx_buf_offset = extra_src0->offset + src0->view_offs; + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; GGML_ASSERT(d_Qx != nullptr); } if(!src1_uma) { - d_Qy = extra_src1->buffer_gpu.lock(); - qy_buf_offset = extra_src1->offset + src1->view_offs; + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; GGML_ASSERT(d_Qy != nullptr); } if(!ids_uma) { - d_ids = extra_ids->buffer_gpu.lock(); - ids_buf_offset = extra_ids->offset + ids->view_offs; + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; GGML_ASSERT(d_ids != nullptr); } if (qx_needs_dequant) { @@ -3942,7 +4862,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (ne01 > max_groups_x) { groups_z = 64; - groups_x /= groups_z; + groups_x = CEIL_DIV(groups_x, groups_z); } // compute @@ -3967,6 +4887,185 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } +static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; + std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; + std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const uint32_t nem1 = mask ? mask->ne[1] : 0; + const uint32_t nbm1 = mask ? mask->nb[1] : 0; + + const uint32_t D = neq0; + const uint32_t N = neq1; + const uint32_t KV = nek1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(nev1 == nek1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + assert(dst->type == GGML_TYPE_F32); + assert(q->type == GGML_TYPE_F32); + assert(k->type == v->type); + + vk_pipeline *pipelines; + // XXX TODO other backends may be changing accumulator precision to default to f32 soon + bool f32acc = dst->op_params[3] == GGML_PREC_F32; + bool small_rows = N <= flash_attention_num_small_rows; + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + default: + assert(!"unsupported D value"); + return; + } + assert(pipelines); + + const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); + const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + bool aligned = (KV % pipelines[1]->align) == 0 && + // the "aligned" shader variant will forcibly align strides, for performance + (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + + vk_pipeline pipeline = pipelines[aligned]; + assert(pipeline); + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head_kv = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; + size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; + + bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); + ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); + Q_uma = d_Q != nullptr; + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + D_uma = d_D != nullptr; + if (mask) { + ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); + M_uma = d_M != nullptr; + } + } + + + ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + + if (!Q_uma) { + d_Q = q_buf_ctx->dev_buffer; + q_buf_offset = vk_tensor_offset(q) + q->view_offs; + } + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_buf_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_buf_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!D_uma) { + d_D = d_buf_ctx->dev_buffer; + d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + if (!M_uma) { + d_M = d_Q; + m_buf_offset = q_buf_offset; + if (mask) { + ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; + d_M = m_buf_ctx->dev_buffer; + m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; + } + } + + const vk_flash_attn_push_constants pc = { N, KV, + (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + (uint32_t)neq2, (uint32_t)neq3, + (uint32_t)nek2, (uint32_t)nek3, + (uint32_t)nev2, (uint32_t)nev3, + nem1, + q_stride, (uint32_t)nbq2, (uint32_t)nbq3, + k_stride, (uint32_t)nbk2, (uint32_t)nbk3, + v_stride, (uint32_t)nbv2, (uint32_t)nbv3, + nbm1, + scale, max_bias, logit_softcap, + mask != nullptr, n_head_log2, m0, m1 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); +} + static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { switch (op) { case GGML_OP_GET_ROWS: @@ -3985,20 +5084,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_ADD: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_add_f32; + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32; } if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { - return ctx->device->pipeline_add_f16_f32_f16; + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; } return nullptr; case GGML_OP_MUL: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_mul_f32; + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; } return nullptr; case GGML_OP_DIV: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_div_f32; + return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32; } return nullptr; case GGML_OP_CONCAT: @@ -4055,7 +5154,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const case GGML_OP_CPY: case GGML_OP_CONT: case GGML_OP_DUP: - return ggml_vk_get_cpy_pipeline(ctx, src0->type, dst->type); + return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); case GGML_OP_NORM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_norm_f32; @@ -4111,10 +5210,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_soft_max_f32; + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; } if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_soft_max_f32_f16; + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; } return nullptr; case GGML_OP_ROPE: @@ -4162,6 +5261,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_timestep_embedding_f32; } return nullptr; + case GGML_OP_POOL_2D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pool2d_f32; + } + return nullptr; + case GGML_OP_RWKV_WKV6: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv6_f32; + } + return nullptr; case GGML_OP_LEAKY_RELU: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_leaky_relu_f32; @@ -4183,7 +5292,6 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_DIV: case GGML_OP_CONCAT: case GGML_OP_UPSCALE: - case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SIN: case GGML_OP_COS: @@ -4196,8 +5304,59 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { } } +static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t) +{ + return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; +} + +template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + GGML_UNUSED(p); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(dst); + static_assert(!std::is_const::value, "unexpected type"); + GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); + GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); + GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); + GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0)); + + p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; + + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.a_offset = a_offset; + p.d_offset = d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + template -static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false) { +static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; if (src1 != nullptr) { std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -4207,9 +5366,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); - GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT - GGML_ASSERT(dst->extra != nullptr); + GGML_ASSERT(dst->buffer != nullptr); const uint64_t ne00 = src0->ne[0]; const uint64_t ne01 = src0->ne[1]; const uint64_t ne02 = src0->ne[2]; @@ -4237,6 +5396,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint64_t ned3 = dst->ne[3]; const uint64_t ned = ned0 * ned1; + init_pushconst_fastdiv(pc); + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); if (pipeline == nullptr) { @@ -4255,10 +5416,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; - ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; - ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; - ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; + ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; vk_buffer d_X = nullptr; size_t x_buf_offset = 0; @@ -4289,7 +5450,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0; uint64_t d_sz = ggml_type_size(dst->type) * ned; - vk_buffer d_D = extra->buffer_gpu.lock(); + vk_buffer d_D = dst_buf_ctx->dev_buffer; // Workaround for tiny tensor inputs on ROPE if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) { @@ -4297,23 +5458,28 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } GGML_ASSERT(d_D != nullptr); - uint64_t d_buf_offset = ((extra->offset + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; - GGML_ASSERT(d_buf_offset == extra->offset || op == GGML_OP_CPY); // NOLINT + uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; if(!src0_uma) { - d_X = extra_src0->buffer_gpu.lock(); - x_buf_offset = extra_src0->offset + src0->view_offs; + d_X = src0_buf_ctx->dev_buffer; + x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; GGML_ASSERT(d_X != nullptr); } if (use_src1 && !src1_uma) { - d_Y = extra_src1->buffer_gpu.lock(); - y_buf_offset = extra_src1->offset + src1->view_offs; + d_Y = src1_buf_ctx->dev_buffer; + y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; GGML_ASSERT(d_Y != nullptr); } if (use_src2 && !src2_uma) { - d_Z = extra_src2->buffer_gpu.lock(); - z_buf_offset = extra_src2->offset + src2->view_offs; + d_Z = src2_buf_ctx->dev_buffer; + z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; GGML_ASSERT(d_Z != nullptr); } + // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. + init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); + x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); if (op_supports_incontiguous) { x_sz = ggml_nbytes(src0); @@ -4382,7 +5548,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t OH = is_2D ? dst->ne[2] : 1; const uint32_t OW = dst->ne[1]; - const uint32_t batch = src1->ne[3]; + const uint32_t batch = src1->ne[is_2D ? 3 : 2]; elements = { OW * KW * KH, OH, batch * IC }; } break; @@ -4392,6 +5558,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co uint32_t half_ceil = (dim + 1) / 2; elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; } break; + case GGML_OP_POOL_2D: + { + const uint32_t N = dst->ne[3]; + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + elements = { N * OC * OH * OW, 1, 1}; + } break; case GGML_OP_ADD: case GGML_OP_DIV: case GGML_OP_MUL: @@ -4490,11 +5664,9 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, } static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 @@ -4506,7 +5678,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, - d_offset, + 0, 0.0f, 0.0f, offset, }, dryrun); } @@ -4556,6 +5728,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * k = dst->src[0]; + const ggml_tensor * v = dst->src[1]; + const ggml_tensor * r = dst->src[2]; + const ggml_tensor * tf = dst->src[3]; + const ggml_tensor * td = dst->src[4]; + const ggml_tensor * state = dst->src[5]; + + GGML_ASSERT(!ggml_is_quantized(k->type)); + GGML_ASSERT(!ggml_is_quantized(v->type)); + GGML_ASSERT(!ggml_is_quantized(r->type)); + GGML_ASSERT(!ggml_is_quantized(tf->type)); + GGML_ASSERT(!ggml_is_quantized(td->type)); + GGML_ASSERT(!ggml_is_quantized(state->type)); + GGML_ASSERT(dst->buffer != nullptr); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; + ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; + ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; + ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; + + ggml_vk_sync_buffers(subctx); + + vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr; + size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0; + bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); + ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); + ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); + ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); + ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + R_uma = d_R != nullptr; + TF_uma = d_TF != nullptr; + TD_uma = d_TD != nullptr; + STATE_uma = d_State != nullptr; + DST_uma = d_D != nullptr; + } + + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!R_uma) { + d_R = r_buf_ctx->dev_buffer; + r_offset = vk_tensor_offset(r) + r->view_offs; + } + if (!TF_uma) { + d_TF = tf_buf_ctx->dev_buffer; + tf_offset = vk_tensor_offset(tf) + tf->view_offs; + } + if (!TD_uma) { + d_TD = td_buf_ctx->dev_buffer; + td_offset = vk_tensor_offset(td) + td->view_offs; + } + if (!STATE_uma) { + d_State = state_buf_ctx->dev_buffer; + state_offset = vk_tensor_offset(state) + state->view_offs; + } + if (!DST_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + const uint64_t k_size = ggml_nbytes(k); + const uint64_t v_size = ggml_nbytes(v); + const uint64_t r_size = ggml_nbytes(r); + const uint64_t tf_size = ggml_nbytes(tf); + const uint64_t td_size = ggml_nbytes(td); + const uint64_t state_size = ggml_nbytes(state); + const uint64_t dst_size = ggml_nbytes(dst); + + std::array elements = { + (uint32_t)(pc.B * pc.H), + 1, + 1 + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_K, k_offset, k_size }, + vk_subbuffer{ d_V, v_offset, v_size }, + vk_subbuffer{ d_R, r_offset, r_size }, + vk_subbuffer{ d_TF, tf_offset, tf_size }, + vk_subbuffer{ d_TD, td_offset, td_size }, + vk_subbuffer{ d_State, state_offset, state_size }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); +} + +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[5]->ne[1]; + + ggml_vk_op_f32_rwkv6( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + dryrun + ); +} + static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; @@ -4582,7 +5882,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c const float sf3 = (float)dst->ne[3] / src0->ne[3]; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { - (uint32_t)ggml_nelements(dst), 0, + (uint32_t)ggml_nelements(dst), 0, 0, (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], sf0, sf1, sf2, sf3, @@ -4599,7 +5899,8 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, - op_params[0], 0.0f + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } @@ -4613,6 +5914,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } @@ -4626,6 +5928,7 @@ static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } @@ -4639,6 +5942,7 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } @@ -4653,6 +5957,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, op_params[0], op_params[1], + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } @@ -4666,6 +5971,7 @@ static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } @@ -4679,21 +5985,21 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, - d_offset, + 0, 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }, dryrun); } @@ -4750,6 +6056,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, scale, max_bias, m0, m1, n_head_log2, + nrows_x, }, dryrun); } @@ -4821,7 +6128,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t OW = dst->ne[1]; const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 + const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 const uint32_t pelements = OW * KW * KH; @@ -4844,6 +6151,34 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context }, dryrun); } +static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + uint32_t op = static_cast(dst->op_params[0]); + const int32_t k1 = dst->op_params[1]; + const int32_t k0 = dst->op_params[2]; + const int32_t s1 = dst->op_params[3]; + const int32_t s0 = dst->op_params[4]; + const int32_t p1 = dst->op_params[5]; + const int32_t p0 = dst->op_params[6]; + + const uint32_t IH = src0->ne[1]; + const uint32_t IW = src0->ne[0]; + + const uint32_t N = dst->ne[3]; + + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + + const uint32_t parallel_elements = N * OC * OH * OW; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { + IW, IH, OW, OH, OC, + parallel_elements, + op, + k0, k1, s0, s1, p0, p1, + }, dryrun); +} + static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const float * op_params = (const float *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); @@ -4900,10 +6235,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t p = ctx->device->pipeline_matmul_f32_f16->a_s; shname = "F32_F16_ALIGNED_S"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16_f32->a_s; + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s; shname = "F16_F32_ALIGNED_S"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16->a_s; + p = ctx->device->pipeline_matmul_f16.f32acc->a_s; shname = "F16_ALIGNED_S"; } else { GGML_ABORT("fatal error"); @@ -4916,10 +6251,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t p = ctx->device->pipeline_matmul_f32_f16->a_m; shname = "F32_F16_ALIGNED_M"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16_f32->a_m; + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m; shname = "F16_F32_ALIGNED_M"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16->a_m; + p = ctx->device->pipeline_matmul_f16.f32acc->a_m; shname = "F16_ALIGNED_M"; } else { GGML_ABORT("fatal error"); @@ -4932,10 +6267,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t p = ctx->device->pipeline_matmul_f32_f16->a_l; shname = "F32_F16_ALIGNED_L"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16_f32->a_l; + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l; shname = "F16_F32_ALIGNED_L"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16->a_l; + p = ctx->device->pipeline_matmul_f16.f32acc->a_l; shname = "F16_ALIGNED_L"; } else { GGML_ABORT("fatal error"); @@ -4955,10 +6290,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t p = ctx->device->pipeline_matmul_f32_f16->s; shname = "F32_F16_S"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16_f32->s; + p = ctx->device->pipeline_matmul_f16_f32.f32acc->s; shname = "F16_F32_S"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16->s; + p = ctx->device->pipeline_matmul_f16.f32acc->s; shname = "F16_S"; } } else if (shader_size == 1) { @@ -4969,10 +6304,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t p = ctx->device->pipeline_matmul_f32_f16->m; shname = "F32_F16_M"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16_f32->m; + p = ctx->device->pipeline_matmul_f16_f32.f32acc->m; shname = "F16_F32_M"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16->m; + p = ctx->device->pipeline_matmul_f16.f32acc->m; shname = "F16_M"; } } else if (shader_size == 2) { @@ -4983,10 +6318,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t p = ctx->device->pipeline_matmul_f32_f16->l; shname = "F32_F16_L"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16_f32->l; + p = ctx->device->pipeline_matmul_f16_f32.f32acc->l; shname = "F16_F32_L"; } else if (std::is_same() && std::is_same()) { - p = ctx->device->pipeline_matmul_f16->l; + p = ctx->device->pipeline_matmul_f16.f32acc->l; shname = "F16_L"; } } @@ -5005,6 +6340,8 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } + ggml_pipeline_allocate_descriptor_sets(ctx->device); + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -5016,19 +6353,27 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t for (size_t i = 0; i < x_ne; i++) { if (std::is_same()) { x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = 1.0f; + // x[i] = i + 1; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; } else if (std::is_same()) { x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // x[i] = ggml_fp32_to_fp16(1.0f); + // x[i] = ggml_fp32_to_fp16(i + 1); + // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); } else { GGML_ABORT("fatal error"); } } for (size_t i = 0; i < y_ne; i++) { if (std::is_same()) { - // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; - y[i] = (i % k == i / k) ? 1.0f : 0.0f; + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i + 1; } else if (std::is_same()) { - // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); - y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + // y[i] = ggml_fp32_to_fp16(i + 1); } else { GGML_ABORT("fatal error"); } @@ -5038,16 +6383,16 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ggml_vk_ctx_begin(ctx->device, subctx); for (size_t i = 0; i < num_it; i++) { - ggml_vk_ctx_begin(ctx->device, subctx); ggml_vk_matmul( ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, k*m, k*n, m*n, split_k, batch, batch, batch, 1, 1 ); - ggml_vk_ctx_end(subctx); } + ggml_vk_ctx_end(subctx); auto begin = std::chrono::high_resolution_clock::now(); ggml_vk_submit(subctx, ctx->fence); @@ -5112,7 +6457,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t double err = std::fabs(d[i] - d_chk[i]); avg_err += err; - if (err > 0.05f && first_err_n == -1) { + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { first_err_b = i / (m * n); first_err_n = (i % (m * n)) / m; first_err_m = (i % (m * n)) % m; @@ -5121,14 +6466,14 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t avg_err /= m * n; - std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms avg_err=" << avg_err << std::endl; + double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); - if (avg_err > 0.1) { + std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1 || std::isnan(avg_err)) { std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; std::cerr << "Actual result: " << std::endl << std::endl; ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); - std::cerr << std::endl; - ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b); std::cerr << "Expected result: " << std::endl << std::endl; ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); @@ -5213,9 +6558,9 @@ static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, gg return; } - ggml_type_traits_t tt = ggml_internal_get_type_traits(quant); + const auto * tt = ggml_get_type_traits(quant); - ggml_to_float_t dequant_fn = tt.to_float; + ggml_to_float_t dequant_fn = tt->to_float; dequant_fn(from, to, ne); } @@ -5243,12 +6588,14 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ggml_vk_ctx_begin(ctx->device, subctx); const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; - ggml_vk_dispatch_pipeline(ctx, subctx, p, { { qx_buf, 0, qx_sz }, { x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); ggml_vk_ctx_end(subctx); auto begin = std::chrono::high_resolution_clock::now(); @@ -5309,13 +6656,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, vk_pipeline p; std::string shname; if (shader_size == 0) { - p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s; + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; } else if (shader_size == 1) { - p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m; + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; } else if (shader_size == 2) { - p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l; + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; } else { GGML_ASSERT(0); @@ -5325,13 +6672,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, if (k != kpad) { if (shader_size == 0) { - p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s; + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s; shname = std::string(ggml_type_name(quant)) + "_S"; } else if (shader_size == 1) { - p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m; + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m; shname = std::string(ggml_type_name(quant)) + "_M"; } else if (shader_size == 2) { - p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l; + p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l; shname = std::string(ggml_type_name(quant)) + "_L"; } else { GGML_ASSERT(0); @@ -5375,20 +6722,22 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } } + ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ggml_vk_buffer_write(y_buf, 0, y, y_sz); vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + ggml_vk_ctx_begin(ctx->device, subctx); for (size_t i = 0; i < num_it; i++) { - ggml_vk_ctx_begin(ctx->device, subctx); ggml_vk_matmul( ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, k*m, k*n, m*n, split_k, batch, batch, batch, 1, 1 ); - ggml_vk_ctx_end(subctx); } + ggml_vk_ctx_end(subctx); auto begin = std::chrono::high_resolution_clock::now(); @@ -5442,7 +6791,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, avg_err /= m * n; - std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms avg_err=" << avg_err << std::endl; + double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; @@ -5484,118 +6835,15 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } #endif -static ggml_tensor_extra_gpu * ggml_vk_tensor_create_extra(ggml_tensor * tensor) { - VK_LOG_DEBUG("ggml_vk_create_extra(" << tensor << " (" << tensor->name << ", " << ggml_op_name(tensor->op) << "))"); - ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu; - extra->reset(); - tensor->extra = extra; - return extra; -} - static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { #if defined(GGML_VULKAN_RUN_TESTS) - ctx->staging = ggml_vk_create_buffer_check(ctx->device, 100ul * 1024ul * 1024ul, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, - vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K); - ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL); - - ggml_vk_test_matmul(ctx, 512, 512, 100, 32, 100, 1, 2); - - ggml_vk_test_matmul(ctx, 128, 512, 512, 2, 100, 1, 0); - ggml_vk_test_matmul(ctx, 128, 512, 512, 2, 100, 1, 1); - ggml_vk_test_matmul(ctx, 128, 512, 512, 2, 100, 1, 2); - // ggml_vk_test_matmul(ctx, 128, 512, 512, 2, 100, 4, 0); - // ggml_vk_test_matmul(ctx, 128, 512, 512, 2, 100, 4, 1); - // ggml_vk_test_matmul(ctx, 128, 512, 512, 2, 100, 4, 2); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K); - // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K); - - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL); - ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL); - - std::cerr << std::endl; - const std::vector vals { + 512, 512, 128, + 128, 512, 512, + 4096, 512, 4096, + 11008, 512, 4096, + 4096, 512, 11008, + 32000, 512, 4096, 8, 8, 8, 100, 46, 576, 623, 111, 128, @@ -5608,25 +6856,52 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { 49, 49, 128, 128, 49, 49, 4096, 49, 4096, - 11008, 49, 4096, - 4096, 49, 11008, - 32000, 49, 4096, - 512, 512, 128, - 128, 512, 512, - 4096, 512, 4096, - 11008, 512, 4096, - 4096, 512, 11008, - 32000, 512, 4096, }; - const size_t num_it = 1; + const size_t num_it = 100; + for (size_t i = 0; i < vals.size(); i += 3) { ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2); - // ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); - // ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); - // ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); - std::cerr << std::endl; + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); + std::cerr << '\n' << std::endl; + + if (vals[i + 2] % 32 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0); + std::cerr << '\n' << std::endl; + } + + if (vals[i + 2] % 256 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K); + std::cerr << '\n' << std::endl; + } } GGML_ABORT("fatal error"); @@ -5658,11 +6933,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, bool last_node, bool dryrun){ - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra; +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence); - if (ggml_is_empty(node) || extra == nullptr) { - return; +// Returns true if node has enqueued work into the queue, false otherwise +// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){ + if (ggml_is_empty(node) || !node->buffer) { + return false; } VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); @@ -5671,6 +6948,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod const ggml_tensor * src0 = node->src[0]; const ggml_tensor * src1 = node->src[1]; const ggml_tensor * src2 = node->src[2]; + const ggml_tensor * src3 = node->src[3]; switch (node->op) { // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor @@ -5679,7 +6957,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_NONE: - return; + return false; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_SILU: @@ -5689,7 +6967,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_UNARY_OP_TANH: break; default: - return; + return false; } break; case GGML_OP_REPEAT: @@ -5721,12 +6999,15 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_SUM_ROWS: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: break; default: std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; GGML_ABORT("fatal error"); - return; + return false; } vk_context compute_ctx; @@ -5739,6 +7020,48 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod } else { compute_ctx = ctx->compute_ctx.lock(); } + } else { + switch (node->op) { + case GGML_OP_REPEAT: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_UNARY: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_ROPE: + case GGML_OP_ARGSORT: + case GGML_OP_SUM_ROWS: + case GGML_OP_IM2COL: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_LEAKY_RELU: + { + // These operations all go through ggml_vk_op_f32, so short-circuit and + // do the only thing needed for the dryrun. + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return false; + } + default: + break; + } } switch (node->op) { @@ -5826,7 +7149,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); break; default: - return; + return false; } break; case GGML_OP_DIAG_MASK_INF: @@ -5856,6 +7179,10 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_TIMESTEP_EMBEDDING: ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_POOL_2D: + ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_LEAKY_RELU: ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); @@ -5868,13 +7195,23 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_MUL_MAT_ID: ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + break; + + case GGML_OP_FLASH_ATTN_EXT: + ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); + + break; + + case GGML_OP_RWKV_WKV6: + ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + break; default: - return; + return false; } if (dryrun) { - return; + return false; } ctx->tensor_ctxs[node_idx] = compute_ctx; @@ -5885,15 +7222,35 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod last_node = true; #endif - if (last_node) { + if (submit || last_node) { ggml_vk_ctx_end(compute_ctx); - compute_ctx->exit_tensor_idx = node_idx; + + // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward + if (last_node) { + compute_ctx->exit_tensor_idx = node_idx_begin; + } + else { + compute_ctx->exit_tensor_idx = -1; + } + ctx->compute_ctx.reset(); + + bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false); + if (!ok) { + if (node->op == GGML_OP_UNARY) { + std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } + else { + std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; + } + } + } + return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx){ - ggml_tensor_extra_gpu * extra = nullptr; +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){ + ggml_backend_buffer * buf = nullptr; switch (tensor->op) { case GGML_OP_ADD: @@ -5927,9 +7284,11 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_SUM_ROWS: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: - extra = (ggml_tensor_extra_gpu *) tensor->extra; + buf = tensor->buffer; break; case GGML_OP_UNARY: @@ -5939,7 +7298,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: - extra = (ggml_tensor_extra_gpu *) tensor->extra; + buf = tensor->buffer; break; default: return false; @@ -5947,53 +7306,52 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * break; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - extra = (ggml_tensor_extra_gpu *) tensor->extra; + case GGML_OP_FLASH_ATTN_EXT: + buf = tensor->buffer; break; default: return false; } - if (extra == nullptr) { + if (buf == nullptr) { return false; } VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); -#ifdef GGML_VULKAN_CHECK_RESULTS - ggml_vk_check_results_0(tensor); -#endif - vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); -#ifdef GGML_VULKAN_PERF - std::chrono::steady_clock::time_point start; -#endif // GGML_VULKAN_PERF + // always wait for the GPU work to be done for the last submit + if (tensor_idx == subctx->exit_tensor_idx) { + use_fence = true; + } // Only run if ctx hasn't been submitted yet if (!subctx->seqs.empty()) { +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_0(tensor); + use_fence = true; +#endif + // Do staging buffer copies for (auto& cpy : subctx->in_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } -#ifdef GGML_VULKAN_PERF - start = std::chrono::steady_clock::now(); -#endif // GGML_VULKAN_PERF + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); - ggml_vk_submit(subctx, ctx->fence); + if (use_fence) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); + + ctx->device->device.resetFences({ ctx->fence }); + } +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_1(tensor); +#endif } if (tensor_idx == subctx->exit_tensor_idx) { - VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); - -#ifdef GGML_VULKAN_PERF - auto duration = std::chrono::duration_cast(std::chrono::steady_clock::now() - start); - ctx->device->perf_logger->log_timing(tensor, duration.count()); -#endif // GGML_VULKAN_PERF - - ctx->device->device.resetFences({ ctx->fence }); - // Do staging buffer copies for (auto& cpy : subctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); @@ -6074,13 +7432,13 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->device->device.destroyFence(ctx->fence); } -GGML_CALL static int ggml_vk_get_device_count() { +static int ggml_vk_get_device_count() { ggml_vk_instance_init(); return vk_instance.device_indices.size(); } -GGML_CALL static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { +static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { ggml_vk_instance_init(); std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); @@ -6097,111 +7455,56 @@ GGML_CALL static void ggml_vk_get_device_description(int device, char * descript // device backend -static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT - -struct ggml_backend_vk_buffer_context { - vk_device_ref device; - vk_buffer dev_buffer; - ggml_tensor_extra_gpu * temp_tensor_extras = nullptr; - size_t temp_tensor_extra_index = 0; - std::string name; - - ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : - device(device), - dev_buffer(dev_buffer), - name(name) { - } - - ~ggml_backend_vk_buffer_context() { - ggml_vk_destroy_buffer(dev_buffer); - if (temp_tensor_extras != nullptr) { - delete[] temp_tensor_extras; - } - } - - ggml_tensor_extra_gpu * ggml_vk_alloc_temp_tensor_extra() { - if (temp_tensor_extras == nullptr) { - temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_VK_MAX_NODES]; - } - - size_t alloc_index = temp_tensor_extra_index; - temp_tensor_extra_index = (temp_tensor_extra_index + 1) % GGML_VK_MAX_NODES; - ggml_tensor_extra_gpu * extra = &temp_tensor_extras[alloc_index]; - extra->reset(); - - return extra; - } -}; - -GGML_CALL static const char * ggml_backend_vk_buffer_get_name(ggml_backend_buffer_t buffer) { - ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; - return ctx->name.c_str(); +static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name; } -GGML_CALL static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { - return buffer->iface.get_name == ggml_backend_vk_buffer_get_name; -} - -GGML_CALL static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { +static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()"); ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ggml_vk_destroy_buffer(ctx->dev_buffer); delete ctx; } -GGML_CALL static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { +static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { return vk_ptr_base; UNUSED(buffer); } -GGML_CALL static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); - ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; - if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); - GGML_ASSERT(tensor->view_src->extra != nullptr); - tensor->extra = tensor->view_src->extra; - } else { - ggml_tensor_extra_gpu * extra = ctx->ggml_vk_alloc_temp_tensor_extra(); - extra->buffer_gpu = ctx->dev_buffer; - extra->offset = (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; - tensor->extra = extra; } } -GGML_CALL static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; - vk_buffer buf = extra->buffer_gpu.lock(); - - ggml_vk_buffer_write(buf, extra->offset + tensor->view_offs + offset, data, size); - - GGML_UNUSED(buffer); + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } -GGML_CALL static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { +static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; - vk_buffer buf = extra->buffer_gpu.lock(); + vk_buffer buf = buf_ctx->dev_buffer; - ggml_vk_buffer_read(buf, extra->offset + tensor->view_offs + offset, data, size); - - GGML_UNUSED(buffer); + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } -GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { +static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { if (ggml_backend_buffer_is_vk(src->buffer)) { - ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra; - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - vk_buffer src_buf = src_extra->buffer_gpu.lock(); - vk_buffer dst_buf = dst_extra->buffer_gpu.lock(); + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; - ggml_vk_buffer_copy(dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src)); + ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); return true; } @@ -6210,17 +7513,17 @@ GGML_CALL static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t bu UNUSED(buffer); } -GGML_CALL static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { +static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size); } static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { - /* .get_name = */ ggml_backend_vk_buffer_get_name, /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, /* .get_base = */ ggml_backend_vk_buffer_get_base, /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, @@ -6229,13 +7532,13 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { }; // vk buffer type -GGML_CALL static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context; return ctx->name.c_str(); } -GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")"); ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; @@ -6251,23 +7554,23 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer( return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size); } -GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; return ctx->device->properties.limits.minStorageBufferOffsetAlignment; } -GGML_CALL static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; return ctx->device->max_memory_allocation_size; } -GGML_CALL static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { return ggml_nbytes(tensor); UNUSED(buft); } -GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { +ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { ggml_vk_instance_init(); VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")"); @@ -6279,24 +7582,24 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) // host buffer type -GGML_CALL static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { +static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { return GGML_VK_NAME "_Host"; UNUSED(buft); } -GGML_CALL static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { +static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { return GGML_VK_NAME "_Host"; UNUSED(buffer); } -GGML_CALL static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { +static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); } -GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { +static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")"); size += 32; // Behave like the CPU buffer type @@ -6312,7 +7615,6 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_bu ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); buffer->buft = buft; - buffer->iface.get_name = ggml_backend_vk_host_buffer_name; buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer; return buffer; @@ -6320,7 +7622,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_bu UNUSED(buft); } -GGML_CALL static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { +static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment; UNUSED(buft); @@ -6328,7 +7630,7 @@ GGML_CALL static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_back // Should be changed to return device-specific host buffer type // but that probably requires changes in llama.cpp -GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { +ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = { /* .iface = */ { /* .get_name = */ ggml_backend_vk_host_buffer_type_name, @@ -6338,6 +7640,7 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0), /* .context = */ nullptr, }; @@ -6351,13 +7654,13 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { // backend -GGML_CALL static const char * ggml_backend_vk_name(ggml_backend_t backend) { +static const char * ggml_backend_vk_name(ggml_backend_t backend) { ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; return ctx->name.c_str(); } -GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) { +static void ggml_backend_vk_free(ggml_backend_t backend) { ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")"); @@ -6367,18 +7670,18 @@ GGML_CALL static void ggml_backend_vk_free(ggml_backend_t backend) { delete backend; } -GGML_CALL static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { +static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; return &ctx->device->buffer_type; } -GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { +static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; vk_context transfer_ctx; @@ -6391,17 +7694,17 @@ GGML_CALL static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, g transfer_ctx = ctx->transfer_ctx.lock(); } - vk_buffer buf = extra->buffer_gpu.lock(); + vk_buffer buf = buf_ctx->dev_buffer; - ggml_vk_buffer_write_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size); + ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } -GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { +static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; vk_context transfer_ctx; @@ -6414,17 +7717,17 @@ GGML_CALL static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, c transfer_ctx = ctx->transfer_ctx.lock(); } - vk_buffer buf = extra->buffer_gpu.lock(); + vk_buffer buf = buf_ctx->dev_buffer; - ggml_vk_buffer_read_async(transfer_ctx, buf, extra->offset + tensor->view_offs + offset, data, size); + ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } -GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { +static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { - ggml_tensor_extra_gpu * src_extra = (ggml_tensor_extra_gpu *) src->extra; - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; vk_context transfer_ctx; @@ -6437,17 +7740,17 @@ GGML_CALL static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, c transfer_ctx = ctx->transfer_ctx.lock(); } - vk_buffer src_buf = src_extra->buffer_gpu.lock(); - vk_buffer dst_buf = dst_extra->buffer_gpu.lock(); + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; - ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, dst_extra->offset + dst->view_offs, src_buf, src_extra->offset + src->view_offs, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); return true; } return false; } -GGML_CALL static void ggml_backend_vk_synchronize(ggml_backend_t backend) { +static void ggml_backend_vk_synchronize(ggml_backend_t backend) { VK_LOG_DEBUG("ggml_backend_vk_synchronize()"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; if(ctx->transfer_ctx.expired()) { @@ -6477,12 +7780,15 @@ static bool ggml_vk_is_empty(ggml_tensor * node) { return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; } -GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { +static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, 0, true); + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); + } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); } ggml_vk_preallocate_buffers(ctx); ggml_pipeline_allocate_descriptor_sets(ctx->device); @@ -6497,31 +7803,46 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen // Reserve tensor context space for all nodes ctx->tensor_ctxs.resize(cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, i == last_node, false); - } + bool first_node_in_batch = true; // true if next node will be first node in a batch + int submit_node_idx = 0; // index to first node in a batch + // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution. + // Start with a smaller count to get work submitted right away, and increase it after each submit. + int nodes_per_submit = 20; + int submitted_nodes = 0; + int submit_count = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - - if (ggml_vk_is_empty(node)) { - continue; + if (first_node_in_batch) { + submit_node_idx = i; } - bool ok = ggml_vk_compute_forward(ctx, node, i); - if (!ok) { - if (node->op == GGML_OP_UNARY) { - std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; - } else { - std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; + bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node); + + bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); + + if (enqueued) { + ++submitted_nodes; + +#ifndef GGML_VULKAN_CHECK_RESULTS + if (first_node_in_batch) { + first_node_in_batch = false; } - } -#ifdef GGML_VULKAN_CHECK_RESULTS - else { - ggml_vk_check_results_1(node); - } #endif - GGML_ASSERT(ok); + } + + if (submit) { + first_node_in_batch = true; + submitted_nodes = 0; + switch (submit_count) { + case 0: + nodes_per_submit = 50; + break; + default: + nodes_per_submit = 100; + break; + } + submit_count++; + } } #ifdef GGML_VULKAN_PERF @@ -6535,9 +7856,132 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen UNUSED(backend); } -GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tensor * op) { - // ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context; +// TODO: enable async and synchronize +static ggml_backend_i ggml_backend_vk_interface = { + /* .get_name = */ ggml_backend_vk_name, + /* .free = */ ggml_backend_vk_free, + /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, + /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, + /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_vk_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, +}; +static ggml_guid_t ggml_backend_vk_guid() { + static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; + return &guid; +} + +ggml_backend_t ggml_backend_vk_init(size_t dev_num) { + VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")"); + + ggml_backend_vk_context * ctx = new ggml_backend_vk_context; + ggml_vk_init(ctx, dev_num); + + ggml_backend_t vk_backend = new ggml_backend { + /* .guid = */ ggml_backend_vk_guid(), + /* .interface = */ ggml_backend_vk_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), + /* .context = */ ctx, + }; + + return vk_backend; +} + +bool ggml_backend_is_vk(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid()); +} + +int ggml_backend_vk_get_device_count() { + return ggml_vk_get_device_count(); +} + +void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; + ggml_vk_get_device_description(dev_idx, description, description_size); +} + +void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; + + vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + + for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; + *free = heap.size; + break; + } + } +} + +////////////////////////// + +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; + ggml_backend_vk_get_device_memory(ctx->device, free, total); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return ggml_backend_vk_host_buffer_type(); +} + +static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); + props->type = ggml_backend_vk_device_get_type(dev); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ true, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { + UNUSED(params); + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_init(ctx->device); +} + +static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -6554,6 +7998,12 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) { + // If there's not enough shared memory for row_ids and the result tile, fallback to CPU + return false; + } switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: @@ -6567,6 +8017,11 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: break; default: @@ -6584,8 +8039,69 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const if (a->ne[3] != b->ne[3]) { return false; } + if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) || + !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { + return false; + } + return true; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + if (!ggml_vk_get_device(ctx->device)->coopmat2) { + return false; + } + switch (op->src[0]->ne[0]) { + case 64: + case 80: + case 96: + case 112: + case 128: + case 256: + break; + default: + return false; + } + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + if (op->type != GGML_TYPE_F32) { + return false; + } + if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { + return false; + } + // It's straightforward to support different K/V dequant, but would + // significantly increase the number of pipelines + if (op->src[1]->type != op->src[2]->type) { + return false; + } + switch (op->src[1]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently + //case GGML_TYPE_Q2_K: + //case GGML_TYPE_Q3_K: + //case GGML_TYPE_Q4_K: + //case GGML_TYPE_Q5_K: + //case GGML_TYPE_Q6_K: + //case GGML_TYPE_IQ2_XXS: + //case GGML_TYPE_IQ2_XS: + //case GGML_TYPE_IQ2_S: + //case GGML_TYPE_IQ3_XXS: + //case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + break; + default: + return false; + } + return true; + } case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { @@ -6596,6 +8112,11 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: return true; default: @@ -6608,12 +8129,36 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const { ggml_type src0_type = op->src[0]->type; ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { - return true; + + if (src0_type == GGML_TYPE_F32) { + switch (src1_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } } - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { - return true; + if (src1_type == GGML_TYPE_F32) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { return true; } @@ -6622,7 +8167,16 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_REPEAT: return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); case GGML_OP_ROPE: - return ggml_is_contiguous(op->src[0]); + { + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & GGML_ROPE_TYPE_VISION) { + return false; + } + return ggml_is_contiguous(op->src[0]); + } case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -6649,126 +8203,110 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_SUM_ROWS: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: return true; default: return false; } - UNUSED(backend); + UNUSED(dev); } -GGML_CALL static bool ggml_backend_vk_offload_op(ggml_backend_t backend, const ggml_tensor * op) { +static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { + return false; + } + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return buft_ctx->device->idx == ctx->device; +} + +static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { const int min_batch_size = 32; return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); - UNUSED(backend); + UNUSED(dev); } -GGML_CALL static bool ggml_backend_vk_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { - return false; - } - - ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; - ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; - - return buft_ctx->device == ctx->device; -} - -// TODO: enable async and synchronize -static ggml_backend_i ggml_backend_vk_interface = { - /* .get_name = */ ggml_backend_vk_name, - /* .free = */ ggml_backend_vk_free, - /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type, - /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, - /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, - /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, - /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_vk_graph_compute, - /* .supports_op = */ ggml_backend_vk_supports_op, - /* .supports_buft = */ ggml_backend_vk_supports_buft, - /* .offload_op = */ ggml_backend_vk_offload_op, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .event_synchronize = */ NULL, +static const struct ggml_backend_device_i ggml_backend_vk_device_i = { + /* .get_name = */ ggml_backend_vk_device_get_name, + /* .get_description = */ ggml_backend_vk_device_get_description, + /* .get_memory = */ ggml_backend_vk_device_get_memory, + /* .get_type = */ ggml_backend_vk_device_get_type, + /* .get_props = */ ggml_backend_vk_device_get_props, + /* .init_backend = */ ggml_backend_vk_device_init, + /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_vk_device_supports_op, + /* .supports_buft = */ ggml_backend_vk_device_supports_buft, + /* .offload_op = */ ggml_backend_vk_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, }; -static ggml_guid_t ggml_backend_vk_guid() { - static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; - return &guid; +static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { + UNUSED(reg); + return GGML_VK_NAME; } -GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num) { - VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")"); - - ggml_backend_vk_context * ctx = new ggml_backend_vk_context; - ggml_vk_init(ctx, dev_num); - - ggml_backend_t vk_backend = new ggml_backend { - /* .guid = */ ggml_backend_vk_guid(), - /* .interface = */ ggml_backend_vk_interface, - /* .context = */ ctx, - }; - - return vk_backend; +static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) { + UNUSED(reg); + return ggml_backend_vk_get_device_count(); } -GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend) { - return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid()); -} +static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) { + static std::vector devices; -GGML_CALL int ggml_backend_vk_get_device_count() { - return ggml_vk_get_device_count(); -} + static bool initialized = false; -GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { - ggml_vk_get_device_description(device, description, description_size); -} - -GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { - GGML_ASSERT(device < (int) vk_instance.device_indices.size()); - - vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; - - vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); - - for (const vk::MemoryHeap& heap : memprops.memoryHeaps) { - if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { - *total = heap.size; - *free = heap.size; - break; + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { + ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; + char desc[256]; + ggml_backend_vk_get_device_description(i, desc, sizeof(desc)); + ctx->device = i; + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, + /* .context = */ ctx, + }); + } + initialized = true; } } + + GGML_ASSERT(device < devices.size()); + return devices[device]; } -// backend registry -GGML_CALL static ggml_backend_t ggml_backend_reg_vk_init(const char * params, void * user_data) { - ggml_backend_t vk_backend = ggml_backend_vk_init((int) (intptr_t) user_data); - return vk_backend; +static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = { + /* .get_name = */ ggml_backend_vk_reg_get_name, + /* .get_device_count = */ ggml_backend_vk_reg_get_device_count, + /* .get_device = */ ggml_backend_vk_reg_get_device, + /* .get_proc_address = */ NULL, +}; - UNUSED(params); -} +ggml_backend_reg_t ggml_backend_vk_reg() { + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_vk_reg_i, + /* .context = */ nullptr, + }; -extern "C" GGML_CALL int ggml_backend_vk_reg_devices(); - -GGML_CALL int ggml_backend_vk_reg_devices() { - ggml_vk_instance_init(); - - for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { - char name[128]; - snprintf(name, sizeof(name), "%s%ld", GGML_VK_NAME, i); - ggml_backend_register(name, ggml_backend_reg_vk_init, ggml_backend_vk_buffer_type(i), (void *) (intptr_t) i); // NOLINT - } - return vk_instance.device_indices.size(); + return ® } // Extension availability @@ -6807,6 +8345,25 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve UNUSED(instance_extensions); } +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { + switch (props.vendorID) { + case VK_VENDOR_ID_INTEL: + // Intel drivers don't support coopmat properly yet + return false; + case VK_VENDOR_ID_AMD: + if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { + // Workaround for AMD proprietary driver reporting support on all GPUs + const std::string name = props.deviceName; + return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs + name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs + name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs + } + return true; + default: + return true; + } +} + // checks #ifdef GGML_VULKAN_CHECK_RESULTS @@ -6873,10 +8430,10 @@ static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) const size_t tensor_size = ggml_nbytes(tensor); tensor_data = malloc(tensor_size); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); + vk_buffer buffer_gpu = buf_ctx->dev_buffer; + ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size); } std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; @@ -6917,6 +8474,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src1 = tensor->src[1]; ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; struct ggml_init_params iparams = { /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, @@ -6929,15 +8487,18 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { struct ggml_tensor * src0_clone = nullptr; struct ggml_tensor * src1_clone = nullptr; struct ggml_tensor * src2_clone = nullptr; + struct ggml_tensor * src3_clone = nullptr; struct ggml_tensor * tensor_clone = nullptr; size_t src0_size; size_t src1_size; size_t src2_size; + size_t src3_size; void * src0_buffer = nullptr; void * src1_buffer = nullptr; void * src2_buffer = nullptr; + void * src3_buffer = nullptr; if (src0 != nullptr) { src0_clone = ggml_dup_tensor(ggml_ctx, src0); @@ -6950,9 +8511,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { memcpy(src0_clone->data, src0->data, src0_size); memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS); } else if (ggml_backend_buffer_is_vk(src0->buffer)) { - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src0->extra; - vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - uint64_t offset = extra->offset + src0->view_offs; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src0) + src0->view_offs; if (!ggml_is_contiguous(src0) && ggml_vk_dim01_contiguous(src0)) { for (int i3 = 0; i3 < src0->ne[3]; i3++) { for (int i2 = 0; i2 < src0->ne[2]; i2++) { @@ -6992,9 +8553,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { memcpy(src1_clone->data, src1->data, src1_size); memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS); } else if (ggml_backend_buffer_is_vk(src1->buffer)) { - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src1->extra; - vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - uint64_t offset = extra->offset + src1->view_offs; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src1) + src1->view_offs; if (!ggml_is_contiguous(src1) && ggml_vk_dim01_contiguous(src1)) { for (int i3 = 0; i3 < src1->ne[3]; i3++) { for (int i2 = 0; i2 < src1->ne[2]; i2++) { @@ -7034,9 +8595,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { memcpy(src2_clone->data, src2->data, src2_size); memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); } else if (ggml_backend_buffer_is_vk(src2->buffer)) { - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra; - vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - uint64_t offset = extra->offset + src2->view_offs; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src2->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src2) + src2->view_offs; if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { for (int i3 = 0; i3 < src2->ne[3]; i3++) { for (int i2 = 0; i2 < src2->ne[2]; i2++) { @@ -7065,8 +8626,53 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { ggml_vk_print_tensor(src2, "src2"); } } + if (src3 != nullptr) { + src3_clone = ggml_dup_tensor(ggml_ctx, src3); - if (tensor->op == GGML_OP_MUL_MAT) { + src3_size = ggml_nbytes(src3); + + src3_buffer = malloc(src3_size); + src3_clone->data = src3_buffer; + if (ggml_backend_buffer_is_host(src3->buffer)) { + memcpy(src3_clone->data, src3->data, src3_size); + memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(src3->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(src3) + src3->view_offs; + if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) { + for (int i3 = 0; i3 < src3->ne[3]; i3++) { + for (int i2 = 0; i2 < src3->ne[2]; i2++) { + const int idx = i3*src3->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]); + } + } + + src3_clone->nb[0] = src3->nb[0]; + src3_clone->nb[1] = src3->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1]; + } + } else { + if (offset + src3_size >= buffer_gpu->size) { + src3_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size); + memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(src3, "src3"); + } + } + + if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { + const float *params = (const float *)tensor->op_params; + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]); + } else if (tensor->op == GGML_OP_MUL_MAT) { tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); } else if (tensor->op == GGML_OP_MUL_MAT_ID) { tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone); @@ -7091,7 +8697,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_PAD) { tensor_clone = ggml_pad(ggml_ctx, src0_clone, tensor->ne[0] - src0_clone->ne[0], tensor->ne[1] - src0_clone->ne[1], tensor->ne[2] - src0_clone->ne[2], tensor->ne[3] - src0_clone->ne[3]); } else if (tensor->op == GGML_OP_REPEAT) { - tensor_clone = ggml_repeat(ggml_ctx, src0_clone, src1_clone); + tensor_clone = ggml_repeat(ggml_ctx, src0_clone, tensor); } else if (tensor->op == GGML_OP_ADD) { tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone); } else if (tensor->op == GGML_OP_ACC) { @@ -7181,10 +8787,24 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period); + } else if (tensor->op == GGML_OP_POOL_2D) { + enum ggml_op_pool op = static_cast(tensor->op_params[0]); + const int32_t k0 = tensor->op_params[1]; + const int32_t k1 = tensor->op_params[2]; + const int32_t s0 = tensor->op_params[3]; + const int32_t s1 = tensor->op_params[4]; + const int32_t p0 = tensor->op_params[5]; + const int32_t p1 = tensor->op_params[6]; + + tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); - } else { + } else if (tensor->op == GGML_OP_RWKV_WKV6) { + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], + tensor->src[4], tensor->src[5]); + } + else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); } @@ -7229,6 +8849,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { ggml_tensor * src0 = tensor->src[0]; ggml_tensor * src1 = tensor->src[1]; ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; void * tensor_data = tensor->data; @@ -7236,14 +8857,15 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { size_t tensor_size = ggml_nbytes(tensor); tensor_data = malloc(tensor_size); - ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_buffer buffer_gpu = extra->buffer_gpu.lock(); - if (extra->offset + tensor->view_offs + tensor_size >= buffer_gpu->size) { - tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs); + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs; + if (offset + tensor_size >= buffer_gpu->size) { + tensor_size = buffer_gpu->size - offset; } - ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size); + ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size); } float first_error_result = -1.0f; @@ -7290,6 +8912,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); @@ -7334,6 +8959,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); @@ -7356,6 +8984,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { if (src2 != nullptr) { std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; std::cerr << std::endl << "Result:" << std::endl; ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); @@ -7380,3 +9011,5 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")"); } #endif + +GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg) diff --git a/ggml/src/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt similarity index 56% rename from ggml/src/vulkan-shaders/CMakeLists.txt rename to ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index 10075db33..074031087 100644 --- a/ggml/src/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -1,7 +1,11 @@ find_package (Threads REQUIRED) +find_program(GLSLC_EXECUTABLE glslc) +if(NOT GLSLC_EXECUTABLE) + message(FATAL_ERROR "glslc not found.") +endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp new file mode 100644 index 000000000..d896f1ef0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + const uint offset = p.param3; + const uint src1_i = idx - offset; + const uint oz = src1_i / p.nb02; + const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; + const uint ox = src1_i % p.nb01; + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + } else { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + } +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp new file mode 100644 index 000000000..2b4085c4f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ggml/src/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp similarity index 94% rename from ggml/src/vulkan-shaders/argsort.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp index e55414b03..d4fa45b1e 100644 --- a/ggml/src/vulkan-shaders/argsort.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -29,20 +29,18 @@ void main() { const int col = int(gl_LocalInvocationID.x); const uint row = gl_WorkGroupID.y; - if (col >= p.ncols_pad) { - return; - } - const uint row_offset = row * p.ncols; // initialize indices - dst_row[col] = col; + if (col < p.ncols_pad) { + dst_row[col] = col; + } barrier(); for (uint k = 2; k <= p.ncols_pad; k *= 2) { for (uint j = k / 2; j > 0; j /= 2) { const uint ixj = col ^ j; - if (ixj > col) { + if (col < p.ncols_pad && ixj > col) { if ((col & k) == 0) { if (dst_row[col] >= p.ncols || (dst_row[ixj] < p.ncols && (p.order == ASC ? diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp new file mode 100644 index 000000000..1e5cb8dae --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); +} diff --git a/ggml/src/vulkan-shaders/concat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp similarity index 76% rename from ggml/src/vulkan-shaders/concat.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/concat.comp index c23b6eb1b..9ee2f1fae 100644 --- a/ggml/src/vulkan-shaders/concat.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp @@ -3,6 +3,8 @@ #include "types.comp" #include "generic_binary_head.comp" +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + void main() { const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; const int dim = p.param3; @@ -28,12 +30,12 @@ void main() { const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; #ifndef OPTIMIZATION_ERROR_WORKAROUND - data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]); + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]); #else if (is_src0) { - data_d[p.d_offset + dst_idx] = data_a[src0_idx]; + data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx]; } else { - data_d[p.d_offset + dst_idx] = data_b[src1_idx]; + data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx]; } #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp new file mode 100644 index 000000000..dd828c232 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +#extension GL_EXT_control_flow_attributes : require + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + // fast path for when all four iterations are in-bounds + if (idx + (num_iter-1)*num_threads < p.ne) { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } else { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp new file mode 100644 index 000000000..29c906494 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); +#else + data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; +#endif +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp new file mode 100644 index 000000000..aeae5400d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" +#include "dequant_funcs.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +void main() { +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = get_doffset() + dst_idx(idx); + uint src_idx = src0_idx_quant(idx, QUANT_K); + + const uint a_offset = 0; + const uint ib = src_idx; + const vec2 dm = get_dm(ib, a_offset); + + [[unroll]] for (int j = 0; j < QUANT_K; j += 4) { + vec4 v = dequantize4(ib, j / QUANT_R, a_offset); + v = v * dm.x + vec4(dm.y); + +#if QUANT_R == 2 + data_d[dst_idx + j/2 + 0] = v[0]; + data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1]; + data_d[dst_idx + j/2 + 1] = v[2]; + data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3]; +#else + data_d[dst_idx + j + 0] = v[0]; + data_d[dst_idx + j + 1] = v[1]; + data_d[dst_idx + j + 2] = v[2]; + data_d[dst_idx + j + 3] = v[3]; +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp new file mode 100644 index 000000000..d4b068e61 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -0,0 +1,237 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +#if defined(DATA_A_IQ4_NL) +// 16 invocations needed for init_iq4nl_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +layout (binding = 0) readonly buffer S {float data_s[];}; +layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; + +#if defined(DATA_A_Q4_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id; + + const uint xi0 = min(15, int(x0 + 8.5)); + const uint xi1 = min(15, int(x1 + 8.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q4_1) +void quantize(uint dst_idx, uint src_idx) +{ + float vmin = 1.0/0.0; + float vmax = -vmin; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) { + const float v = data_s[src_idx + j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(vmin); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - vmin)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id; + + const uint xi0 = min(15, int(x0 + 0.5)); + const uint xi1 = min(15, int(x1 + 0.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q5_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id; + + const uint xi0 = min(31, int(x0 + 16.5)); + const uint xi1 = min(31, int(x1 + 16.5)); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2); + } + data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF); + data_q[dst_idx].qh[1] = uint16_t(qh >> 16); +} +#endif + +#if defined(DATA_A_Q5_1) +void quantize(uint dst_idx, uint src_idx) +{ + float min = data_s[src_idx + 0]; + float max = min; + + [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) { + const float v = data_s[src_idx + j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = (d != 0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(min); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - min)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id; + + const uint xi0 = uint(x0 + 0.5); + const uint xi1 = uint(x1 + 0.5); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2); + } + data_q[dst_idx].qh = qh; +} +#endif + +#if defined(DATA_A_Q8_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; // absolute max + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) { + const float v = data_s[src_idx + j]; + amax = max(amax, abs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) { + const float x0 = data_s[src_idx + j]*id; + + data_q[dst_idx].qs[j] = int8_t(round(x0)); + } +} +#endif + +#if defined(DATA_A_IQ4_NL) +uint best_index(float x) { + if (x <= kvalues_iq4nl[0]) return 0; + if (x >= kvalues_iq4nl[15]) return 15; + int ml = 0, mu = 15; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav; + } + return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu; +} + +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + float sumqx = 0, sumq2 = 0; + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id; + const uint xi0 = best_index(x0); + const uint xi1 = best_index(x1); + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j]; + const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d); + +} +#endif + +void main() { +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = dst_idx_quant(idx, QUANT_K); + uint src_idx = get_aoffset() + src0_idx(idx); + + quantize(dst_idx, src_idx); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp new file mode 100644 index 000000000..0b8d02f58 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); +} diff --git a/ggml/src/vulkan-shaders/dequant_f32.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_f32.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp new file mode 100644 index 000000000..ee6877531 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -0,0 +1,334 @@ +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#endif + +#include "types.comp" + +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + +#if defined(DATA_A_F32) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_F16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_Q4_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2(vui & 0xF, vui >> 4) - 8.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); +} +#endif + +#if defined(DATA_A_Q4_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(vui & 0xF, vui >> 4); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); +} +#endif + +#if defined(DATA_A_Q5_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f); +} +#endif + +#if defined(DATA_A_Q5_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a[a_offset + ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a_packed16[a_offset + ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y); +} +#endif + +#if defined(DATA_A_Q8_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; + uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; + return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)); +} +#endif + +#if defined(DATA_A_IQ2_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid[iqs % 4] * (sign0 ? -1.0 : 1.0), + grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint qs = data_a[a_offset + ib].qs[iqs / 4]; + const uint qh = data_a[a_offset + ib].qh[iqs / 32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[iqs / 64]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4)); + return db * vec2( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[ib32 / 2]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4)); + return db * vec4( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0), + int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0), + int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ4_NL) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); +} +#endif + +#if defined(DATA_A_F32) || defined(DATA_A_F16) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(0, 0); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), 0); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp new file mode 100644 index 000000000..974efd3f9 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -0,0 +1,509 @@ + +#include "types.comp" + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { + block_q4_0_packed16 block; +}; + +float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); + qs >>= shift; + qs &= 0x0F0F; + qs = unpack8(qs)[idx & 1]; + float16_t ret = (float16_t(qs) - float16_t(8)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { + block_q4_1 block; +}; + +float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(qs) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { + block_q5_0 block; +}; + +float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { + block_q5_1 block; +}; + +float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = bl.block.qh; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = float16_t(qs | qh) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { + block_q8_0_packed16 block; +}; + +float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + // Load 16b and select the byte for this element + int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; + float16_t ret = float16_t(qs) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { + block_q2_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 { + block_q2_K_packed16 block; +}; + +float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); + const f16vec2 d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint scalesi = (idx & 0xF0) >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qs = (qs >> qsshift) & 0x0303; + qs = unpack8(qs)[idx & 1]; + + const uint scales = bl.block.scales[scalesi]; + float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { + block_q3_K block; +}; + +float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + const uint n = iqs / 128; // 0,1 + const uint qsi = n * 32 + (iqs % 32); // 0..63 + const uint hmi = (iqs % 32); // 0..31 + const uint j = (iqs % 128) / 8; // 0..15 + const uint is = iqs / 16; // 0..15 + const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + uint32_t scaleidx0 = (is < 8) ? is : (is-8); + uint32_t scaleidx0shift = (is < 8) ? 0 : 4; + uint32_t scaleidx1 = is + 8 - (is/4)*4; + uint32_t scaleidx1shift = (is/4)*2; + + const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + + const float16_t dl = bl.block.d * float16_t(us - 32); + + float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4)); + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { + block_q4_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 { + block_q4_K_packed16 block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { + block_q4_K_packed128 block; +}; + +float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + + uvec4 v = bl128.block.q4k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; + + float16_t ret = d * float16_t(qs) - m; + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { + block_q5_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 { + block_q5_K_packed16 block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 { + block_q5_K_packed128 block; +}; + +float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + + uvec4 v = bl128.block.q5k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); + + uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); + qh = ((qh >> is) & 0x101) << 4; + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4)) & 0x0F0F; + qs = unpack8(qs | qh)[idx & 1]; + + float16_t ret = d * (float16_t(qs)) - m; + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { + block_q6_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 { + block_q6_K_packed16 block; +}; + +float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; // 0,1 + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = (idx & 0xF0) >> 4; // 0..15 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]); + ql = (ql >> (b * 4)) & 0x0F0F; + + uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qh = ((qh >> qhshift) & 0x0303) << 4; + + int q = unpack8(ql | qh)[idx & 1]; + + float16_t ret = dscale * float16_t(q - 32); + + return ret; +} + +#if defined(DATA_A_IQ2_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS { + block_iq2_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 { + block_iq2_xxs_packed16 block; +}; + +float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0x18) >> 3; // 0..3 + const uint iqs = 8 * ib32 + ib8; + + const uint8_t qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + + const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28)); + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + + const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3]; + + float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + + return ret; +} +#endif + +#if defined(DATA_A_IQ2_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS { + block_iq2_xs block; +}; + +float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint is = (idx & 0xE0) >> 5; // 0..8 + const uint sshift = (idx & 0x10) >> 2; // 0,4 + const uint iqs = (idx & 0xF8) >> 3; // 0..63 + + const uint16_t qs = bl.block.qs[iqs]; + const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3]; + + float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return ret; +} +#endif + +#if defined(DATA_A_IQ2_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S { + block_iq2_s block; +}; + +float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + uint lsb = idx & 1; + idx /= 2; + + const uint ib8 = (idx % 128) / 4; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (bl.block.scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); + return float16_t(v[lsb]); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS { + block_iq3_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 { + block_iq3_xxs_packed16 block; +}; + +float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + uint lsb = idx & 1; + idx /= 2; + + const uint iqs = (idx % 128) / 2; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u8vec4( + bl.block.qs[is+0], + bl.block.qs[is+1], + bl.block.qs[is+2], + bl.block.qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + return float16_t(v[lsb]); +} +#endif + +#if defined(DATA_A_IQ3_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S { + block_iq3_s block; +}; + +float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + uint lsb = idx & 1; + idx /= 2; + + const uint iqs = (idx % 128) / 2; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (2 * (idx % 4))); + const uint scale = bl.block.scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + return float16_t(v[lsb]); +} +#endif + + +#if defined(DATA_A_IQ4_NL) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { + block_iq4_nl block; +}; + +float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; + return ret; +} +#endif + +#if defined(DATA_A_Q4_0) +#define dequantFuncA dequantFuncQ4_0 +#elif defined(DATA_A_Q4_1) +#define dequantFuncA dequantFuncQ4_1 +#elif defined(DATA_A_Q5_0) +#define dequantFuncA dequantFuncQ5_0 +#elif defined(DATA_A_Q5_1) +#define dequantFuncA dequantFuncQ5_1 +#elif defined(DATA_A_Q8_0) +#define dequantFuncA dequantFuncQ8_0 +#elif defined(DATA_A_Q2_K) +#define dequantFuncA dequantFuncQ2_K +#elif defined(DATA_A_Q3_K) +#define dequantFuncA dequantFuncQ3_K +#elif defined(DATA_A_Q4_K) +#define dequantFuncA dequantFuncQ4_K +#elif defined(DATA_A_Q5_K) +#define dequantFuncA dequantFuncQ5_K +#elif defined(DATA_A_Q6_K) +#define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ2_XXS) +#define dequantFuncA dequantFuncIQ2_XXS +#elif defined(DATA_A_IQ2_XS) +#define dequantFuncA dequantFuncIQ2_XS +#elif defined(DATA_A_IQ2_S) +#define dequantFuncA dequantFuncIQ2_S +#elif defined(DATA_A_IQ3_XXS) +#define dequantFuncA dequantFuncIQ3_XXS +#elif defined(DATA_A_IQ3_S) +#define dequantFuncA dequantFuncIQ3_S +#elif defined(DATA_A_IQ4_NL) +#define dequantFuncA dequantFuncIQ4_NL +#endif diff --git a/ggml/src/vulkan-shaders/dequant_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp new file mode 100644 index 000000000..48f6b65bc --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + uint qh = data_a[ib].qh[ib32]; + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; + qs |= (qh << (8 - 2 * l)) & 0x300; + const uvec2 grid = iq2s_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp new file mode 100644 index 000000000..a08331c40 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint16_t qs = data_a[ib].qs[4 * ib32 + l]; + const uint sign7 = qs >> 9; + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xs_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp new file mode 100644 index 000000000..e370690bc --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -0,0 +1,48 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[8*is + 4], + data_a[ib].qs[8*is + 5], + data_a[ib].qs[8*is + 6], + data_a[ib].qs[8*is + 7] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp new file mode 100644 index 000000000..c3f4bca5d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale nibble. + // Each block contains 4 scale bytes (8 scales) for 256 output values. + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf)); + + // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. + uint qh = data_a[ib].qh[is]; + [[unroll]] for (uint l = 0; l < 8; ++l) { + uint qs = data_a[ib].qs[8 * is + l]; + uint gidx = qs | ((qh << (8 - l)) & 256); + uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1)); + u8vec4 grid = unpack8(iq3s_grid[gidx]); + data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp new file mode 100644 index 000000000..a92b82961 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // 8 threads handle 1 superblock + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + const uint s_idx = QUANT_K / 4 + 4 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[s_idx + 0], + data_a[ib].qs[s_idx + 1], + data_a[ib].qs[s_idx + 2], + data_a[ib].qs[s_idx + 3] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.5; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + // Restore parity bit. + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]); + const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ggml/src/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp similarity index 95% rename from ggml/src/vulkan-shaders/dequant_iq4_nl.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp index 34ef3da30..46d9ad15e 100644 --- a/ggml/src/vulkan-shaders/dequant_iq4_nl.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -10,6 +10,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + init_iq_shmem(gl_WorkGroupSize); + const uint tid = gl_LocalInvocationID.x % 64; const uint il = tid/32; const uint ir = tid%32; diff --git a/ggml/src/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_q2_k.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp diff --git a/ggml/src/vulkan-shaders/dequant_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_q3_k.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp diff --git a/ggml/src/vulkan-shaders/dequant_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_q4_0.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp diff --git a/ggml/src/vulkan-shaders/dequant_q4_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_q4_1.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp new file mode 100644 index 000000000..987f113a3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -0,0 +1,68 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 8; + const uint ir = tid % 8; + const uint is = 2 * il; + const uint n = 4; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + n * ir; + const uint qs_idx = 32*il + n * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + [[unroll]] for (uint l = 0; l < n; ++l) { + data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1); + data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2); + } + } +} diff --git a/ggml/src/vulkan-shaders/dequant_q5_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_q5_0.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp diff --git a/ggml/src/vulkan-shaders/dequant_q5_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_q5_1.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp new file mode 100644 index 000000000..6db5403b6 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -0,0 +1,70 @@ +#version 450 + +#include "dequant_head.comp" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.M * p.K / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 16; + const uint ir = tid % 16; + const uint is = 2 * il; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; + const uint qs_idx = 32*il + 2 * ir; + const uint qh_idx = 2 * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + const uint8_t hm1 = uint8_t(1 << (2 * il )); + const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); + data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); + data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); + } +} diff --git a/ggml/src/vulkan-shaders/dequant_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_q6_k.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp diff --git a/ggml/src/vulkan-shaders/dequant_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp similarity index 100% rename from ggml/src/vulkan-shaders/dequant_q8_0.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp diff --git a/ggml/src/vulkan-shaders/diag_mask_inf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp similarity index 91% rename from ggml/src/vulkan-shaders/diag_mask_inf.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp index 4e68742b5..26d8bc22a 100644 --- a/ggml/src/vulkan-shaders/diag_mask_inf.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -12,7 +12,7 @@ layout (push_constant) uniform parameter #include "types.comp" -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp new file mode 100644 index 000000000..9fb69c6c1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp new file mode 100644 index 000000000..043a53023 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -0,0 +1,309 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable +#extension GL_EXT_null_initializer : enable + +#include "types.comp" +#include "dequant_funcs_cm2.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint32_t Br = 32; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; +layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; +} p; + +layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; +layout (binding = 1) readonly buffer K {uint8_t data_k[];}; +layout (binding = 2) readonly buffer V {uint8_t data_v[];}; +layout (binding = 3) readonly buffer M {uint8_t data_m[];}; +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return max(x, y); +} + +ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return x; +} + +// Replace matrix elements >= numRows or numCols with 'replace' +ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { + if (row >= numRows || col >= numCols) { + return replace; + } + return elem; +} + +ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) +{ + return exp(elem); +} + +ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) +{ + return max(elem0, elem1); +} + +#if defined(BLOCK_SIZE) +#define DECODEFUNC , DEQUANTFUNC +#else +#define DECODEFUNC +#endif + +void main() { +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint32_t N = p.N; + const uint32_t KV = p.KV; + + const uint32_t Tr = CEIL_DIV(N, Br); + const uint32_t Tc = CEIL_DIV(KV, Bc); + + const uint32_t i = gl_WorkGroupID.x; + + const uint32_t iq2 = gl_WorkGroupID.y; + const uint32_t iq3 = gl_WorkGroupID.z; + + // broadcast factors + const uint32_t rk2 = p.neq2/p.nek2; + const uint32_t rk3 = p.neq3/p.nek3; + + const uint32_t rv2 = p.neq2/p.nev2; + const uint32_t rv3 = p.neq3/p.nev3; + + // k indices + const uint32_t ik3 = iq3 / rk3; + const uint32_t ik2 = iq2 / rk2; + + // v indices + const uint32_t iv3 = iq3 / rv3; + const uint32_t iv2 = iq2 / rv2; + + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); + tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if defined(BLOCK_SIZE) + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); +#endif + + tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D); + tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); + tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); + + // nb?1 are already divided by the type size and are in units of elements + uint32_t q_stride = p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // hint to the compiler that strides are aligned for the aligned variant of the shader + if (Clamp != gl_CooperativeMatrixClampModeConstantNV) + { + q_stride &= ~7; +#if !defined(BLOCK_SIZE) + k_stride &= ~7; + v_stride &= ~7; +#endif + } + tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); + tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); + tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); + + coopmat Q; + coopmat Qf16; + + uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D)); + + Qf16 = coopmat(Q); + Qf16 *= float16_t(p.scale); + + coopmat O = coopmat(0); + + coopmat L, M; + + L = coopmat(0); + M = coopmat(-1.0/0.0); + + ACC_TYPE slope = ACC_TYPE(1.0); + + // ALiBi + if (p.max_bias > 0.0f) { + const uint32_t h = iq2; + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + slope = pow(base, ACC_TYPE(exph)); + } + + [[dont_unroll]] + for (uint32_t j = 0; j < Tc; ++j) { + + coopmat S = coopmat(0); + + coopmat K_T; + + uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC); + S = coopMatMulAdd(Qf16, K_T, S); + + if (p.logit_softcap != 0.0f) { + [[unroll]] + for (int k = 0; k < S.length(); ++k) { + S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); + } + } + + if (p.mask != 0) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slope*coopmat(mv); + } + + // Clear padding elements to -inf, so they don't contribute to rowmax + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); + } + + coopmat rowmax, P, rowsum, eM; + + coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); + + coopmat Mold = M; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + coopMatPerElementNV(M, rowmax, Max, Mold); + coopMatPerElementNV(P, S - M, Exp); + coopMatPerElementNV(eM, Mold - M, Exp); + + // Clear padding elements to 0, so they don't contribute to rowsum + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); + } + + coopmat P_A = coopmat(P); + + // compute rowsum by multiplying by matrix of all ones. + coopmat One = coopmat(1.0); + + rowsum = coopmat(0.0); + rowsum = coopMatMulAdd(P_A, One, rowsum); + + coopmat V; + uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC); + + L = eM*L + rowsum; + + // This is the "diagonal" matrix in the paper, but since we do componentwise + // multiply rather than matrix multiply it has the diagonal element smeared + // across the row + coopmat eMdiag; + + // resize eM by using smear/reduce + coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); + + O = eMdiag * O; + + O = coopMatMulAdd(P_A, V, O); + } + + coopmat Ldiag; + + // resize L by using smear/reduce + coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); + + [[unroll]] + for (int k = 0; k < Ldiag.length(); ++k) { + Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; + } + + O = Ldiag*O; + + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + uint32_t o_offset = iq3*p.ne2*p.ne1; + + coopmat O_D = coopmat(O); + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); +} diff --git a/ggml/src/vulkan-shaders/gelu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp similarity index 100% rename from ggml/src/vulkan-shaders/gelu.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp diff --git a/ggml/src/vulkan-shaders/gelu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp similarity index 100% rename from ggml/src/vulkan-shaders/gelu_quick.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp new file mode 100644 index 000000000..062e2a4cd --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp @@ -0,0 +1,64 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; + uint misalign_offsets; + float param1; float param2; int param3; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +// true if src0/src1 are the same shape and the indices can be reused without additional modulus +layout(constant_id = 0) const bool norepeat = false; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; } +uint get_doffset() { return p.misalign_offsets & 0xFF; } + +// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 +uint fastmod(uint a, uint b) { + if ((b & (b-1)) == 0) { + return a & (b-1); + } + return a % b; +} + +uint fastdiv(uint a, uint b) { + return (a < b) ? 0 : (a / b); +} + +void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) { + i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00)); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00)); + const uint i02_offset = i02*p.ne01*p.ne00; + i01 = (idx - i03_offset - i02_offset) / p.ne00; + i00 = idx - i03_offset - i02_offset - i01*p.ne00; +} + +uint src0_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint src1_idx(uint i00, uint i01, uint i02, uint i03) { + if (norepeat) { + return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10; + } else { + return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10; + } +} + +uint dst_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20; +} diff --git a/ggml/src/vulkan-shaders/generic_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp similarity index 100% rename from ggml/src/vulkan-shaders/generic_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp new file mode 100644 index 000000000..8dc9d360d --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp @@ -0,0 +1,76 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + float param1; float param2; + + uint ne0_012mp; uint ne0_012L; + uint ne0_01mp; uint ne0_01L; + uint ne0_0mp; uint ne0_0L; + uint ne1_012mp; uint ne1_012L; + uint ne1_01mp; uint ne1_01L; + uint ne1_0mp; uint ne1_0L; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +uint src0_idx(uint idx) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint dst_idx(uint idx) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; +} + +uint src0_idx_quant(uint idx, uint qk) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00; +} + +uint dst_idx_quant(uint idx, uint qk) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10; +} diff --git a/ggml/src/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp similarity index 61% rename from ggml/src/vulkan-shaders/get_rows.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index e9ff22efa..e877ed779 100644 --- a/ggml/src/vulkan-shaders/get_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -3,6 +3,8 @@ #include "types.comp" #include "generic_binary_head.comp" +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + void main() { const uint i00 = gl_GlobalInvocationID.x; const uint i10 = gl_GlobalInvocationID.y; @@ -13,10 +15,10 @@ void main() { return; } - const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; - const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; - const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; #ifndef OPTIMIZATION_ERROR_WORKAROUND data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); diff --git a/ggml/src/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp similarity index 74% rename from ggml/src/vulkan-shaders/get_rows_quant.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index 53a9a96f2..09dc43d8d 100644 --- a/ggml/src/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -4,12 +4,18 @@ #include "generic_binary_head.comp" #include "dequant_funcs.comp" +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + void main() { const uint i00 = (gl_GlobalInvocationID.x)*2; const uint i10 = gl_GlobalInvocationID.y; const uint i11 = (gl_GlobalInvocationID.z)/p.ne12; const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); +#endif + if (i00 >= p.ne00) { return; } @@ -25,6 +31,8 @@ void main() { const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + v = v * dm.x + dm.y; data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); diff --git a/ggml/src/vulkan-shaders/group_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp similarity index 96% rename from ggml/src/vulkan-shaders/group_norm.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp index 5ad9b28da..b6a0d5645 100644 --- a/ggml/src/vulkan-shaders/group_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -19,7 +19,7 @@ void main() { const uint tid = gl_LocalInvocationID.x; const uint start = gl_WorkGroupID.x * group_size + tid; - const uint end = start + group_size; + const uint end = (gl_WorkGroupID.x + 1) * group_size; tmp[tid] = 0.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp new file mode 100644 index 000000000..122b1e93f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -0,0 +1,87 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_spirv_intrinsics: enable +#extension GL_EXT_control_flow_attributes : require + +#if RTE16 +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif + +layout (push_constant) uniform parameter +{ + uint batch_offset; uint offset_delta; + uint IC; + uint IW; uint IH; + uint OW; uint OH; + uint KW; uint KH; + uint pelements; + uint CHW; + int s0; int s1; + int p0; int p1; + int d0; int d1; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +const uint NUM_ITER = 512 / BLOCK_SIZE; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint gidx = gl_GlobalInvocationID.x; + + const uint oh = gl_GlobalInvocationID.y; + const uint batch = gl_GlobalInvocationID.z / p.IC; + const uint ic = gl_GlobalInvocationID.z % p.IC; + + A_TYPE values[NUM_ITER]; + uint offset_dst[NUM_ITER]; + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + values[idx] = A_TYPE(0); + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint i = gidx * NUM_ITER + idx; + + const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + const uint kx = i / ksize; + const uint kd = kx * ksize; + const uint ky = (i - kd) / p.OW; + const uint ix = i % p.OW; + + const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; + const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + + offset_dst[idx] = + ((batch * p.OH + oh) * p.OW + ix) * p.CHW + + (ic * (p.KW * p.KH) + ky * p.KW + kx); + + if (i >= p.pelements) { + continue; + } + + if (iih < p.IH && iiw < p.IW) { + const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; + values[idx] = data_a[offset_src + iih * p.IW + iiw]; + } + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint i = gidx * NUM_ITER + idx; + + if (i >= p.pelements) { + continue; + } + + data_d[offset_dst[idx]] = D_TYPE(values[idx]); + } + +} diff --git a/ggml/src/vulkan-shaders/leaky_relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp similarity index 100% rename from ggml/src/vulkan-shaders/leaky_relu.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp new file mode 100644 index 000000000..43de19df8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.comp" +#include "generic_binary_head.comp" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp new file mode 100644 index 000000000..4c64fd47a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp @@ -0,0 +1,48 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 0) readonly buffer A4 {vec4 data_a4[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; +layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];}; + +layout (push_constant) uniform parameter { + uint ne; + uint k_num; +} p; + +void main() { + // Each invocation handles four consecutive components + const uint idx = gl_GlobalInvocationID.x * 4; + + if (idx >= p.ne) { + return; + } + + // Check if all four components are in bounds and aligned, + // then use vector loads + if (idx + 3 < p.ne && (p.ne % 4) == 0) { + vec4 result = vec4(0.0f); + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a4[(i * p.ne + idx) / 4]; + } + + data_d4[idx / 4] = result; + } else { + [[unroll]] for (uint j = 0; j < 4; ++j) { + if (idx + j < p.ne) { + float result = 0.0f; + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a[i * p.ne + idx + j]; + } + + data_d[idx + j] = result; + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp new file mode 100644 index 000000000..48156e7ba --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -0,0 +1,149 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#define K_PER_ITER 8 +#else +#define K_PER_ITER 2 +#endif + + +uint a_offset, b_offset, d_offset, y_offset; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = (col%QUANT_K)/QUANT_R; // quant index + const uint iybs = col - col%QUANT_K; // y block start index + +#if K_PER_ITER == 8 +#if QUANT_R == 2 + const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); + const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); + const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); +#else + const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); +#endif +#else + // Check if the second of the pair of elements is OOB, and don't fetch B or + // accumulate it. We still fetch a pair of elements for A, which is fine for + // quantized formats since they'll be within the same block. We should + // probably skip fetching the second element for F16/F32, but as of now we + // still do. + const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); + + FLOAT_TYPE b0 = 0, b1 = 0; + b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); + if (!OOB) { + b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); + } +#endif + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + +#if K_PER_ITER == 8 + vec4 v = dequantize4(ib, iqs, a_offset); + vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); + + const vec2 dm = get_dm(ib, a_offset); + if (dm.y != 0) { // quant has min component + v = v * dm.x + dm.y; + v2 = v2 * dm.x + dm.y; + } + + // matrix multiplication + FLOAT_TYPE rowtmp = dot(bv0, v); + rowtmp += dot(bv1, v2); + + if (dm.y == 0) + rowtmp *= dm.x; + + temp[j][n] += rowtmp; +#else + const vec2 v = dequantize(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); + if (!OOB) { + temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + } +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + + y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); +#endif + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp new file mode 100644 index 000000000..903753c7e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -0,0 +1,118 @@ +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_8bit_storage : require + +#ifdef MUL_MAT_ID +#define EXPERT_COUNT 8 +#endif + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +#include "dequant_funcs.comp" + +layout (push_constant) uniform parameter +{ + uint ncols; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint ne11; +#else + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.y; +#else + const uint batch_idx = gl_GlobalInvocationID.y; +#endif + +#ifndef MUL_MAT_ID + uint batch_idx_a = 0; + if (batch_idx != 0) { + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + batch_idx_a = i03 * p.ne02 + i02; + } +#else + const uint expert_id = data_ids[expert_idx]; +#endif + + a_offset = +#ifdef MUL_MAT_ID + expert_id * p.batch_stride_a; +#else + batch_idx_a * p.batch_stride_a; +#endif + b_offset = +#ifdef MUL_MAT_ID + (expert_idx % p.ne11) * p.stride_b; +#else + batch_idx * p.batch_stride_b; +#endif + d_offset = +#ifdef MUL_MAT_ID + expert_idx * p.stride_d; +#else + batch_idx * p.batch_stride_d; +#endif +} + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (constant_id = 1) const uint NUM_ROWS = 1; +layout (constant_id = 2) const uint NUM_COLS = 1; + +shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; + +void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + // sum up partial sums and write back result + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] = temp[j][n]; + } + } + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; + } + } + } + barrier(); + } + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); + } + } + } +} diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp similarity index 100% rename from ggml/src/vulkan-shaders/mul_mat_vec_nc.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_p021.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp similarity index 100% rename from ggml/src/vulkan-shaders/mul_mat_vec_p021.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp new file mode 100644 index 000000000..8cdc640e8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -0,0 +1,129 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + barrier(); + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + } + barrier(); + + if (i >= num_blocks_per_row) + continue; + } else { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + barrier(); + } + + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp new file mode 100644 index 000000000..3116fad16 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -0,0 +1,132 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); + if (i < num_blocks_per_row) + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + if (all_threads) { + barrier(); + sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + const uint itid8 = itid%8; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_im4 = v_im*4; + const uint v_in = itid - 8*v_im; // 0...7 + + const uint32_t m = 0x01010101 << (4 * v_im); + uint32_t hm_m[4]; + [[unroll]] for (uint j = 0; j < 4; ++j) + hm_m[j] = m << j; + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint s_shift = v_im4 + 2*(itid8/4); + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp new file mode 100644 index 000000000..f9cde0648 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -0,0 +1,136 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); + const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); + const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); + const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_lo4.x; + const FLOAT_TYPE q4_1 = qs0_lo4.y; + const FLOAT_TYPE q4_2 = qs0_lo4.z; + const FLOAT_TYPE q4_3 = qs0_lo4.w; + const FLOAT_TYPE q4_4 = qs0_hi4.x; + const FLOAT_TYPE q4_5 = qs0_hi4.y; + const FLOAT_TYPE q4_6 = qs0_hi4.z; + const FLOAT_TYPE q4_7 = qs0_hi4.w; + const FLOAT_TYPE q4_8 = qs64_lo4.x; + const FLOAT_TYPE q4_9 = qs64_lo4.y; + const FLOAT_TYPE q4_10 = qs64_lo4.z; + const FLOAT_TYPE q4_11 = qs64_lo4.w; + const FLOAT_TYPE q4_12 = qs64_hi4.x; + const FLOAT_TYPE q4_13 = qs64_hi4.y; + const FLOAT_TYPE q4_14 = qs64_hi4.z; + const FLOAT_TYPE q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); + vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); + vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); + vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 + const uint n = 4; + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = n * (2 * ir + v_in); // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp new file mode 100644 index 000000000..6c84ef3cd --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -0,0 +1,167 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); + vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); + vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); + vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); + vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); + vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); + vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); + vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = 4*ir + 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp new file mode 100644 index 000000000..f05f96b5e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -0,0 +1,130 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + if (!all_threads) { // when we don't have enough blocks to use all threads + barrier(); + if (i < num_blocks_per_row) + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + const uint32_t qh4_u32 = (qh_u32 & 0x30303030); + const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + const vec4 q0 = vec4(unpack8(q0_u32)) - 32; + const vec4 q1 = vec4(unpack8(q1_u32)) - 32; + const vec4 q2 = vec4(unpack8(q2_u32)) - 32; + const vec4 q3 = vec4(unpack8(q3_u32)) - 32; + + if (all_threads) { + barrier(); + sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); + vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); + vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); + vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); + + FLOAT_TYPE sum[4] = {0, 0, 0, 0}; + [[unroll]] for (uint l = 0; l < 4; ++l) { + sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); + sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); + sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); + sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); + } + temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); + } + } +} + +void compute_outputs(const uint first_row, const uint num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 + + const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 + const uint is = v_in / 4; + + const uint ql_offset = 64*v_im + l0; + const uint qh_offset = 32*v_im + l0; + const uint s_offset = 8*v_im + is; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ggml/src/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp similarity index 53% rename from ggml/src/vulkan-shaders/mul_mm.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index fffdd1818..d0559aac8 100644 --- a/ggml/src/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -7,6 +7,12 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + #ifdef MUL_MAT_ID #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #endif @@ -57,6 +63,7 @@ layout (push_constant) uniform parameter #endif } p; +layout (constant_id = 0) const uint BLOCK_SIZE = 64; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant @@ -65,16 +72,33 @@ layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; layout (constant_id = 7) const uint TM = 4; layout (constant_id = 8) const uint TN = 2; -layout (constant_id = 9) const uint WARP = 32; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; -shared FLOAT_TYPE buf_a[BM * (BK+1)]; -shared FLOAT_TYPE buf_b[BN * (BK+1)]; +#ifdef COOPMAT +#define SHMEM_STRIDE (BK + 8) +#else +#define SHMEM_STRIDE (BK + 1) +#endif + +shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID shared u16vec2 row_ids[3072]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif void main() { +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); +#endif + #ifdef MUL_MAT_ID const uint expert_idx = gl_GlobalInvocationID.z; #else @@ -94,17 +118,32 @@ void main() { const uint ik = gl_WorkGroupID.x / blocks_m; const uint ic = gl_WorkGroupID.y; - const uint warp_i = gl_LocalInvocationID.x / WARP; - const uint warp_r = warp_i % (BM / WM); - const uint warp_c = warp_i / (BM / WM); - const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; const uint WSUBN = WN / WNITER; +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + const uint tiw = gl_LocalInvocationID.x % WARP; + const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); @@ -152,21 +191,31 @@ void main() { uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; #endif - float sums[WMITER * TM * WNITER * TN]; +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + ACC_TYPE sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE cache_a[WMITER * TM]; FLOAT_TYPE cache_b[WNITER * TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; + sums[i] = ACC_TYPE(0.0f); } +#endif - [[unroll]] for (uint block = start_k; block < end_k; block += BK) { + for (uint block = start_k; block < end_k; block += BK) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { #if defined(DATA_A_F32) || defined(DATA_A_F16) #if LOAD_VEC_A == 8 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); @@ -177,21 +226,21 @@ void main() { buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); #elif LOAD_VEC_A == 4 const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); #else if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { - buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); } else { - buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f); + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); } #endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; const uint ib = idx / 16; const uint iqs = idx & 0xF; @@ -204,7 +253,7 @@ void main() { buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; const uint ib = idx / 16; const uint iqs = idx & 0xF; @@ -218,7 +267,7 @@ void main() { buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; const uint ib = idx / 16; const uint iqs = idx & 0xF; @@ -233,7 +282,7 @@ void main() { buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; const uint ib = idx / 16; const uint iqs = idx & 0xF; @@ -249,7 +298,7 @@ void main() { buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint ib = idx / 16; const uint iqs = (idx & 0xF) * 2; @@ -261,7 +310,7 @@ void main() { buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -280,7 +329,7 @@ void main() { buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -294,17 +343,15 @@ void main() { const uint qsshift = halfsplit * 2; // 0,2,4,6 const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 - const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) : - is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) : - is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) : - (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4)); + const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); const float dl = float(data_a[ib].d) * float(us - 32); buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4))); buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); #elif defined(DATA_A_Q4_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -316,15 +363,20 @@ void main() { const vec2 loadd = vec2(data_a[ib].d); - uint8_t sc; - uint8_t mbyte; - if (is < 4) { - sc = uint8_t(data_a[ib].scales[is ] & 63); - mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); - mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); - } + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const float d = loadd.x * sc; const float m = -loadd.y * mbyte; @@ -332,7 +384,7 @@ void main() { buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -347,15 +399,20 @@ void main() { const vec2 loadd = vec2(data_a[ib].d); - uint8_t sc; - uint8_t mbyte; - if (is < 4) { - sc = uint8_t(data_a[ib].scales[is ] & 63); - mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); - mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); - } + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const float d = loadd.x * sc; const float m = -loadd.y * mbyte; @@ -363,7 +420,7 @@ void main() { buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 @@ -380,9 +437,121 @@ void main() { buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32)); buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ2_XXS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx / 4) % 4; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[8 * ib32 + ib8]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[8*ib32 + 4], + data_a[ib].qs[8*ib32 + 5], + data_a[ib].qs[8*ib32 + 6], + data_a[ib].qs[8*ib32 + 7] + )); + const float db = d * 0.25 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_XS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint ib8 = (idx / 4) % 4; // 0..3 + + const float d = float(data_a[ib].d); + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const float db = d * 0.25 * (0.5 + scale); + const uint qs = data_a[ib].qs[4 * ib32 + ib8]; + const uint sign7 = qs >> 9; + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ2_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint ib8 = (idx % 128) / 4; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4)); + + const float d = float(data_a[ib].d); + const float db = d * 0.25 * (0.5 + scale); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ3_XXS) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = (idx % 128) / 2; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[is+0], + data_a[ib].qs[is+1], + data_a[ib].qs[is+2], + data_a[ib].qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); + const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); +#elif defined(DATA_A_IQ3_S) + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = (idx % 128) / 2; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint qh = data_a[ib].qh[iqh]; + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4))); + const uint scale = data_a[ib].scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; const uint ib = idx / 16; const uint iqs = idx & 0xF; @@ -403,7 +572,7 @@ void main() { #else const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif - const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; + const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); @@ -419,24 +588,24 @@ void main() { #else const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif - const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B; + const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); #elif !MUL_MAT_ID if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { - buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); } else { - buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); } #else const uint row_i = ic * BN + loadc_b + l; if (row_i < _ne1) { const u16vec2 row_idx = row_ids[row_i]; - buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); } else { - buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); } #endif } @@ -446,16 +615,30 @@ void main() { pos_a += BK / LOAD_VEC_A; pos_b += BK / LOAD_VEC_B; - for (uint i = 0; i < BK; i++) { +#ifdef COOPMAT + [[unroll]] for (uint i = 0; i < BK; i += TK) { + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + // Load from shared into cache + coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } +#else + [[unroll]] for (uint i = 0; i < BK; i++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; + cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; } } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; + cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; } } @@ -464,12 +647,13 @@ void main() { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]); + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); } } } } } +#endif barrier(); } @@ -481,6 +665,54 @@ void main() { const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < BN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { @@ -492,7 +724,7 @@ void main() { if (row_i >= _ne1) break; const u16vec2 row_idx = row_ids[row_i]; -#endif +#endif // MUL_MAT_ID [[unroll]] for (uint cr = 0; cr < TM; cr++) { #ifdef MUL_MAT_ID data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); @@ -500,9 +732,10 @@ void main() { if (dr_warp + cr < p.M && dc_warp + cc < p.N) { data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); } -#endif +#endif // MUL_MAT_ID } } } } +#endif // COOPMAT } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp new file mode 100644 index 000000000..27c5d68b3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -0,0 +1,311 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#if QUANT_K > 1 +#define DECODEFUNCA , dequantFuncA + +#include "dequant_funcs_cm2.comp" + +#else +#define DECODEFUNCA +#endif + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; + +shared u16vec4 row_ids[3072]; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { + B_TYPE b[]; +}; + +uint _ne1; +shared uint _ne1_sh; + +B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + if (row_i >= _ne1) { + return B_TYPE(0.0); + } + + const u16vec4 row_idx = row_ids[row_i]; + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; + + return ret; +} + +D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) +{ + uint dr = ir * BM + r; + uint dc = ic * BN + c; + + if (dr < p.M && dc < _ne1) { + uint row_i = dc; + const u16vec4 row_idx = row_ids[row_i]; + data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; + } + return elem; +} + +#endif + +void main() { +#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + +#ifdef MUL_MAT_ID + // Spread the search across all elements in the first subgroup + if (gl_SubgroupID == 0) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + + for (uint i = gl_SubgroupInvocationID; subgroupAny(i < num_elements); i += gl_SubgroupSize) { + bool in_range = i < num_elements; + uint ii0 = i % p.nei0; + uint ii1 = i / p.nei0; + uint id = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + uint idx = subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx) { + row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0); + } + _ne1 += subgroupBallotBitCount(ballot); + } + _ne1_sh = _ne1; + } + + barrier(); + + _ne1 = _ne1_sh; + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + uint start_k = 0; + const uint end_k = p.K; +#else + uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + coopmat sum; + sum = coopmat(0.0); + +#ifdef MUL_MAT_ID + uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; + uint pos_b = 0; +#else + uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; + uint pos_b = batch_idx * p.batch_stride_b; +#endif + + uint stride_a = p.stride_a / QUANT_K; + uint stride_b = p.stride_b; + + // Hint to the compiler that values are aligned (want 16B alignment). + // Quants are always block-aligned, no alignment needed. +#if ALIGNED +#if QUANT_K == 1 + stride_a &= ~7; +#endif + stride_b &= ~7; +#endif + + // Create layouts for both clamped and unclamped accesses + tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + +#if QUANT_K > 1 + tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); + tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); +#endif + + // Use end_k rather than p.K as the dimension because that's what + // we need to bound check against when using split_k + tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); + tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); + tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); + tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if !defined(MUL_MAT_ID) + // Detect a fast path where all loads are entirely in bounds and no clamping is required + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && +#if QUANT_K == 1 + (stride_a % 8) == 0 && +#endif + (stride_b % 8) == 0 && (start_k % 8) == 0) { + // Hint to the compiler that values are aligned (want 16B alignment) + start_k &= ~7; + stride_b &= ~7; +#if QUANT_K == 1 + stride_a &= ~7; +#endif + + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + uint k_iters = (end_k - start_k + BK - 1) / BK; + + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } else +#endif // !defined(MUL_MAT_ID) + { + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + + tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); + + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); + + [[dont_unroll]] + for (uint block_k = start_k; block_k < end_k; block_k += BK) { + + coopmat mat_a; + coopmat mat_b; + + // Clamping is expensive, so detect different code paths for each combination + // of A and B needing clamping. + bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; +#ifdef MUL_MAT_ID + bool unclampedB = true; +#else + bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; +#endif + if (unclampedA && unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); +#endif + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else if (unclampedA && !unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else if (!unclampedA && unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); +#endif + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else if (!unclampedA && !unclampedB) { + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } + } + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + +#ifdef MUL_MAT_ID + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); +#else + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); + + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); +#endif +} diff --git a/ggml/src/vulkan-shaders/norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp similarity index 100% rename from ggml/src/vulkan-shaders/norm.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/norm.comp diff --git a/ggml/src/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp similarity index 83% rename from ggml/src/vulkan-shaders/pad.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/pad.comp index a465cd52b..450b67fc5 100644 --- a/ggml/src/vulkan-shaders/pad.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -3,6 +3,8 @@ #include "types.comp" #include "generic_unary_head.comp" +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + void main() { const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; @@ -22,5 +24,5 @@ void main() { const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; - data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f); + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp new file mode 100644 index 000000000..b6124411a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp @@ -0,0 +1,74 @@ +#version 450 + +#include "types.comp" + +#extension GL_EXT_shader_16bit_storage : require + +layout(push_constant) uniform parameter { + uint IW; uint IH; + uint OW; uint OH; + uint OC; + uint pelements; + uint op; + int k0; int k1; + int s0; int s1; + int p0; int p1; +} p; + +#define BLOCK_SIZE 512 +#define FLT_MAX 3.402823466e+38F +#define OP_POOL_MAX 0u +#define OP_POOL_AVG 1u + +layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout(binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.pelements) { + return; + } + + const uint O_HW = p.OW * p.OH; + + const uint nc = idx / O_HW; + const uint cur_oh = (idx % O_HW) / p.OW; + const uint cur_ow = (idx % O_HW) % p.OW; + + const int start_h = int(cur_oh) * p.s0 - p.p0; + const uint bh = max(start_h, 0); + const uint eh = min(start_h + p.k0, p.IH); + + const int start_w = int(cur_ow) * p.s1 - p.p1; + const uint bw = max(start_w, 0); + const uint ew = min(start_w + p.k1, p.IW); + + const float scale = 1.0 / float(p.k0 * p.k1); + float res; + + if (p.op == OP_POOL_AVG) { + res = 0.0; + } else if (p.op == OP_POOL_MAX) { + res = -FLT_MAX; + } else { + return; + } + + #pragma unroll + for (uint i = bh; i < eh; i++) { + #pragma unroll + for (uint j = bw; j < ew; j++) { + const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]); + + if (p.op == OP_POOL_AVG) { + res += cur * scale; + } else if (p.op == OP_POOL_MAX) { + res = max(res, cur); + } + } + } + + data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res; +} diff --git a/ggml/src/vulkan-shaders/relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp similarity index 100% rename from ggml/src/vulkan-shaders/relu.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/relu.comp diff --git a/ggml/src/vulkan-shaders/repeat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp similarity index 79% rename from ggml/src/vulkan-shaders/repeat.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp index a86af87e7..1568b141d 100644 --- a/ggml/src/vulkan-shaders/repeat.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp @@ -3,6 +3,8 @@ #include "types.comp" #include "generic_unary_head.comp" +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + uint src0_idx_mod(uint idx) { const uint i13 = idx / (p.ne12*p.ne11*p.ne10); const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; @@ -20,5 +22,5 @@ void main() { return; } - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]); } diff --git a/ggml/src/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp similarity index 100% rename from ggml/src/vulkan-shaders/rms_norm.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp diff --git a/ggml/src/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp similarity index 91% rename from ggml/src/vulkan-shaders/rope_head.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp index ea8954226..574b51ca5 100644 --- a/ggml/src/vulkan-shaders/rope_head.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp @@ -1,6 +1,11 @@ #include "types.comp" #extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_spirv_intrinsics: enable + +#if RTE16 +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp similarity index 100% rename from ggml/src/vulkan-shaders/rope_neox.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp diff --git a/ggml/src/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp similarity index 100% rename from ggml/src/vulkan-shaders/rope_norm.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp new file mode 100644 index 000000000..4663428de --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -0,0 +1,24 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1)); + idx += num_threads; + } +} diff --git a/ggml/src/vulkan-shaders/silu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp similarity index 100% rename from ggml/src/vulkan-shaders/silu.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/silu.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp new file mode 100644 index 000000000..d7c15a169 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp new file mode 100644 index 000000000..51fc2dc7e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -0,0 +1,173 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate +// over all the columns. The main function tries to pass a constant here, +// as if it were a template function, to allow unrolling. +void soft_max(uint num_iters) { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint rowy = (p.KY > 0) ? (rowx % p.KY) : 0; + + if (rowx >= p.nrows_x) { + return; + } + + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = rowx/p.KY; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // Find max + FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); + + // Cache values while we compute the max, so we don't need to read them + // again when we're ready to compute exp(x-max). + const uint DATA_CACHE_SIZE = 16; + FLOAT_TYPE data_cache[DATA_CACHE_SIZE]; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy * p.KX + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = v; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum, and cache the new value + // in data_cache if possible. + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + if (idx < DATA_CACHE_SIZE) { + val = exp(data_cache[idx] - max_val); + } else { + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); + } + sum += val; + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = val; + } else { + data_d[i] = D_TYPE(val); + } + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + sum = vals[0]; + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + if (idx < DATA_CACHE_SIZE) { + data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor); + } else { + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } + } +} + +void main() { + // instantiate the soft_max function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + soft_max(num_blocks); + } else if (num_blocks > 16) { + soft_max(32); + } else if (num_blocks > 8) { + soft_max(16); + } else if (num_blocks > 4) { + soft_max(8); + } else if (num_blocks == 4) { + soft_max(4); + } else if (num_blocks == 3) { + soft_max(3); + } else if (num_blocks == 2) { + soft_max(2); + } else if (num_blocks == 1) { + soft_max(1); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp new file mode 100644 index 000000000..ef43598ba --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.comp" +#include "generic_unary_head.comp" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); +} diff --git a/ggml/src/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp similarity index 100% rename from ggml/src/vulkan-shaders/sum_rows.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp diff --git a/ggml/src/vulkan-shaders/tanh.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp similarity index 88% rename from ggml/src/vulkan-shaders/tanh.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp index 74630dc7f..495f966bd 100644 --- a/ggml/src/vulkan-shaders/tanh.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -16,6 +16,5 @@ void main() { if (i >= p.KX) { return; } - - data_d[i] = D_TYPE(tanh(data_a[i])); + data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.)); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp new file mode 100644 index 000000000..28eb24e11 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_NV_cooperative_matrix2 : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp new file mode 100644 index 000000000..8c5dd1bd1 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_KHR_cooperative_matrix : require + +void main() +{ +} diff --git a/ggml/src/vulkan-shaders/timestep_embedding.comp b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp similarity index 100% rename from ggml/src/vulkan-shaders/timestep_embedding.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp new file mode 100644 index 000000000..9e56a3530 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -0,0 +1,1068 @@ + +#if !defined(GGML_TYPES_COMP) +#define GGML_TYPES_COMP + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_16bit_storage : require + +#if defined(DATA_A_F32) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE float +#elif LOAD_VEC_A == 4 +#define A_TYPE vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE mat2x4 +#endif +#endif + +#if defined(DATA_A_F16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE float16_t +#elif LOAD_VEC_A == 4 +#define A_TYPE f16vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE f16mat2x4 +#endif +#endif + +#define QUANT_K_Q4_0 32 +#define QUANT_R_Q4_0 2 + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; +struct block_q4_0_packed16 +{ + float16_t d; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q4_0) +#define QUANT_K QUANT_K_Q4_0 +#define QUANT_R QUANT_R_Q4_0 +#define A_TYPE block_q4_0 +#define A_TYPE_PACKED16 block_q4_0_packed16 +#endif + +#define QUANT_K_Q4_1 32 +#define QUANT_R_Q4_1 2 + +struct block_q4_1 +{ + float16_t d; + float16_t m; + uint8_t qs[16]; +}; + +struct block_q4_1_packed16 +{ + float16_t d; + float16_t m; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q4_1) +#define QUANT_K QUANT_K_Q4_1 +#define QUANT_R QUANT_R_Q4_1 +#define A_TYPE block_q4_1 +#define A_TYPE_PACKED16 block_q4_1_packed16 +#endif + +#define QUANT_K_Q5_0 32 +#define QUANT_R_Q5_0 2 + +struct block_q5_0 +{ + float16_t d; + uint16_t qh[2]; + uint8_t qs[16]; +}; + +struct block_q5_0_packed16 +{ + float16_t d; + uint16_t qh[2]; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q5_0) +#define QUANT_K QUANT_K_Q5_0 +#define QUANT_R QUANT_R_Q5_0 +#define A_TYPE block_q5_0 +#define A_TYPE_PACKED16 block_q5_0_packed16 +#endif + +#define QUANT_K_Q5_1 32 +#define QUANT_R_Q5_1 2 + +struct block_q5_1 +{ + float16_t d; + float16_t m; + uint qh; + uint8_t qs[16]; +}; + +struct block_q5_1_packed16 +{ + float16_t d; + float16_t m; + uint qh; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q5_1) +#define QUANT_K QUANT_K_Q5_1 +#define QUANT_R QUANT_R_Q5_1 +#define A_TYPE block_q5_1 +#define A_TYPE_PACKED16 block_q5_1_packed16 +#endif + +#define QUANT_K_Q8_0 32 +#define QUANT_R_Q8_0 1 + +struct block_q8_0 +{ + float16_t d; + int8_t qs[32]; +}; +struct block_q8_0_packed16 +{ + float16_t d; + uint16_t qs[32/2]; +}; + +#if defined(DATA_A_Q8_0) +#define QUANT_K QUANT_K_Q8_0 +#define QUANT_R QUANT_R_Q8_0 +#define A_TYPE block_q8_0 +#define A_TYPE_PACKED16 block_q8_0_packed16 +#endif + +// K-quants +#define QUANT_K_Q2_K 256 + +struct block_q2_K +{ + uint8_t scales[QUANT_K_Q2_K/16]; + uint8_t qs[QUANT_K_Q2_K/4]; + f16vec2 d; +}; + +struct block_q2_K_packed16 +{ + uint16_t scales[QUANT_K_Q2_K/16/2]; + uint16_t qs[QUANT_K_Q2_K/4/2]; + f16vec2 d; +}; + +struct block_q2_K_packed32 +{ + uint32_t scales[QUANT_K_Q2_K/16/4]; + uint32_t qs[QUANT_K_Q2_K/4/4]; + f16vec2 d; +}; + +#if defined(DATA_A_Q2_K) +#define QUANT_K QUANT_K_Q2_K +#define A_TYPE block_q2_K +#define A_TYPE_PACKED16 block_q2_K_packed16 +#define A_TYPE_PACKED32 block_q2_K_packed32 +#endif + +#define QUANT_K_Q3_K 256 + +struct block_q3_K +{ + uint8_t hmask[QUANT_K_Q3_K/8]; + uint8_t qs[QUANT_K_Q3_K/4]; + uint8_t scales[12]; + float16_t d; +}; + +struct block_q3_K_packed16 +{ + uint16_t hmask[QUANT_K_Q3_K/8/2]; + uint16_t qs[QUANT_K_Q3_K/4/2]; + uint16_t scales[12/2]; + float16_t d; +}; + +#if defined(DATA_A_Q3_K) +#define QUANT_K QUANT_K_Q3_K +#define A_TYPE block_q3_K +#define A_TYPE_PACKED16 block_q3_K_packed16 +#endif + +#define QUANT_K_Q4_K 256 + +struct block_q4_K +{ + f16vec2 d; + uint8_t scales[3*QUANT_K_Q4_K/64]; + uint8_t qs[QUANT_K_Q4_K/2]; +}; + +struct block_q4_K_packed16 +{ + f16vec2 d; + uint16_t scales[3*QUANT_K_Q4_K/64/2]; + uint16_t qs[QUANT_K_Q4_K/2/2]; +}; + +struct block_q4_K_packed32 +{ + f16vec2 d; + uint32_t scales[3*QUANT_K_Q4_K/64/4]; + uint32_t qs[QUANT_K_Q4_K/2/4]; +}; + +struct block_q4_K_packed128 +{ + uvec4 q4k[9]; +}; + +#if defined(DATA_A_Q4_K) +#define QUANT_K QUANT_K_Q4_K +#define A_TYPE block_q4_K +#define A_TYPE_PACKED16 block_q4_K_packed16 +#define A_TYPE_PACKED32 block_q4_K_packed32 +#endif + +#define QUANT_K_Q5_K 256 + +struct block_q5_K +{ + f16vec2 d; + uint8_t scales[12]; + uint8_t qh[QUANT_K_Q5_K/8]; + uint8_t qs[QUANT_K_Q5_K/2]; +}; + +struct block_q5_K_packed16 +{ + f16vec2 d; + uint16_t scales[12/2]; + uint16_t qh[QUANT_K_Q5_K/8/2]; + uint16_t qs[QUANT_K_Q5_K/2/2]; +}; + +struct block_q5_K_packed128 +{ + uvec4 q5k[11]; +}; + +#if defined(DATA_A_Q5_K) +#define QUANT_K QUANT_K_Q5_K +#define A_TYPE block_q5_K +#define A_TYPE_PACKED16 block_q5_K_packed16 +#endif + +#define QUANT_K_Q6_K 256 + +struct block_q6_K +{ + uint8_t ql[QUANT_K_Q6_K/2]; + uint8_t qh[QUANT_K_Q6_K/4]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +struct block_q6_K_packed16 +{ + uint16_t ql[QUANT_K_Q6_K/2/2]; + uint16_t qh[QUANT_K_Q6_K/4/2]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +#if defined(DATA_A_Q6_K) +#define QUANT_K QUANT_K_Q6_K +#define A_TYPE block_q6_K +#define A_TYPE_PACKED16 block_q6_K_packed16 +#endif + +// IQuants + +#define QUANT_K_IQ2_XXS 256 +#define QUANT_R_IQ2_XXS 1 + +struct block_iq2_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_XXS/4]; +}; + +struct block_iq2_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XXS/8]; +}; + +#if defined(DATA_A_IQ2_XXS) + +const uvec2[256] iq2xxs_grid_const = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808), + uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808), + uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), + uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808), + uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819), + uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b), + uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908), + uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), + uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919), + uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08), + uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08), + uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b), + uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808), + uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), + uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819), + uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b), + uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919), + uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808), + uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819), + uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908), + uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908), + uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b), + uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b), + uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), + uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819), + uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908), + uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08), + uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19), + uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808), + uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808), + uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819), + uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08), + uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b), + uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919), + uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), + uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819), + uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908), + uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919), + uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08), + uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808), + uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819), + uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908), + uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819), + uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19) +}; + +shared uvec2 iq2xxs_grid[256]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) { + iq2xxs_grid[i] = iq2xxs_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XXS +#define QUANT_R QUANT_R_IQ2_XXS +#define A_TYPE block_iq2_xxs +#define A_TYPE_PACKED16 block_iq2_xxs_packed16 +#endif + +#define QUANT_K_IQ2_XS 256 +#define QUANT_R_IQ2_XS 1 + +struct block_iq2_xs +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint8_t scales[QUANT_K_IQ2_XS/32]; +}; + +struct block_iq2_xs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint16_t scales[QUANT_K_IQ2_XS/64]; +}; + +#if defined(DATA_A_IQ2_XS) + +const uvec2 iq2xs_grid_const[512] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), + uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), + uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819), + uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819), + uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), + uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819), + uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b), + uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b), + uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), + uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), + uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), + uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908), + uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919), + uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08), + uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19), + uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19), + uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b), + uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808), + uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808), + uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808), + uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), + uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808), + uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), + uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b), + uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b), + uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), + uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b), + uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), + uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), + uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), + uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), + uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908), + uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919), + uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b), + uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), + uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19), + uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b), + uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), + uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808), + uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), + uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808), + uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), + uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), + uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819), + uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), + uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908), + uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908), + uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908), + uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919), + uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), + uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08), + uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19), + uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808), + uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), + uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), + uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819), + uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819), + uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b), + uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908), + uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908), + uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919), + uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08), + uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08), + uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), + uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919), + uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08), + uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b), + uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), + uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b), + uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b), + uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908), + uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b), + uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08), + uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08), + uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b), + uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808), + uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b), + uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908), + uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919), + uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08), + uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808), + uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819), + uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b), + uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908), + uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08), + uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08), + uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19), + uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b), +}; + +shared uvec2 iq2xs_grid[512]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) { + iq2xs_grid[i] = iq2xs_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XS +#define QUANT_R QUANT_R_IQ2_XS +#define A_TYPE block_iq2_xs +#define A_TYPE_PACKED16 block_iq2_xs_packed16 +#endif + +#define QUANT_K_IQ2_S 256 +#define QUANT_R_IQ2_S 1 + +struct block_iq2_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_S/4]; + uint8_t qh[QUANT_K_IQ2_S/32]; + uint8_t scales[QUANT_K_IQ2_S/32]; +}; + +#if defined(DATA_A_IQ2_S) + +const uvec2 iq2s_grid_const[1024] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808), + uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), + uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808), + uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), + uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), + uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), + uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), + uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819), + uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), + uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), + uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), + uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b), + uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908), + uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908), + uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908), + uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), + uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), + uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908), + uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908), + uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908), + uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), + uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919), + uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919), + uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919), + uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), + uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b), + uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08), + uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), + uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08), + uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08), + uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), + uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19), + uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19), + uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19), + uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b), + uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b), + uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), + uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), + uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), + uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), + uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808), + uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808), + uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808), + uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808), + uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819), + uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819), + uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), + uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819), + uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819), + uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b), + uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b), + uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b), + uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b), + uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908), + uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908), + uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908), + uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908), + uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908), + uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919), + uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919), + uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919), + uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919), + uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b), + uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b), + uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08), + uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08), + uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19), + uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19), + uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19), + uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808), + uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808), + uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808), + uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808), + uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819), + uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819), + uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819), + uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b), + uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b), + uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), + uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908), + uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908), + uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908), + uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908), + uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919), + uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919), + uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919), + uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b), + uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08), + uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08), + uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19), + uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b), + uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), + uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), + uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808), + uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808), + uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808), + uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), + uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808), + uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), + uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), + uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819), + uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819), + uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819), + uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819), + uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819), + uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), + uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b), + uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b), + uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b), + uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908), + uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908), + uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908), + uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908), + uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908), + uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908), + uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919), + uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919), + uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919), + uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919), + uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919), + uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919), + uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b), + uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b), + uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08), + uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08), + uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08), + uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19), + uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19), + uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19), + uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b), + uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808), + uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808), + uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808), + uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808), + uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808), + uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808), + uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808), + uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819), + uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819), + uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819), + uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819), + uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819), + uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b), + uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b), + uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b), + uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908), + uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908), + uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908), + uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908), + uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908), + uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919), + uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919), + uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b), + uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08), + uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08), + uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08), + uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19), + uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808), + uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808), + uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808), + uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808), + uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819), + uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819), + uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819), + uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b), + uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908), + uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908), + uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908), + uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919), + uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919), + uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08), + uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19), + uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808), + uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808), + uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808), + uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819), + uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819), + uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819), + uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b), + uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b), + uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), + uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908), + uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908), + uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919), + uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919), + uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919), + uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b), + uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08), + uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08), + uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19), + uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b), + uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), + uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808), + uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), + uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808), + uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819), + uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819), + uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b), + uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908), + uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908), + uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908), + uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919), + uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919), + uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08), + uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08), + uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19), + uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), + uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808), + uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808), + uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b), + uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b), + uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908), + uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908), + uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b), + uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19), + uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b), + uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b) +}; + +shared uvec2 iq2s_grid[1024]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) { + iq2s_grid[i] = iq2s_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_S +#define QUANT_R QUANT_R_IQ2_S +#define A_TYPE block_iq2_s +#endif + +#define QUANT_K_IQ3_XXS 256 +#define QUANT_R_IQ3_XXS 1 + +struct block_iq3_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8]; +}; + +struct block_iq3_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16]; +}; + +#if defined(DATA_A_IQ3_XXS) + +const uint32_t iq3xxs_grid_const[256] = { + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +}; + +shared uint32_t iq3xxs_grid[256]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) { + iq3xxs_grid[i] = iq3xxs_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_XXS +#define QUANT_R QUANT_R_IQ3_XXS +#define A_TYPE block_iq3_xxs +#define A_TYPE_PACKED16 block_iq3_xxs_packed16 +#endif + +#define QUANT_K_IQ3_S 256 +#define QUANT_R_IQ3_S 1 + +struct block_iq3_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_S/4]; + uint8_t qh[QUANT_K_IQ3_S/32]; + uint8_t signs[QUANT_K_IQ3_S/8]; + uint8_t scales[QUANT_K_IQ3_S/64]; +}; + +struct block_iq3_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_S/4/2]; + uint16_t qh[QUANT_K_IQ3_S/32/2]; + uint16_t signs[QUANT_K_IQ3_S/8/2]; + uint16_t scales[QUANT_K_IQ3_S/64/2]; +}; + +#if defined(DATA_A_IQ3_S) + +const uint32_t iq3s_grid_const[512] = { + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +}; + +shared uint32_t iq3s_grid[512]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) { + iq3s_grid[i] = iq3s_grid_const[i]; + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_S +#define QUANT_R QUANT_R_IQ3_S +#define A_TYPE block_iq3_s +#define A_TYPE_PACKED16 block_iq3_s_packed16 +#endif + +#define QUANT_K_IQ4_NL 32 +#define QUANT_R_IQ4_NL 2 + +struct block_iq4_nl +{ + float16_t d; + uint8_t qs[QUANT_K_IQ4_NL/2]; +}; + +struct block_iq4_nl_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ4_NL/2/2]; +}; + +#if defined(DATA_A_IQ4_NL) + +const int8_t kvalues_iq4nl_const[16] = { + int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), + int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) +}; + +shared FLOAT_TYPE kvalues_iq4nl[16]; + +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) { + kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]); + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ4_NL +#define QUANT_R QUANT_R_IQ4_NL +#define A_TYPE block_iq4_nl +#define A_TYPE_PACKED16 block_iq4_nl_packed16 +#endif + +#endif // !defined(GGML_TYPES_COMP) diff --git a/ggml/src/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp similarity index 85% rename from ggml/src/vulkan-shaders/upscale.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 511a086ea..6f607380d 100644 --- a/ggml/src/vulkan-shaders/upscale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -2,7 +2,7 @@ layout (push_constant) uniform parameter { - uint ne; uint d_offset; + uint ne; uint a_offset; uint d_offset; uint nb00; uint nb01; uint nb02; uint nb03; uint ne10; uint ne11; uint ne12; uint ne13; float sf0; float sf1; float sf2; float sf3; @@ -32,5 +32,5 @@ void main() { const uint i02 = uint(i12 / p.sf2); const uint i03 = uint(i13 / p.sf3); - data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); + data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp new file mode 100644 index 000000000..93ddbfadc --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -0,0 +1,608 @@ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 + #include + #include // For _mkdir on Windows +#else + #include + #include + #include +#endif + +#define ASYNCIO_CONCURRENCY 64 + +std::mutex lock; +std::vector> shader_fnames; + +std::string GLSLC = "glslc"; +std::string input_dir = "vulkan-shaders"; +std::string output_dir = "/tmp"; +std::string target_hpp = "ggml-vulkan-shaders.hpp"; +std::string target_cpp = "ggml-vulkan-shaders.cpp"; +bool no_clean = false; + +const std::vector type_names = { + "f32", + "f16", + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2_k", + "q3_k", + "q4_k", + "q5_k", + "q6_k", + "iq2_xxs", + "iq2_xs", + "iq2_s", + "iq3_xxs", + "iq3_s", + "iq4_nl" +}; + +namespace { +void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { +#ifdef _WIN32 + HANDLE stdout_read, stdout_write; + HANDLE stderr_read, stderr_write; + SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; + + if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) || + !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stdout pipe"); + } + + if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) || + !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stderr pipe"); + } + + PROCESS_INFORMATION pi; + STARTUPINFOA si = {}; + si.cb = sizeof(STARTUPINFOA); + si.dwFlags = STARTF_USESTDHANDLES; + si.hStdOutput = stdout_write; + si.hStdError = stderr_write; + + std::vector cmd(command.begin(), command.end()); + cmd.push_back('\0'); + + if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) { + throw std::runtime_error("Failed to create process"); + } + + CloseHandle(stdout_write); + CloseHandle(stderr_write); + + std::array buffer; + DWORD bytes_read; + + while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + CloseHandle(stdout_read); + CloseHandle(stderr_read); + WaitForSingleObject(pi.hProcess, INFINITE); + CloseHandle(pi.hProcess); + CloseHandle(pi.hThread); +#else +int stdout_pipe[2]; + int stderr_pipe[2]; + + if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { + throw std::runtime_error("Failed to create pipes"); + } + + pid_t pid = fork(); + if (pid < 0) { + throw std::runtime_error("Failed to fork process"); + } + + if (pid == 0) { + close(stdout_pipe[0]); + close(stderr_pipe[0]); + dup2(stdout_pipe[1], STDOUT_FILENO); + dup2(stderr_pipe[1], STDERR_FILENO); + close(stdout_pipe[1]); + close(stderr_pipe[1]); + execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr); + _exit(EXIT_FAILURE); + } else { + close(stdout_pipe[1]); + close(stderr_pipe[1]); + + std::array buffer; + ssize_t bytes_read; + + while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + close(stdout_pipe[0]); + close(stderr_pipe[0]); + waitpid(pid, nullptr, 0); + } +#endif +} + +bool directory_exists(const std::string& path) { + struct stat info; + if (stat(path.c_str(), &info) != 0) { + return false; // Path doesn't exist or can't be accessed + } + return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory +} + +bool create_directory(const std::string& path) { +#ifdef _WIN32 + return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists +#else + return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions +#endif +} + +std::string to_uppercase(const std::string& input) { + std::string result = input; + for (char& c : result) { + c = std::toupper(c); + } + return result; +} + +bool string_ends_with(const std::string& str, const std::string& suffix) { + if (suffix.size() > str.size()) { + return false; + } + return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); +} + +static const char path_separator = '/'; + +std::string join_paths(const std::string& path1, const std::string& path2) { + return path1 + path_separator + path2; +} + +std::string basename(const std::string &path) { + return path.substr(path.find_last_of("/\\") + 1); +} + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string out_fname = join_paths(output_dir, name + ".spv"); + std::string in_path = join_paths(input_dir, in_fname); + + std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; + + // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 + std::string opt_level = coopmat ? "" : "-O"; + + #ifdef _WIN32 + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; + #else + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname}; + #endif + + #ifdef GGML_VULKAN_SHADER_DEBUG_INFO + cmd.push_back("-g"); + #endif + + for (const auto& define : defines) { + cmd.push_back("-D" + define.first + "=" + define.second); + } + + std::string command; + for (const auto& part : cmd) { + command += part + " "; + } + + std::string stdout_str, stderr_str; + try { + // std::cout << "Executing command: "; + // for (const auto& part : cmd) { + // std::cout << part << " "; + // } + // std::cout << std::endl; + + execute_command(command, stdout_str, stderr_str); + if (!stderr_str.empty()) { + std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl; + return; + } + + std::lock_guard guard(lock); + shader_fnames.push_back(std::make_pair(name, out_fname)); + } catch (const std::exception& e) { + std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; + } + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); +} + +std::map merge_maps(const std::map& a, const std::map& b) { + std::map result = a; + result.insert(b.begin(), b.end()); + return result; +} + +static std::vector> compiles; +void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + { + // wait until fewer than N compiles are in progress. + // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. + uint32_t N = 16; + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc)); +} + +void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) { + std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; + std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; + std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + + std::map base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; + std::string shader_name = "matmul"; + + if (matmul_id) { + base_dict["MUL_MAT_ID"] = "1"; + shader_name = "matmul_id"; + } + + if (fp16) { + base_dict["FLOAT16"] = "1"; + } + + base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + + if (coopmat) { + base_dict["COOPMAT"] = "1"; + } + + base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + + std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; + + // Shaders with f16 B_TYPE + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + + for (const auto& tname : type_names) { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + // For unaligned, load one at a time for f32/f16, or two at a time for quants + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; + // For aligned matmul loads + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; + + // don't generate f32 variants for coopmat2 + if (!coopmat2) { + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + + if (tname != "f16" && tname != "f32") { + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + } +} + +void process_shaders() { + std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; + std::map base_dict = {{"FLOAT_TYPE", "float"}}; + + // matmul + for (const auto& matmul_id : {false, true}) { + // No coopmats + // fp32 + matmul_shaders(false, matmul_id, false, false, false); + + // fp16, fp32acc and fp16acc + matmul_shaders(true, matmul_id, false, false, false); + matmul_shaders(true, matmul_id, false, false, true); + +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + // Coopmat, fp32acc and fp16acc + matmul_shaders(true, matmul_id, true, false, false); + matmul_shaders(true, matmul_id, true, false, true); +#endif + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + // Coopmat2, fp32acc and fp16acc + matmul_shaders(true, matmul_id, false, true, false); + matmul_shaders(true, matmul_id, false, true, true); +#endif + } + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + // flash attention + for (const auto& f16acc : {false, true}) { + std::string acctype = f16acc ? "float16_t" : "float"; + + for (const auto& tname : type_names) { + if (tname == "f32") { + continue; + } + + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); + } else { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + } + } + } +#endif + + for (const auto& tname : type_names) { + // mul mat vec + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + + // Dequant shaders + if (tname != "f16") { + string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); + } + + if (!string_ends_with(tname, "_k")) { + shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; + + if (tname == "f16") { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + } else { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); + } + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + } + } + + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + // Norms + string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + } + + string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + + string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); + + string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + + string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); + string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); + + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + for (auto &c : compiles) { + c.wait(); + } +} + +void write_output_files() { + FILE* hdr = fopen(target_hpp.c_str(), "w"); + FILE* src = fopen(target_cpp.c_str(), "w"); + + fprintf(hdr, "#include \n\n"); + fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); + + std::sort(shader_fnames.begin(), shader_fnames.end()); + for (const auto& pair : shader_fnames) { + const std::string& name = pair.first; + #ifdef _WIN32 + std::string path = pair.second; + std::replace(path.begin(), path.end(), '/', '\\' ); + #else + const std::string& path = pair.second; + #endif + + FILE* spv = fopen(path.c_str(), "rb"); + if (!spv) { + std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; + continue; + } + + fseek(spv, 0, SEEK_END); + size_t size = ftell(spv); + fseek(spv, 0, SEEK_SET); + + std::vector data(size); + size_t read_size = fread(data.data(), 1, size, spv); + fclose(spv); + if (read_size != size) { + std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; + continue; + } + + fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); + fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); + + fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); + for (size_t i = 0; i < size; ++i) { + fprintf(src, "0x%02x,", data[i]); + if ((i + 1) % 12 == 0) fprintf(src, "\n"); + } + fprintf(src, "\n};\n\n"); + + if (!no_clean) { + std::remove(path.c_str()); + } + } + + fclose(hdr); + fclose(src); +} +} + +int main(int argc, char** argv) { + std::map args; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg.rfind("--", 0) == 0) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + args[arg] = argv[i + 1]; + ++i; + } else { + args[arg] = ""; + } + } + } + + if (args.find("--glslc") != args.end()) { + GLSLC = args["--glslc"]; // Path to glslc + } + if (args.find("--input-dir") != args.end()) { + input_dir = args["--input-dir"]; // Directory containing shader sources + } + if (args.find("--output-dir") != args.end()) { + output_dir = args["--output-dir"]; // Directory for containing SPIR-V output + } + if (args.find("--target-hpp") != args.end()) { + target_hpp = args["--target-hpp"]; // Path to generated header file + } + if (args.find("--target-cpp") != args.end()) { + target_cpp = args["--target-cpp"]; // Path to generated cpp file + } + if (args.find("--no-clean") != args.end()) { + no_clean = true; // Keep temporary SPIR-V files in output-dir after build + } + + if (!directory_exists(input_dir)) { + std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; + return EXIT_FAILURE; + } + + if (!directory_exists(output_dir)) { + if (!create_directory(output_dir)) { + std::cerr << "Error creating output directory: " << output_dir << "\n"; + return EXIT_FAILURE; + } + } + + process_shaders(); + + write_output_files(); + + return EXIT_SUCCESS; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp new file mode 100644 index 000000000..35cc6c45f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -0,0 +1,87 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; +layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; +layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + barrier(); + _tf[tid] = tf[head_id * head_size + tid]; + barrier(); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + barrier(); + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + + vec4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 28ee46e04..3b4861542 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1,10 +1,17 @@ -#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows +#define _CRT_SECURE_NO_DEPRECATE // Disables "unsafe" warnings on Windows #define _USE_MATH_DEFINES // For M_PI on MSVC +#include "ggml-backend.h" #include "ggml-impl.h" -#include "ggml-quants.h" +#include "ggml-threading.h" #include "ggml.h" -#include "ggml-aarch64.h" + +// FIXME: required here for quantization functions +#include "ggml-quants.h" + +#ifdef GGML_USE_CPU_HBM +#include +#endif #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -29,149 +36,38 @@ #include #endif -#ifdef GGML_USE_OPENMP -#include -#endif - -#ifdef GGML_USE_METAL +#if defined(__APPLE__) #include -#endif - -#if defined(__ARM_FEATURE_SVE) -int ggml_sve_cnt_b = 0; -#endif -#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) -#undef GGML_USE_LLAMAFILE -#endif - -#ifdef GGML_USE_LLAMAFILE -#include -#endif - -#if defined(_MSC_VER) -// disable "possible loss of data" to avoid hundreds of casts -// we should just be careful :) -#pragma warning(disable: 4244 4267) - -// disable POSIX deprecation warnings -// these functions are never going away, anyway -#pragma warning(disable: 4996) - -// unreachable code because of multiple instances of code after GGML_ABORT -#pragma warning(disable: 4702) +#include +#include #endif #if defined(_WIN32) - #define WIN32_LEAN_AND_MEAN #ifndef NOMINMAX #define NOMINMAX #endif #include - -#if !defined(__clang__) -typedef volatile LONG atomic_int; -typedef atomic_int atomic_bool; -typedef atomic_int atomic_flag; - -#define ATOMIC_FLAG_INIT 0 - -typedef enum { - memory_order_relaxed, - memory_order_consume, - memory_order_acquire, - memory_order_release, - memory_order_acq_rel, - memory_order_seq_cst -} memory_order; - -static void atomic_store(atomic_int * ptr, LONG val) { - InterlockedExchange(ptr, val); -} -static void atomic_store_explicit(atomic_int * ptr, LONG val, memory_order mo) { - // TODO: add support for explicit memory order - InterlockedExchange(ptr, val); -} -static LONG atomic_load(atomic_int * ptr) { - return InterlockedCompareExchange(ptr, 0, 0); -} -static LONG atomic_load_explicit(atomic_int * ptr, memory_order mo) { - // TODO: add support for explicit memory order - return InterlockedCompareExchange(ptr, 0, 0); -} -static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { - return InterlockedExchangeAdd(ptr, inc); -} -static LONG atomic_fetch_add_explicit(atomic_int * ptr, LONG inc, memory_order mo) { - // TODO: add support for explicit memory order - return InterlockedExchangeAdd(ptr, inc); -} -static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { - return InterlockedExchange(ptr, 1); -} -static void atomic_flag_clear(atomic_flag * ptr) { - InterlockedExchange(ptr, 0); -} -#else // clang -#include #endif -typedef HANDLE pthread_t; +#define UNUSED GGML_UNUSED -typedef DWORD thread_ret_t; -static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) { - (void) unused; - HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); - if (handle == NULL) - { - return EAGAIN; - } - - *out = handle; - return 0; -} - -static int pthread_join(pthread_t thread, void * unused) { - (void) unused; - int ret = (int) WaitForSingleObject(thread, INFINITE); - CloseHandle(thread); - return ret; -} - -static int sched_yield (void) { - Sleep (0); - return 0; -} +#if defined(_MSC_VER) +#define m512bh(p) p +#define m512i(p) p #else - -#include -#include -#include -#if defined(__FreeBSD__) -#include +#define m512bh(p) (__m512bh)(p) +#define m512i(p) (__m512i)(p) #endif -typedef void * thread_ret_t; - -#include -#include -#include - -#endif - -typedef pthread_t ggml_thread_t; - -#ifdef GGML_USE_CPU_HBM -#include -#endif - -#if defined(__APPLE__) -#include -#endif +// precomputed f32 table for f16 (256 KB) (ggml-impl.h) +float ggml_table_f32_f16[1 << 16]; #if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \ (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH)) - +#include +#include +#include #include #if defined(__ANDROID__) @@ -232,6 +128,10 @@ static void ggml_print_backtrace_symbols(void) { #endif static void ggml_print_backtrace(void) { + const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); + if (GGML_NO_BACKTRACE) { + return; + } char attach[32]; snprintf(attach, sizeof(attach), "attach %d", getpid()); int pid = fork(); @@ -284,37 +184,49 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { abort(); } -#define GGML_DEBUG 0 -#define GGML_GELU_FP16 -#define GGML_GELU_QUICK_FP16 - -#define GGML_SOFT_MAX_UNROLL 4 -#define GGML_VEC_DOT_UNROLL 2 -#define GGML_VEC_MAD_UNROLL 32 - // // logging // -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif +struct ggml_logger_state { + ggml_log_callback log_callback; + void * log_callback_user_data; +}; +static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL}; -#if (GGML_DEBUG >= 5) -#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_5(...) -#endif +static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) { + if (format == NULL) { + return; + } + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data); + } else { + char * buffer2 = (char *) calloc(len + 1, sizeof(char)); + vsnprintf(buffer2, len + 1, format, args_copy); + buffer2[len] = 0; + g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data); + free(buffer2); + } + va_end(args_copy); +} -#if (GGML_DEBUG >= 10) -#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_10(...) -#endif +void ggml_log_internal(enum ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + ggml_log_internal_v(level, format, args); + va_end(args); +} -#define GGML_PRINT(...) printf(__VA_ARGS__) +void ggml_log_callback_default(enum ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} // // end of logging block @@ -326,23 +238,41 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { //#define GGML_SOFT_MAX_ACCELERATE #endif + +void * ggml_aligned_malloc(size_t size) { + const int alignment = 64; + #if defined(_MSC_VER) || defined(__MINGW32__) -#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN) -#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) + return _aligned_malloc(size, alignment); #else -inline static void * ggml_aligned_malloc(size_t size) { if (size == 0) { - GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n"); + GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n"); return NULL; } void * aligned_memory = NULL; -#ifdef GGML_USE_CPU_HBM - int result = hbw_posix_memalign(&aligned_memory, 16, size); -#elif GGML_USE_METAL - int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size); -#else - int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size); -#endif + #ifdef GGML_USE_CPU_HBM + int result = hbw_posix_memalign(&aligned_memory, alignment, size); + #elif TARGET_OS_OSX + GGML_UNUSED(alignment); + kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE); + int result = EFAULT; + switch (alloc_status) { + case KERN_SUCCESS: + result = 0; + break; + case KERN_INVALID_ADDRESS: + result = EINVAL; + break; + case KERN_NO_SPACE: + result = ENOMEM; + break; + default: + result = EFAULT; + break; + } + #else + int result = posix_memalign(&aligned_memory, alignment, size); + #endif if (result != 0) { // Handle allocation failure const char *error_desc = "unknown allocation error"; @@ -354,28 +284,39 @@ inline static void * ggml_aligned_malloc(size_t size) { error_desc = "insufficient memory"; break; } - GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); - GGML_ABORT("fatal error"); + GGML_LOG_ERROR("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); return NULL; } return aligned_memory; +#endif } -#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size) -#ifdef GGML_USE_CPU_HBM -#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr) + +void ggml_aligned_free(void * ptr, size_t size) { + GGML_UNUSED(size); +#if defined(_MSC_VER) || defined(__MINGW32__) + _aligned_free(ptr); +#elif GGML_USE_CPU_HBM + if (ptr != NULL) { + hbw_free(ptr); + } +#elif TARGET_OS_OSX + if (ptr != NULL) { + vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size); + } #else -#define GGML_ALIGNED_FREE(ptr) free(ptr) -#endif + free(ptr); #endif +} + inline static void * ggml_malloc(size_t size) { if (size == 0) { - GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_malloc!\n"); + GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_malloc!\n"); return NULL; } void * result = malloc(size); if (result == NULL) { - GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); + GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); GGML_ABORT("fatal error"); } return result; @@ -384,12 +325,12 @@ inline static void * ggml_malloc(size_t size) { // calloc inline static void * ggml_calloc(size_t num, size_t size) { if (num == 0 || size == 0) { - GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_calloc!\n"); + GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_calloc!\n"); return NULL; } void * result = calloc(num, size); if (result == NULL) { - GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); + GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); GGML_ABORT("fatal error"); } return result; @@ -400,36 +341,7 @@ inline static void * ggml_calloc(size_t num, size_t size) { #define GGML_FREE(ptr) free(ptr) -#define UNUSED GGML_UNUSED -#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0) - -#if defined(GGML_USE_ACCELERATE) -#include -#endif - -// floating point type used to accumulate sums -typedef double ggml_float; - -#undef MIN -#undef MAX - -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -// -// global data -// - -// precomputed gelu table for f16 (128 KB) -static ggml_fp16_t ggml_table_gelu_f16[1 << 16]; - -// precomputed quick gelu table for f16 (128 KB) -static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; - -// precomputed f32 table for f16 (256 KB) (ggml-impl.h) -float ggml_table_f32_f16[1 << 16]; - -GGML_CALL const char * ggml_status_to_string(enum ggml_status status) { +const char * ggml_status_to_string(enum ggml_status status) { switch (status) { case GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)"; case GGML_STATUS_FAILED: return "GGML status: error (operation failed)"; @@ -466,19 +378,23 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) { } } +// FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library +// currently, the ggml_cpu_has_* functions are entirely compile-time void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) { int64_t i = 0; #if defined(__F16C__) - for (; i + 7 < n; i += 8) { - __m256 x_vec = _mm256_loadu_ps(x + i); - __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); - _mm_storeu_si128((__m128i *)(y + i), y_vec); - } - for(; i + 3 < n; i += 4) { - __m128 x_vec = _mm_loadu_ps(x + i); - __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); - _mm_storel_epi64((__m128i *)(y + i), y_vec); - } + //if (ggml_cpu_has_f16c()) { + for (; i + 7 < n; i += 8) { + __m256 x_vec = _mm256_loadu_ps(x + i); + __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128((__m128i *)(y + i), y_vec); + } + for(; i + 3 < n; i += 4) { + __m128 x_vec = _mm_loadu_ps(x + i); + __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storel_epi64((__m128i *)(y + i), y_vec); + } + //} #endif for (; i < n; i++) { y[i] = GGML_FP32_TO_FP16(x[i]); @@ -488,25 +404,30 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) { void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) { int64_t i = 0; #if defined(__AVX512F__) - for (; i + 16 <= n; i += 16) { - _mm512_storeu_ps(y + i, - _mm512_castsi512_ps( - _mm512_slli_epi32( - _mm512_cvtepu16_epi32( - _mm256_loadu_si256( - (const __m256i *)(x + i))), - 16))); - } -#elif defined(__AVX2__) - for (; i + 8 <= n; i += 8) { - _mm256_storeu_ps(y + i, - _mm256_castsi256_ps( - _mm256_slli_epi32( - _mm256_cvtepu16_epi32( - _mm_loadu_si128( - (const __m128i *)(x + i))), - 16))); - } + //if (ggml_cpu_has_avx512()) { + for (; i + 16 <= n; i += 16) { + _mm512_storeu_ps(y + i, + _mm512_castsi512_ps( + _mm512_slli_epi32( + _mm512_cvtepu16_epi32( + _mm256_loadu_si256( + (const __m256i *)(x + i))), + 16))); + } + //} +#endif +#if defined(__AVX2__) + //if (ggml_cpu_has_avx2()) { + for (; i + 8 <= n; i += 8) { + _mm256_storeu_ps(y + i, + _mm256_castsi256_ps( + _mm256_slli_epi32( + _mm256_cvtepu16_epi32( + _mm_loadu_si128( + (const __m128i *)(x + i))), + 16))); + } + //} #endif for (; i < n; i++) { y[i] = GGML_BF16_TO_FP32(x[i]); @@ -638,29 +559,13 @@ FILE * ggml_fopen(const char * fname, const char * mode) { #else return fopen(fname, mode); #endif + } - -// -// cache line -// - -#if defined(__cpp_lib_hardware_interference_size) -#define CACHE_LINE_SIZE hardware_destructive_interference_size -#else -#if defined(__POWER9_VECTOR__) -#define CACHE_LINE_SIZE 128 -#else -#define CACHE_LINE_SIZE 64 -#endif -#endif - -static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); - static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc); static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc); static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc); -static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { +static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { [GGML_TYPE_I8] = { .type_name = "i8", .blck_size = 1, @@ -690,16 +595,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = 1, .type_size = sizeof(double), .is_quantized = false, - .nrows = 1, }, [GGML_TYPE_F32] = { .type_name = "f32", .blck_size = 1, .type_size = sizeof(float), .is_quantized = false, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32, - .vec_dot_type = GGML_TYPE_F32, - .nrows = 1, }, [GGML_TYPE_F16] = { .type_name = "f16", @@ -707,11 +608,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(ggml_fp16_t), .is_quantized = false, .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, - .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, - .vec_dot_type = GGML_TYPE_F16, - .nrows = 1, }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", @@ -719,15 +616,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q4_0), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_0, - .from_float = quantize_row_q4_0, .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref, - .vec_dot = ggml_vec_dot_q4_0_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, -#if defined (__ARM_FEATURE_MATMUL_INT8) - .nrows = 2, -#else - .nrows = 1, -#endif }, [GGML_TYPE_Q4_1] = { .type_name = "q4_1", @@ -735,39 +624,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q4_1), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_1, - .from_float = quantize_row_q4_1, .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, - .vec_dot = ggml_vec_dot_q4_1_q8_1, - .vec_dot_type = GGML_TYPE_Q8_1, -#if defined (__ARM_FEATURE_MATMUL_INT8) - .nrows = 2, -#else - .nrows = 1, -#endif }, [4] = { // GGML_TYPE_Q4_2 .type_name = "DEPRECATED", .blck_size = 0, .type_size = 0, .is_quantized = false, - .to_float = NULL, - .from_float = NULL, - .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = GGML_TYPE_COUNT, - .nrows = 1, }, [5] = { // GGML_TYPE_Q4_3 .type_name = "DEPRECATED", .blck_size = 0, .type_size = 0, .is_quantized = false, - .to_float = NULL, - .from_float = NULL, - .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = GGML_TYPE_COUNT, - .nrows = 1, }, [GGML_TYPE_Q5_0] = { .type_name = "q5_0", @@ -775,11 +644,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q5_0), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_0, - .from_float = quantize_row_q5_0, .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref, - .vec_dot = ggml_vec_dot_q5_0_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, }, [GGML_TYPE_Q5_1] = { .type_name = "q5_1", @@ -787,11 +652,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q5_1), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_1, - .from_float = quantize_row_q5_1, .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, - .vec_dot = ggml_vec_dot_q5_1_q8_1, - .vec_dot_type = GGML_TYPE_Q8_1, - .nrows = 1, }, [GGML_TYPE_Q8_0] = { .type_name = "q8_0", @@ -799,26 +660,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_0), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q8_0, - .from_float = quantize_row_q8_0, .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref, - .from_float_to_mat = quantize_mat_q8_0, - .vec_dot = ggml_vec_dot_q8_0_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, -#if defined (__ARM_FEATURE_MATMUL_INT8) - .nrows = 2, -#else - .nrows = 1, -#endif }, [GGML_TYPE_Q8_1] = { .type_name = "q8_1", .blck_size = QK8_1, .type_size = sizeof(block_q8_1), .is_quantized = true, - .from_float = quantize_row_q8_1, .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, - .vec_dot_type = GGML_TYPE_Q8_1, - .nrows = 1, }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", @@ -826,11 +675,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q2_K), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q2_K, - .from_float = quantize_row_q2_K, .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref, - .vec_dot = ggml_vec_dot_q2_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_Q3_K] = { .type_name = "q3_K", @@ -838,11 +683,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q3_K), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q3_K, - .from_float = quantize_row_q3_K, .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref, - .vec_dot = ggml_vec_dot_q3_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_Q4_K] = { .type_name = "q4_K", @@ -850,11 +691,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q4_K), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_K, - .from_float = quantize_row_q4_K, .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref, - .vec_dot = ggml_vec_dot_q4_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_Q5_K] = { .type_name = "q5_K", @@ -862,11 +699,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q5_K), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_K, - .from_float = quantize_row_q5_K, .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref, - .vec_dot = ggml_vec_dot_q5_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_Q6_K] = { .type_name = "q6_K", @@ -874,11 +707,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q6_K), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q6_K, - .from_float = quantize_row_q6_K, .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref, - .vec_dot = ggml_vec_dot_q6_K_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_IQ2_XXS] = { .type_name = "iq2_xxs", @@ -886,11 +715,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq2_xxs), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = ggml_vec_dot_iq2_xxs_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_IQ2_XS] = { .type_name = "iq2_xs", @@ -898,11 +723,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq2_xs), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_xs, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = ggml_vec_dot_iq2_xs_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_IQ3_XXS] = { .type_name = "iq3_xxs", @@ -910,11 +731,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq3_xxs), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, - .from_float = quantize_row_iq3_xxs, .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref, - .vec_dot = ggml_vec_dot_iq3_xxs_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_IQ3_S] = { .type_name = "iq3_s", @@ -922,11 +739,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq3_s), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq3_s, - .from_float = quantize_row_iq3_s, .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref, - .vec_dot = ggml_vec_dot_iq3_s_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_IQ2_S] = { .type_name = "iq2_s", @@ -934,11 +747,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq2_s), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_s, - .from_float = quantize_row_iq2_s, .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref, - .vec_dot = ggml_vec_dot_iq2_s_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_IQ1_S] = { .type_name = "iq1_s", @@ -946,11 +755,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq1_s), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_s, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = ggml_vec_dot_iq1_s_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_IQ1_M] = { .type_name = "iq1_m", @@ -958,11 +763,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq1_m), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_m, - .from_float = NULL, .from_float_ref = NULL, - .vec_dot = ggml_vec_dot_iq1_m_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_IQ4_NL] = { .type_name = "iq4_nl", @@ -970,11 +771,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq4_nl), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_nl, - .from_float = quantize_row_iq4_nl, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, - .vec_dot = ggml_vec_dot_iq4_nl_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, }, [GGML_TYPE_IQ4_XS] = { .type_name = "iq4_xs", @@ -982,18 +779,13 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_iq4_xs), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_xs, - .from_float = quantize_row_iq4_xs, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref, - .vec_dot = ggml_vec_dot_iq4_xs_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_Q8_K] = { .type_name = "q8_K", .blck_size = QK_K, .type_size = sizeof(block_q8_K), .is_quantized = true, - .from_float = quantize_row_q8_K, }, [GGML_TYPE_BF16] = { .type_name = "bf16", @@ -1001,59 +793,25 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(ggml_bf16_t), .is_quantized = false, .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, - .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row_ref, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, - .vec_dot_type = GGML_TYPE_BF16, - .nrows = 1, }, - [GGML_TYPE_Q4_0_4_4] = { - .type_name = "q4_0_4x4", - .blck_size = QK4_0, - .blck_size_interleave = 4, - .type_size = sizeof(block_q4_0), - .is_quantized = true, - .to_float = NULL, - .from_float = NULL, - .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, - .ncols = 4, - .gemv = ggml_gemv_q4_0_4x4_q8_0, - .gemm = ggml_gemm_q4_0_4x4_q8_0, + [31] = { // GGML_TYPE_Q4_0_4_4 + .type_name = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, }, - [GGML_TYPE_Q4_0_4_8] = { - .type_name = "q4_0_4x8", - .blck_size = QK4_0, - .blck_size_interleave = 8, - .type_size = sizeof(block_q4_0), - .is_quantized = true, - .to_float = NULL, - .from_float = NULL, - .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, - .ncols = 4, - .gemv = ggml_gemv_q4_0_4x8_q8_0, - .gemm = ggml_gemm_q4_0_4x8_q8_0, + [32] = { // GGML_TYPE_Q4_0_4_8 + .type_name = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, }, - [GGML_TYPE_Q4_0_8_8] = { - .type_name = "q4_0_8x8", - .blck_size = QK4_0, - .blck_size_interleave = 8, - .type_size = sizeof(block_q4_0), - .is_quantized = true, - .to_float = NULL, - .from_float = NULL, - .from_float_ref = NULL, - .vec_dot = NULL, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, - .ncols = 8, - .gemv = ggml_gemv_q4_0_8x8_q8_0, - .gemm = ggml_gemm_q4_0_8x8_q8_0, + [33] = { // GGML_TYPE_Q4_0_8_8 + .type_name = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, }, [GGML_TYPE_TQ1_0] = { .type_name = "tq1_0", @@ -1061,11 +819,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_tq1_0), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_tq1_0, - .from_float = quantize_row_tq1_0, .from_float_ref = (ggml_from_float_t) quantize_row_tq1_0_ref, - .vec_dot = ggml_vec_dot_tq1_0_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, }, [GGML_TYPE_TQ2_0] = { .type_name = "tq2_0", @@ -1073,825 +827,49 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_tq2_0), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_tq2_0, - .from_float = quantize_row_tq2_0, .from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref, - .vec_dot = ggml_vec_dot_tq2_0_q8_K, - .vec_dot_type = GGML_TYPE_Q8_K, - .nrows = 1, + }, + [36] = { // GGML_TYPE_IQ4_NL_4_4 + .type_name = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [37] = { // GGML_TYPE_IQ4_NL_4_8 + .type_name = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [38] = { // GGML_TYPE_IQ4_NL_8_8 + .type_name = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, }, }; -// For internal test use -ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) { +const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { GGML_ASSERT(type < GGML_TYPE_COUNT); - return type_traits[type]; + return &type_traits[type]; } // -// simd mappings +// ggml object // -// we define a common set of C macros which map to specific intrinsics based on the current architecture -// we then implement the fundamental computation operations below using only these macros -// adding support for new architectures requires to define the corresponding SIMD macros -// -// GGML_F32_STEP / GGML_F16_STEP -// number of elements to process in a single step -// -// GGML_F32_EPR / GGML_F16_EPR -// number of elements to fit in a single register -// +struct ggml_object { + size_t offs; + size_t size; -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + struct ggml_object * next; -#define GGML_SIMD + enum ggml_object_type type; -// F32 NEON + char padding[4]; +}; -#define GGML_F32_STEP 16 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 float32x4_t -#define GGML_F32x4_ZERO vdupq_n_f32(0.0f) -#define GGML_F32x4_SET1(x) vdupq_n_f32(x) -#define GGML_F32x4_LOAD vld1q_f32 -#define GGML_F32x4_STORE vst1q_f32 -#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) -#define GGML_F32x4_ADD vaddq_f32 -#define GGML_F32x4_MUL vmulq_f32 -#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f32(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f32(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f32(x[i], x[offset+i]); \ - } \ - res = GGML_F32x4_REDUCE_ONE(x[0]); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 NEON - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - #define GGML_F16_STEP 32 - #define GGML_F16_EPR 8 - - #define GGML_F16x8 float16x8_t - #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) - #define GGML_F16x8_SET1(x) vdupq_n_f16(x) - #define GGML_F16x8_LOAD(x) vld1q_f16((const ggml_fp16_internal_t *)(x)) - #define GGML_F16x8_STORE vst1q_f16 - #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) - #define GGML_F16x8_ADD vaddq_f16 - #define GGML_F16x8_MUL vmulq_f16 - #define GGML_F16x8_REDUCE(res, x) \ - do { \ - int offset = GGML_F16_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f16(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f16(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vaddq_f16(x[i], x[offset+i]); \ - } \ - const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ - const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ - res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ - } while (0) - - #define GGML_F16_VEC GGML_F16x8 - #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO - #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 - #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), r[i]) - #define GGML_F16_VEC_FMA GGML_F16x8_FMA - #define GGML_F16_VEC_ADD GGML_F16x8_ADD - #define GGML_F16_VEC_MUL GGML_F16x8_MUL - #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE -#else - // if FP16 vector arithmetic is not supported, we use FP32 instead - // and take advantage of the vcvt_ functions to convert to/from FP16 - - #define GGML_F16_STEP 16 - #define GGML_F16_EPR 4 - - #define GGML_F32Cx4 float32x4_t - #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) - #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) - #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x))) - #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) - #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) - #define GGML_F32Cx4_ADD vaddq_f32 - #define GGML_F32Cx4_MUL vmulq_f32 - #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE - - #define GGML_F16_VEC GGML_F32Cx4 - #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO - #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 - #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i]) - #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA - #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD - #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL - #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE -#endif - -#elif defined(__AVX512F__) - -#define GGML_SIMD - -// F32 AVX512 - -#define GGML_F32_STEP 64 -#define GGML_F32_EPR 16 - -#define GGML_F32x16 __m512 -#define GGML_F32x16_ZERO _mm512_setzero_ps() -#define GGML_F32x16_SET1(x) _mm512_set1_ps(x) -#define GGML_F32x16_LOAD _mm512_loadu_ps -#define GGML_F32x16_STORE _mm512_storeu_ps -// _mm512_fmadd_ps is defined in AVX512F so no guard is required -#define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) -#define GGML_F32x16_ADD _mm512_add_ps -#define GGML_F32x16_MUL _mm512_mul_ps -#define GGML_F32x16_REDUCE(res, x) \ -do { \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - res = _mm512_reduce_add_ps(x[0]); \ -} while (0) - -// TODO: is this optimal ? - -#define GGML_F32_VEC GGML_F32x16 -#define GGML_F32_VEC_ZERO GGML_F32x16_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x16_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x16_LOAD -#define GGML_F32_VEC_STORE GGML_F32x16_STORE -#define GGML_F32_VEC_FMA GGML_F32x16_FMA -#define GGML_F32_VEC_ADD GGML_F32x16_ADD -#define GGML_F32_VEC_MUL GGML_F32x16_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE - -// F16 AVX512 - -// F16 AVX - -#define GGML_F16_STEP 64 -#define GGML_F16_EPR 16 - -// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead - -#define GGML_F32Cx16 __m512 -#define GGML_F32Cx16_ZERO _mm512_setzero_ps() -#define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x) - -// unlike _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F -// so F16C guard isn't required -#define GGML_F32Cx16_LOAD(x) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) -#define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) - -#define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) -#define GGML_F32Cx16_ADD _mm512_add_ps -#define GGML_F32Cx16_MUL _mm512_mul_ps -#define GGML_F32Cx16_REDUCE(res, x) \ -do { \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm512_add_ps(x[i], x[offset+i]); \ - } \ - res = _mm512_reduce_add_ps(x[0]); \ -} while (0) - -#define GGML_F16_VEC GGML_F32Cx16 -#define GGML_F16_VEC_ZERO GGML_F32Cx16_ZERO -#define GGML_F16_VEC_SET1 GGML_F32Cx16_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx16_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx16_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx16_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx16_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE - -#elif defined(__AVX__) - -#define GGML_SIMD - -// F32 AVX - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 8 - -#define GGML_F32x8 __m256 -#define GGML_F32x8_ZERO _mm256_setzero_ps() -#define GGML_F32x8_SET1(x) _mm256_set1_ps(x) -#define GGML_F32x8_LOAD _mm256_loadu_ps -#define GGML_F32x8_STORE _mm256_storeu_ps -#if defined(__FMA__) - #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) -#else - #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) -#endif -#define GGML_F32x8_ADD _mm256_add_ps -#define GGML_F32x8_MUL _mm256_mul_ps -#define GGML_F32x8_REDUCE(res, x) \ -do { \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm256_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm256_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm256_add_ps(x[i], x[offset+i]); \ - } \ - const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]), \ - _mm256_extractf128_ps(x[0], 1)); \ - const __m128 t1 = _mm_hadd_ps(t0, t0); \ - res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ -} while (0) -// TODO: is this optimal ? - -#define GGML_F32_VEC GGML_F32x8 -#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD -#define GGML_F32_VEC_STORE GGML_F32x8_STORE -#define GGML_F32_VEC_FMA GGML_F32x8_FMA -#define GGML_F32_VEC_ADD GGML_F32x8_ADD -#define GGML_F32_VEC_MUL GGML_F32x8_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE - -// F16 AVX - -#define GGML_F16_STEP 32 -#define GGML_F16_EPR 8 - -// F16 arithmetic is not supported by AVX, so we use F32 instead - -#define GGML_F32Cx8 __m256 -#define GGML_F32Cx8_ZERO _mm256_setzero_ps() -#define GGML_F32Cx8_SET1(x) _mm256_set1_ps(x) - -#if defined(__F16C__) -// the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) -#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) -#else -static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { - float tmp[8]; - - for (int i = 0; i < 8; i++) { - tmp[i] = GGML_FP16_TO_FP32(x[i]); - } - - return _mm256_loadu_ps(tmp); -} -static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { - float arr[8]; - - _mm256_storeu_ps(arr, y); - - for (int i = 0; i < 8; i++) - x[i] = GGML_FP32_TO_FP16(arr[i]); -} -#define GGML_F32Cx8_LOAD(x) __avx_f32cx8_load(x) -#define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) -#endif - -#define GGML_F32Cx8_FMA GGML_F32x8_FMA -#define GGML_F32Cx8_ADD _mm256_add_ps -#define GGML_F32Cx8_MUL _mm256_mul_ps -#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE - -#define GGML_F16_VEC GGML_F32Cx8 -#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO -#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE - -#elif defined(__POWER9_VECTOR__) - -#define GGML_SIMD - -// F32 POWER9 - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 vector float -#define GGML_F32x4_ZERO 0.0f -#define GGML_F32x4_SET1 vec_splats -#define GGML_F32x4_LOAD(p) vec_xl(0, p) -#define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) -#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) -#define GGML_F32x4_ADD vec_add -#define GGML_F32x4_MUL vec_mul -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vec_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vec_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = vec_add(x[i], x[offset+i]); \ - } \ - res = vec_extract(x[0], 0) + \ - vec_extract(x[0], 1) + \ - vec_extract(x[0], 2) + \ - vec_extract(x[0], 3); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 POWER9 -#define GGML_F16_STEP GGML_F32_STEP -#define GGML_F16_EPR GGML_F32_EPR -#define GGML_F16_VEC GGML_F32x4 -#define GGML_F16_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F16_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F16_VEC_FMA GGML_F32x4_FMA -#define GGML_F16_VEC_ADD GGML_F32x4_ADD -#define GGML_F16_VEC_MUL GGML_F32x4_MUL -#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE -// Use vec_xl, not vec_ld, in case the load address is not aligned. -#define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ? \ - vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ - vec_extract_fp32_from_shortl(vec_xl(0, p)) -#define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] -#define GGML_F16_VEC_STORE(p, r, i) \ - if (i & 0x1) \ - vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)], \ - r[i - GGML_ENDIAN_BYTE(0)]), \ - 0, p - GGML_F16_EPR) - -#elif defined(__wasm_simd128__) - -#define GGML_SIMD - -// F32 WASM - -#define GGML_F32_STEP 16 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 v128_t -#define GGML_F32x4_ZERO wasm_f32x4_splat(0.0f) -#define GGML_F32x4_SET1(x) wasm_f32x4_splat(x) -#define GGML_F32x4_LOAD wasm_v128_load -#define GGML_F32x4_STORE wasm_v128_store -#define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) -#define GGML_F32x4_ADD wasm_f32x4_add -#define GGML_F32x4_MUL wasm_f32x4_mul -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 WASM - -#define GGML_F16_STEP 16 -#define GGML_F16_EPR 4 - -inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { - float tmp[4]; - - tmp[0] = GGML_FP16_TO_FP32(p[0]); - tmp[1] = GGML_FP16_TO_FP32(p[1]); - tmp[2] = GGML_FP16_TO_FP32(p[2]); - tmp[3] = GGML_FP16_TO_FP32(p[3]); - - return wasm_v128_load(tmp); -} - -inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { - float tmp[4]; - - wasm_v128_store(tmp, x); - - p[0] = GGML_FP32_TO_FP16(tmp[0]); - p[1] = GGML_FP32_TO_FP16(tmp[1]); - p[2] = GGML_FP32_TO_FP16(tmp[2]); - p[3] = GGML_FP32_TO_FP16(tmp[3]); -} - -#define GGML_F16x4 v128_t -#define GGML_F16x4_ZERO wasm_f32x4_splat(0.0f) -#define GGML_F16x4_SET1(x) wasm_f32x4_splat(x) -#define GGML_F16x4_LOAD(x) __wasm_f16x4_load(x) -#define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) -#define GGML_F16x4_FMA GGML_F32x4_FMA -#define GGML_F16x4_ADD wasm_f32x4_add -#define GGML_F16x4_MUL wasm_f32x4_mul -#define GGML_F16x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F16_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = wasm_f32x4_add(x[i], x[offset+i]); \ - } \ - res = wasm_f32x4_extract_lane(x[0], 0) + \ - wasm_f32x4_extract_lane(x[0], 1) + \ - wasm_f32x4_extract_lane(x[0], 2) + \ - wasm_f32x4_extract_lane(x[0], 3); \ -} - -#define GGML_F16_VEC GGML_F16x4 -#define GGML_F16_VEC_ZERO GGML_F16x4_ZERO -#define GGML_F16_VEC_SET1 GGML_F16x4_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F16x4_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F16x4_FMA -#define GGML_F16_VEC_ADD GGML_F16x4_ADD -#define GGML_F16_VEC_MUL GGML_F16x4_MUL -#define GGML_F16_VEC_REDUCE GGML_F16x4_REDUCE - -#elif defined(__SSE3__) - -#define GGML_SIMD - -// F32 SSE - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 __m128 -#define GGML_F32x4_ZERO _mm_setzero_ps() -#define GGML_F32x4_SET1(x) _mm_set1_ps(x) -#define GGML_F32x4_LOAD _mm_loadu_ps -#define GGML_F32x4_STORE _mm_storeu_ps -#if defined(__FMA__) - // TODO: Does this work? - #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) -#else - #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) -#endif -#define GGML_F32x4_ADD _mm_add_ps -#define GGML_F32x4_MUL _mm_mul_ps -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm_add_ps(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = _mm_add_ps(x[i], x[offset+i]); \ - } \ - const __m128 t0 = _mm_hadd_ps(x[0], x[0]); \ - res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0)); \ -} -// TODO: is this optimal ? - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 SSE - -#define GGML_F16_STEP 32 -#define GGML_F16_EPR 4 - -static inline __m128 __sse_f16x4_load(ggml_fp16_t *x) { - float tmp[4]; - - tmp[0] = GGML_FP16_TO_FP32(x[0]); - tmp[1] = GGML_FP16_TO_FP32(x[1]); - tmp[2] = GGML_FP16_TO_FP32(x[2]); - tmp[3] = GGML_FP16_TO_FP32(x[3]); - - return _mm_loadu_ps(tmp); -} - -static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { - float arr[4]; - - _mm_storeu_ps(arr, y); - - x[0] = GGML_FP32_TO_FP16(arr[0]); - x[1] = GGML_FP32_TO_FP16(arr[1]); - x[2] = GGML_FP32_TO_FP16(arr[2]); - x[3] = GGML_FP32_TO_FP16(arr[3]); -} - -#define GGML_F32Cx4 __m128 -#define GGML_F32Cx4_ZERO _mm_setzero_ps() -#define GGML_F32Cx4_SET1(x) _mm_set1_ps(x) -#define GGML_F32Cx4_LOAD(x) __sse_f16x4_load(x) -#define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) -#define GGML_F32Cx4_FMA GGML_F32x4_FMA -#define GGML_F32Cx4_ADD _mm_add_ps -#define GGML_F32Cx4_MUL _mm_mul_ps -#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE - -#define GGML_F16_VEC GGML_F32Cx4 -#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO -#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE - -#elif defined(__loongarch_asx) - -#define GGML_SIMD - -// F32 LASX -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 8 - -#define GGML_F32x8 __m256 -#define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0) -#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x)) -#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0) -#define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0) -#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a) -#define GGML_F32x8_ADD __lasx_xvfadd_s -#define GGML_F32x8_MUL __lasx_xvfmul_s -#define GGML_F32x8_REDUCE(res, x) \ -do { \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ - } \ - float *tmp_p = (float *)&x[0]; \ - res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \ -} while (0) -// TODO: is this optimal ? - -#define GGML_F32_VEC GGML_F32x8 -#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD -#define GGML_F32_VEC_STORE GGML_F32x8_STORE -#define GGML_F32_VEC_FMA GGML_F32x8_FMA -#define GGML_F32_VEC_ADD GGML_F32x8_ADD -#define GGML_F32_VEC_MUL GGML_F32x8_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE - -// F16 LASX - -#define GGML_F16_STEP 32 -#define GGML_F16_EPR 8 - -// F16 arithmetic is not supported by AVX, so we use F32 instead - -#define GGML_F32Cx8 __m256 -#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0) -#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x)) - -static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) { - float tmp[8]; - - for (int i = 0; i < 8; i++) { - tmp[i] = GGML_FP16_TO_FP32(x[i]); - } - - return (__m256)__lasx_xvld(tmp, 0); -} -static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { - float arr[8]; - - __lasx_xvst(y, arr, 0); - - for (int i = 0; i < 8; i++) { - x[i] = GGML_FP32_TO_FP16(arr[i]); - } -} -#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x) -#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y) - -#define GGML_F32Cx8_FMA GGML_F32x8_FMA -#define GGML_F32Cx8_ADD __lasx_xvfadd_s -#define GGML_F32Cx8_MUL __lasx_xvfmul_s -#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE - -#define GGML_F16_VEC GGML_F32Cx8 -#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO -#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE - -#elif defined(__loongarch_sx) - -#define GGML_SIMD - -// F32 LSX - -#define GGML_F32_STEP 32 -#define GGML_F32_EPR 4 - -#define GGML_F32x4 __m128 -#define GGML_F32x4_ZERO __lsx_vldi(0) -#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) -#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0) -#define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0) -#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a) -#define GGML_F32x4_ADD __lsx_vfadd_s -#define GGML_F32x4_MUL __lsx_vfmul_s -#define GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ - } \ - __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \ - tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \ - tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ - const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \ - tmp = __lsx_vsrli_d((__m128i)t0, 32); \ - tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \ - tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ - res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \ -} - -#define GGML_F32_VEC GGML_F32x4 -#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO -#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 -#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD -#define GGML_F32_VEC_STORE GGML_F32x4_STORE -#define GGML_F32_VEC_FMA GGML_F32x4_FMA -#define GGML_F32_VEC_ADD GGML_F32x4_ADD -#define GGML_F32_VEC_MUL GGML_F32x4_MUL -#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE - -// F16 LSX - -#define GGML_F16_STEP 32 -#define GGML_F16_EPR 4 - -static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) { - float tmp[4]; - - tmp[0] = GGML_FP16_TO_FP32(x[0]); - tmp[1] = GGML_FP16_TO_FP32(x[1]); - tmp[2] = GGML_FP16_TO_FP32(x[2]); - tmp[3] = GGML_FP16_TO_FP32(x[3]); - - return __lsx_vld(tmp, 0); -} - -static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { - float arr[4]; - - __lsx_vst(y, arr, 0); - - x[0] = GGML_FP32_TO_FP16(arr[0]); - x[1] = GGML_FP32_TO_FP16(arr[1]); - x[2] = GGML_FP32_TO_FP16(arr[2]); - x[3] = GGML_FP32_TO_FP16(arr[3]); -} - -#define GGML_F32Cx4 __m128 -#define GGML_F32Cx4_ZERO __lsx_vldi(0) -#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) -#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x) -#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y) -#define GGML_F32Cx4_FMA GGML_F32x4_FMA -#define GGML_F32Cx4_ADD __lsx_vfadd_s -#define GGML_F32Cx4_MUL __lsx_vfmul_s -#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE - -#define GGML_F16_VEC GGML_F32Cx4 -#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO -#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 -#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) -#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) -#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA -#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD -#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL -#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE - -#endif - -// GGML_F32_ARR / GGML_F16_ARR -// number of registers to use per step -#ifdef GGML_SIMD -#define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) -#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) -#endif +static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); // // ggml context @@ -1899,18 +877,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { struct ggml_context { size_t mem_size; - void* mem_buffer; + void * mem_buffer; bool mem_buffer_owned; bool no_alloc; - bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers int n_objects; struct ggml_object * objects_begin; struct ggml_object * objects_end; - - struct ggml_scratch scratch; - struct ggml_scratch scratch_save; }; struct ggml_context_container { @@ -1919,971 +893,6 @@ struct ggml_context_container { struct ggml_context context; }; -// -// Threading defs -// - -typedef pthread_t ggml_thread_t; - -#if defined(_WIN32) - -typedef CONDITION_VARIABLE ggml_cond_t; -typedef SRWLOCK ggml_mutex_t; - -#define ggml_mutex_init(m) InitializeSRWLock(m) -#define ggml_mutex_destroy(m) -#define ggml_mutex_lock(m) AcquireSRWLockExclusive(m) -#define ggml_mutex_unlock(m) ReleaseSRWLockExclusive(m) -#define ggml_mutex_lock_shared(m) AcquireSRWLockShared(m) -#define ggml_mutex_unlock_shared(m) ReleaseSRWLockShared(m) - -#define ggml_cond_init(c) InitializeConditionVariable(c) -#define ggml_cond_destroy(c) -#define ggml_cond_wait(c, m) SleepConditionVariableSRW(c, m, INFINITE, CONDITION_VARIABLE_LOCKMODE_SHARED) -#define ggml_cond_broadcast(c) WakeAllConditionVariable(c) - -#define ggml_thread_create pthread_create -#define ggml_thread_join pthread_join - -#else - -typedef pthread_cond_t ggml_cond_t; -typedef pthread_mutex_t ggml_mutex_t; - -#define ggml_mutex_init(m) pthread_mutex_init(m, NULL) -#define ggml_mutex_destroy(m) pthread_mutex_destroy(m) -#define ggml_mutex_lock(m) pthread_mutex_lock(m) -#define ggml_mutex_unlock(m) pthread_mutex_unlock(m) -#define ggml_mutex_lock_shared(m) pthread_mutex_lock(m) -#define ggml_mutex_unlock_shared(m) pthread_mutex_unlock(m) - -#define ggml_lock_init(x) UNUSED(x) -#define ggml_lock_destroy(x) UNUSED(x) -#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64)) -#define ggml_lock_lock(x) _mm_pause() -#else -#define ggml_lock_lock(x) UNUSED(x) -#endif -#define ggml_lock_unlock(x) UNUSED(x) - -#define GGML_LOCK_INITIALIZER 0 -#define ggml_cond_init(c) pthread_cond_init(c, NULL) -#define ggml_cond_destroy(c) pthread_cond_destroy(c) -#define ggml_cond_wait(c, m) pthread_cond_wait(c, m) -#define ggml_cond_broadcast(c) pthread_cond_broadcast(c) - -#define ggml_thread_create pthread_create -#define ggml_thread_join pthread_join - -#endif - -// Threadpool def -struct ggml_threadpool { - ggml_mutex_t mutex; // mutex for cond.var - ggml_cond_t cond; // cond.var for waiting for new work - - struct ggml_cgraph * cgraph; - struct ggml_cplan * cplan; - - // synchronization primitives - atomic_int n_graph; // incremented when there is work to be done (i.e each graph) - atomic_int n_barrier; - atomic_int n_barrier_passed; - atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. - - // these are atomic as an annotation for thread-sanitizer - atomic_bool stop; // Used for stopping the threadpool altogether - atomic_bool pause; // Used for pausing the threadpool or individual threads - - struct ggml_compute_state * workers; // per thread state - int n_threads_max; // number of threads in the pool - int n_threads_cur; // number of threads used in the current graph - - int32_t prio; // Scheduling priority - uint32_t poll; // Polling level (0 - no polling) - - enum ggml_status ec; -}; - -// Per-thread state -struct ggml_compute_state { -#ifndef GGML_USE_OPENMP - ggml_thread_t thrd; - bool cpumask[GGML_MAX_N_THREADS]; - int last_graph; - bool pending; -#endif - struct ggml_threadpool * threadpool; - int ith; -}; - -struct ggml_compute_params { - // ith = thread index, nth = number of threads - int ith, nth; - - // work buffer for all threads - size_t wsize; - void * wdata; - - struct ggml_threadpool * threadpool; -}; - -// -// fundamental operations -// - -inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } -inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } -inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } -inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } -inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } -inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } -inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } -inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } -inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - -static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - -#if defined(GGML_SIMD) - float sumf = 0.0f; - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - - sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_F32_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } -#else - // scalar - ggml_float sumf = 0.0; - for (int i = 0; i < n; ++i) { - sumf += (ggml_float)(x[i]*y[i]); - } -#endif - - *s = sumf; -} - -static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - int i = 0; - ggml_float sumf = 0; - -#if defined(__AVX512BF16__) - __m512 c1 = _mm512_setzero_ps(); - __m512 c2 = _mm512_setzero_ps(); - for (; i + 64 <= n; i += 64) { - c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))), - m512bh(_mm512_loadu_si512((y + i)))); - c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))), - m512bh(_mm512_loadu_si512((y + i + 32)))); - } - sumf += (ggml_float)_mm512_reduce_add_ps(c1); - sumf += (ggml_float)_mm512_reduce_add_ps(c2); - -#elif defined(__AVX512F__) -#define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16)) - __m512 c1 = _mm512_setzero_ps(); - __m512 c2 = _mm512_setzero_ps(); - for (; i + 32 <= n; i += 32) { - c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1); - c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2); - } - sumf += (ggml_float)_mm512_reduce_add_ps(c1); - sumf += (ggml_float)_mm512_reduce_add_ps(c2); - -#undef LOAD -#elif defined(__AVX2__) -#define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)) - __m256 c1 = _mm256_setzero_ps(); - __m256 c2 = _mm256_setzero_ps(); - __m256 c3 = _mm256_setzero_ps(); - __m256 c4 = _mm256_setzero_ps(); - for (; i + 32 <= n; i += 32) { - c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1); - c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2); - c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3); - c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4); - } - __m128 g; - c1 = _mm256_add_ps(_mm256_add_ps(c1, c3), - _mm256_add_ps(c2, c4)); - g = _mm_add_ps(_mm256_extractf128_ps(c1, 1), - _mm256_castps256_ps128(c1)); - g = _mm_add_ps(g, _mm_movehl_ps(g, g)); - g = _mm_add_ss(g, _mm_movehdup_ps(g)); - sumf += (ggml_float)_mm_cvtss_f32(g); - -#undef LOAD -#endif - - for (; i < n; ++i) { - sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) * - GGML_BF16_TO_FP32(y[i])); - } - *s = sumf; -} - -static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - ggml_float sumf = 0.0; - -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - - sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - GGML_F16_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); - } -#else - for (int i = 0; i < n; ++i) { - sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); - } -#endif - - *s = sumf; -} - -// compute GGML_VEC_DOT_UNROLL dot products at once -// xs - x row stride in bytes -inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) { - ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; - - ggml_fp16_t * restrict x[GGML_VEC_DOT_UNROLL]; - - for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { - x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); - } - -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); - - sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); - } - } - } - - // reduce sum0..sum3 to sum0 - for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { - GGML_F16_VEC_REDUCE(sumf[k], sum[k]); - } - - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); - } - } -#else - for (int i = 0; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); - } - } -#endif - - for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { - s[i] = sumf[i]; - } -} - -inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float * restrict x, const float v) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] += x[i]*v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] += x[i]*v; - } -#endif -} - -inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ax[GGML_F16_ARR]; - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); - } -#endif -} - -// xs and vs are byte strides of x and v -inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { - - const float * restrict x[GGML_VEC_MAD_UNROLL]; - const float * restrict v[GGML_VEC_MAD_UNROLL]; - - for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) { - x[i] = (const float *) ((const char *) xv + i*xs); - v[i] = (const float *) ((const char *) vv + i*vs); - } - -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - vx[k] = GGML_F32_VEC_SET1(v[k][0]); - } - - GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); - } - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } - } - - // leftovers - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - for (int i = np; i < n; ++i) { - y[i] += x[k][i]*v[k][0]; - } - } -#else - // scalar - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - for (int i = 0; i < n; ++i) { - y[i] += x[k][i]*v[k][0]; - } - } -#endif -} - -//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } -inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { -#if defined(GGML_USE_ACCELERATE) - vDSP_vsmul(y, 1, &v, y, 1, n); -#elif defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_MUL(ay[j], vx); - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] *= v; - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] *= v; - } -#endif -} - -inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F16_STEP - 1)); - - GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); - - GGML_F16_VEC ay[GGML_F16_ARR]; - - for (int i = 0; i < np; i += GGML_F16_STEP) { - for (int j = 0; j < GGML_F16_ARR; j++) { - ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); - ay[j] = GGML_F16_VEC_MUL(ay[j], vx); - - GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); - } - } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); - } -#else - // scalar - for (int i = 0; i < n; ++i) { - y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); - } -#endif -} - -inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } -inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } -inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } -inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); } -inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); } -inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); } -inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } -inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } -inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } -inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } -inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } -inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } -inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } -// TODO: optimize performance -inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } -inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } -inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } - -static const float GELU_COEF_A = 0.044715f; -static const float GELU_QUICK_COEF = -1.702f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -inline static float ggml_gelu_f32(float x) { - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - const uint16_t * i16 = (const uint16_t *) x; - for (int i = 0; i < n; ++i) { - y[i] = ggml_table_gelu_f16[i16[i]]; - } -} - -#ifdef GGML_GELU_FP16 -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - if (x[i] <= -10.0f) { - y[i] = 0.0f; - } else if (x[i] >= 10.0f) { - y[i] = x[i]; - } else { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); - } - } -} -#else -inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_gelu_f32(x[i]); - } -} -#endif - -inline static float ggml_gelu_quick_f32(float x) { - return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); -} - -//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = ggml_table_gelu_quick_f16[i16[i]]; -// } -//} - -#ifdef GGML_GELU_QUICK_FP16 -inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { - uint16_t t; - for (int i = 0; i < n; ++i) { - ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); - memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]); - } -} -#else -inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { - for (int i = 0; i < n; ++i) { - y[i] = ggml_gelu_quick_f32(x[i]); - } -} -#endif - -// Sigmoid Linear Unit (SiLU) function -inline static float ggml_silu_f32(float x) { - return x/(1.0f + expf(-x)); -} - -#if __FINITE_MATH_ONLY__ -#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" -#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461" -#endif - -#if defined(__ARM_NEON) && defined(__aarch64__) - -// adapted from arm limited optimized routine -// the maximum error is 1.45358 plus 0.5 ulps -// numbers above 88.38 will flush to infinity -// numbers beneath -103.97 will flush to zero -inline static float32x4_t ggml_v_expf(float32x4_t x) { - const float32x4_t r = vdupq_n_f32(0x1.8p23f); - const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); - const float32x4_t n = vsubq_f32(z, r); - const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, - vdupq_n_f32(0x1.7f7d1cp-20f)); - const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); - const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); - const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); - const float32x4_t u = vmulq_f32(b, b); - const float32x4_t j = vfmaq_f32( - vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), - vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), - vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); - if (!vpaddd_u64(vreinterpretq_u64_u32(c))) - return vfmaq_f32(k, j, k); - const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); - const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); - const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); - return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), - vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); -} - -// computes silu x/(1+exp(-x)) in single precision vector -inline static float32x4_t ggml_v_silu(float32x4_t x) { - const float32x4_t one = vdupq_n_f32(1.0f); - const float32x4_t zero = vdupq_n_f32(0.0f); - const float32x4_t neg_x = vsubq_f32(zero, x); - const float32x4_t exp_neg_x = ggml_v_expf(neg_x); - const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); - return vdivq_f32(x, one_plus_exp_neg_x); -} - -#elif defined(__AVX512F__) && defined(__AVX512DQ__) - -// adapted from arm limited optimized routine -// the maximum error is 1.45358 plus 0.5 ulps -// numbers above 88.38 will flush to infinity -// numbers beneath -103.97 will flush to zero -inline static __m512 ggml_v_expf(__m512 x) { - const __m512 r = _mm512_set1_ps(0x1.8p23f); - const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); - const __m512 n = _mm512_sub_ps(z, r); - const __m512 b = - _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), - _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); - const __mmask16 d = - _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); - const __m512 u = _mm512_mul_ps(b, b); - const __m512 j = _mm512_fmadd_ps( - _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, - _mm512_set1_ps(0x1.573e2ep-5f)), - u, - _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, - _mm512_set1_ps(0x1.fffdb6p-2f))), - u, - _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); - const __m512 res = _mm512_scalef_ps(j, n); - if (_mm512_kortestz(d, d)) - return res; - const __m512 zero = _mm512_setzero_ps(); - const __m512 alt = _mm512_mask_blend_ps( - _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); - return _mm512_mask_blend_ps(d, res, alt); -} - -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m512 ggml_v_silu(__m512 x) { - const __m512 one = _mm512_set1_ps(1); - const __m512 zero = _mm512_setzero_ps(); - const __m512 neg_x = _mm512_sub_ps(zero, x); - const __m512 exp_neg_x = ggml_v_expf(neg_x); - const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); - return _mm512_div_ps(x, one_plus_exp_neg_x); -} - -#elif defined(__AVX2__) && defined(__FMA__) - -// adapted from arm limited optimized routine -// the maximum error is 1.45358 plus 0.5 ulps -// numbers above 88.38 will flush to infinity -// numbers beneath -103.97 will flush to zero -inline static __m256 ggml_v_expf(__m256 x) { - const __m256 r = _mm256_set1_ps(0x1.8p23f); - const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); - const __m256 n = _mm256_sub_ps(z, r); - const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), - _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); - const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); - const __m256 k = _mm256_castsi256_ps( - _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); - const __m256i c = _mm256_castps_si256( - _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), - _mm256_set1_ps(126), _CMP_GT_OQ)); - const __m256 u = _mm256_mul_ps(b, b); - const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, - _mm256_set1_ps(0x1.573e2ep-5f)), u, - _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, - _mm256_set1_ps(0x1.fffdb6p-2f))), - u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); - if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) - return _mm256_fmadd_ps(j, k, k); - const __m256i g = _mm256_and_si256( - _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), - _mm256_set1_epi32(0x82000000u)); - const __m256 s1 = - _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); - const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); - const __m256i d = _mm256_castps_si256( - _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), - _mm256_set1_ps(192), _CMP_GT_OQ)); - return _mm256_or_ps( - _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), - _mm256_andnot_ps( - _mm256_castsi256_ps(d), - _mm256_or_ps( - _mm256_and_ps(_mm256_castsi256_ps(c), - _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), - _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); -} - -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m256 ggml_v_silu(__m256 x) { - const __m256 one = _mm256_set1_ps(1); - const __m256 zero = _mm256_setzero_ps(); - const __m256 neg_x = _mm256_sub_ps(zero, x); - const __m256 exp_neg_x = ggml_v_expf(neg_x); - const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); - return _mm256_div_ps(x, one_plus_exp_neg_x); -} - -#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON - -#if defined(__FMA__) -#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) -#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) -#else -#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) -#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) -#endif - -// adapted from arm limited optimized routine -// the maximum error is 1.45358 plus 0.5 ulps -// numbers above 88.38 will flush to infinity -// numbers beneath -103.97 will flush to zero -inline static __m128 ggml_v_expf(__m128 x) { - const __m128 r = _mm_set1_ps(0x1.8p23f); - const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r); - const __m128 n = _mm_sub_ps(z, r); - const __m128 b = - NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x)); - const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23); - const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1)))); - const __m128i c = - _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126))); - const __m128 u = _mm_mul_ps(b, b); - const __m128 j = - MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u, - MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))), - u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b)); - if (!_mm_movemask_epi8(c)) - return MADD128(j, k, k); - const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())), - _mm_set1_epi32(0x82000000u)); - const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u))); - const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g)); - const __m128i d = - _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192))); - return _mm_or_ps( - _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)), - _mm_andnot_ps(_mm_castsi128_ps(d), - _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)), - _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k))))); -} - -// computes silu x/(1+exp(-x)) in single precision vector -inline static __m128 ggml_v_silu(__m128 x) { - const __m128 one = _mm_set1_ps(1); - const __m128 zero = _mm_setzero_ps(); - const __m128 neg_x = _mm_sub_ps(zero, x); - const __m128 exp_neg_x = ggml_v_expf(neg_x); - const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); - return _mm_div_ps(x, one_plus_exp_neg_x); -} - -#endif // __ARM_NEON / __AVX2__ / __SSE2__ - -static void ggml_vec_silu_f32(const int n, float * y, const float * x) { - int i = 0; -#if defined(__AVX512F__) && defined(__AVX512DQ__) - for (; i + 15 < n; i += 16) { - _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i))); - } -#elif defined(__AVX2__) && defined(__FMA__) - for (; i + 7 < n; i += 8) { - _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i))); - } -#elif defined(__SSE2__) - for (; i + 3 < n; i += 4) { - _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i))); - } -#elif defined(__ARM_NEON) && defined(__aarch64__) - for (; i + 3 < n; i += 4) { - vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i))); - } -#endif - for (; i < n; ++i) { - y[i] = ggml_silu_f32(x[i]); - } -} - -static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { - int i = 0; - ggml_float sum = 0; -#if defined(__AVX512F__) && defined(__AVX512DQ__) - for (; i + 15 < n; i += 16) { - __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i), - _mm512_set1_ps(max))); - _mm512_storeu_ps(y + i, val); - sum += (ggml_float)_mm512_reduce_add_ps(val); - } -#elif defined(__AVX2__) && defined(__FMA__) - for (; i + 7 < n; i += 8) { - __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i), - _mm256_set1_ps(max))); - _mm256_storeu_ps(y + i, val); - __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), - _mm256_castps256_ps128(val)); - val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); - val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); - sum += (ggml_float)_mm_cvtss_f32(val2); - } -#elif defined(__SSE2__) - for (; i + 3 < n; i += 4) { - __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i), - _mm_set1_ps(max))); - _mm_storeu_ps(y + i, val); -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) - val = _mm_add_ps(val, _mm_movehl_ps(val, val)); - val = _mm_add_ss(val, _mm_movehdup_ps(val)); -#else - __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); - val = _mm_add_ps(val, tmp); - tmp = _mm_movehl_ps(tmp, val); - val = _mm_add_ss(val, tmp); -#endif - sum += (ggml_float)_mm_cvtss_f32(val); - } -#elif defined(__ARM_NEON) && defined(__aarch64__) - for (; i + 3 < n; i += 4) { - float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i), - vdupq_n_f32(max))); - vst1q_f32(y + i, val); - sum += (ggml_float)vaddvq_f32(val); - } -#endif - for (; i < n; ++i) { - float val = expf(x[i] - max); - sum += (ggml_float)val; - y[i] = val; - } - return sum; -} - -static ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) { - // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i) - - int i = 0; - ggml_float sum = 0; - for (; i < n; ++i) { - float val = x[i] - max; - y[i] = val; - sum += (ggml_float)expf(val); - } - return sum = (ggml_float)logf(sum); -} - -inline static float ggml_silu_backward_f32(float x, float dy) { - const float s = 1.0f/(1.0f + expf(-x)); - return dy*s*(1.0f + x*(1.0f - s)); -} - -inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { - for (int i = 0; i < n; ++i) { - dx[i] = ggml_silu_backward_f32(x[i], dy[i]); - } -} - -inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { -#ifndef GGML_USE_ACCELERATE - ggml_float sum = 0.0; - for (int i = 0; i < n; ++i) { - sum += (ggml_float)x[i]; - } - *s = sum; -#else - vDSP_sve(x, 1, s, n); -#endif -} - -inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { - ggml_float sum = 0.0; - for (int i = 0; i < n; ++i) { - sum += (ggml_float)x[i]; - } - *s = sum; -} - -inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) { - float sum = 0.0f; - for (int i = 0; i < n; ++i) { - sum += GGML_FP16_TO_FP32(x[i]); - } - *s = sum; -} - -inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) { - float sum = 0.0f; - for (int i = 0; i < n; ++i) { - sum += GGML_BF16_TO_FP32(x[i]); - } - *s = sum; -} - -inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { -#ifndef GGML_USE_ACCELERATE - float max = -INFINITY; - for (int i = 0; i < n; ++i) { - max = MAX(max, x[i]); - } - *s = max; -#else - vDSP_maxv(x, 1, s, n); -#endif -} - -inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { - ggml_vec_norm_f32(n, s, x); - *s = 1.f/(*s); -} - -inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { - float max = -INFINITY; - int idx = 0; - for (int i = 0; i < n; ++i) { - max = MAX(max, x[i]); - if (max == x[i]) { idx = i; } - } - *s = idx; -} - // // data types // @@ -2907,6 +916,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SUM_ROWS", "MEAN", "ARGMAX", + "COUNT_EQUAL", "REPEAT", "REPEAT_BACK", "CONCAT", @@ -2947,6 +957,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "POOL_2D_BACK", "UPSCALE", "PAD", + "PAD_REFLECT_1D", "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", @@ -2960,7 +971,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "WIN_UNPART", "GET_REL_POS", "ADD_REL_POS", - "RWKV_WKV", + "RWKV_WKV6", + "GATED_LINEAR_ATTN", "UNARY", @@ -2977,9 +989,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", + "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3000,6 +1013,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "Σx_k", "Σx/n", "argmax(x)", + "count_equal(x)", "repeat(x)", "repeat_back(x)", "concat(x, y)", @@ -3040,6 +1054,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "pool_2d_back(x)", "upscale(x)", "pad(x)", + "pad_reflect_1d(x)", "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", @@ -3053,7 +1068,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "win_unpart(x)", "get_rel_pos(x)", "add_rel_pos(x)", - "rwkv_wkv(k, v, r, tf, td, s)", + "rwkv_wkv6(k, v, r, tf, td, s)", + "gated_linear_attn(k, v, q, gate, s)", "unary(x)", @@ -3070,9 +1086,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", + "adamw(x)", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -3100,249 +1117,42 @@ static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); -// Helpers for polling loops -#if defined(__aarch64__) && ( defined(__clang__) || defined(__GNUC__) ) -static inline void ggml_thread_cpu_relax(void) { - __asm__ volatile("yield" ::: "memory"); -} -#elif defined(__x86_64__) -static inline void ggml_thread_cpu_relax(void) { - _mm_pause(); -} -#else -static inline void ggml_thread_cpu_relax(void) {;} -#endif - -// -// NUMA support -// - -#define GGML_NUMA_MAX_NODES 8 -#define GGML_NUMA_MAX_CPUS 512 - -struct ggml_numa_node { - uint32_t cpus[GGML_NUMA_MAX_CPUS]; // hardware threads on this node - uint32_t n_cpus; -}; - -struct ggml_numa_nodes { - enum ggml_numa_strategy numa_strategy; - struct ggml_numa_node nodes[GGML_NUMA_MAX_NODES]; - uint32_t n_nodes; - uint32_t total_cpus; // hardware threads on system - uint32_t current_node; // node on which main process is execting -#if defined(__gnu_linux__) - cpu_set_t cpuset; // cpuset from numactl -#else - uint32_t cpuset; // no NUMA support outside of Linux at this time. Use a portable datatype -#endif -}; - -// -// ggml state -// - -struct ggml_state { - struct ggml_context_container contexts[GGML_MAX_CONTEXTS]; - struct ggml_numa_nodes numa; -}; - -// global state -static struct ggml_state g_state; -static atomic_flag g_state_critical = ATOMIC_FLAG_INIT; - -// critical section via spin lock -inline static void ggml_critical_section_start(void) { - while (atomic_flag_test_and_set(&g_state_critical)) { - // spin - sched_yield(); - } -} - -#ifdef GGML_USE_OPENMP -static void ggml_barrier(struct ggml_threadpool * threadpool) { - if (threadpool->n_threads_cur == 1) { - return; - } - - #pragma omp barrier -} -#else -static void ggml_barrier(struct ggml_threadpool * threadpool) { - if (threadpool->n_threads_cur == 1) { - return; - } - - atomic_int * n_barrier = &threadpool->n_barrier; - atomic_int * n_barrier_passed = &threadpool->n_barrier_passed; - - int n_threads = threadpool->n_threads_cur; - int passed_old = atomic_load_explicit(n_barrier_passed, memory_order_relaxed); - - if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) { - // last thread - atomic_store(n_barrier, 0); - atomic_fetch_add_explicit(n_barrier_passed, 1, memory_order_relaxed); - } else { - // wait for other threads - while (true) { - if (atomic_load_explicit(n_barrier_passed, memory_order_relaxed) != passed_old) { - return; - } - ggml_thread_cpu_relax(); - } - } -} -#endif - -// TODO: make this somehow automatically executed -// some sort of "sentry" mechanism -inline static void ggml_critical_section_end(void) { - atomic_flag_clear(&g_state_critical); -} - -#if defined(__gnu_linux__) -static cpu_set_t ggml_get_numa_affinity(void) { - cpu_set_t cpuset; - pthread_t thread; - thread = pthread_self(); - CPU_ZERO(&cpuset); - pthread_getaffinity_np(thread, sizeof(cpu_set_t), &cpuset); - return cpuset; -} -#else -static uint32_t ggml_get_numa_affinity(void) { - return 0; // no NUMA support -} -#endif - -void ggml_numa_init(enum ggml_numa_strategy numa_flag) { - if (g_state.numa.n_nodes > 0) { - fprintf(stderr, "ggml_numa_init: NUMA already initialized\n"); - - return; - } - -#if defined(__gnu_linux__) - struct stat st; - char path[256]; - int rv; - - // set numa scheme - g_state.numa.numa_strategy = numa_flag; - - GGML_PRINT_DEBUG("numa strategy %u\n",g_state.numa.numa_strategy); - - g_state.numa.cpuset = ggml_get_numa_affinity(); - - // enumerate nodes - while (g_state.numa.n_nodes < GGML_NUMA_MAX_NODES) { - rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u", g_state.numa.n_nodes); - GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); - if (stat(path, &st) != 0) { break; } - ++g_state.numa.n_nodes; - } - - // enumerate CPUs - while (g_state.numa.total_cpus < GGML_NUMA_MAX_CPUS) { - rv = snprintf(path, sizeof(path), "/sys/devices/system/cpu/cpu%u", g_state.numa.total_cpus); - GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); - if (stat(path, &st) != 0) { break; } - ++g_state.numa.total_cpus; - } - - GGML_PRINT_DEBUG("found %u numa nodes, %u CPUs\n", g_state.numa.n_nodes, g_state.numa.total_cpus); - - // figure out which node we're on - uint current_cpu; - int getcpu_ret = 0; -#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__) - getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); -#else - // old glibc doesn't have a wrapper for this call. Fall back on direct syscall -# if !defined(SYS_getcpu) && defined(SYS_get_cpu) -# define SYS_getcpu SYS_get_cpu // some older glibc versions use this name -# endif - getcpu_ret = syscall(SYS_getcpu, ¤t_cpu, &g_state.numa.current_node); -#endif - - if (g_state.numa.n_nodes < 1 || g_state.numa.total_cpus < 1 || getcpu_ret != 0) { - g_state.numa.n_nodes = 0; - return; - } - - GGML_PRINT_DEBUG("found our process on numa node %u, CPU %u\n", g_state.numa.current_node, current_cpu); - - for (uint32_t n = 0; n < g_state.numa.n_nodes; ++n) { - struct ggml_numa_node * node = &g_state.numa.nodes[n]; - GGML_PRINT_DEBUG("CPUs on node %u:", n); - node->n_cpus = 0; - for (uint32_t c = 0; c < g_state.numa.total_cpus; ++c) { - rv = snprintf(path, sizeof(path), "/sys/devices/system/node/node%u/cpu%u", n, c); - GGML_ASSERT(rv > 0 && (unsigned)rv < sizeof(path)); - if (stat(path, &st) == 0) { - node->cpus[node->n_cpus++] = c; - GGML_PRINT_DEBUG(" %u", c); - } - } - GGML_PRINT_DEBUG("\n"); - } - - if (ggml_is_numa()) { - FILE *fptr = fopen("/proc/sys/kernel/numa_balancing", "r"); - if (fptr != NULL) { - char buf[42]; - if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) { - GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n"); - } - fclose(fptr); - } - } -#else - UNUSED(numa_flag); - // TODO -#endif -} - -bool ggml_is_numa(void) { - return g_state.numa.n_nodes > 1; -} //////////////////////////////////////////////////////////////////////////////// void ggml_print_object(const struct ggml_object * obj) { - GGML_PRINT(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", + GGML_LOG_INFO(" - ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n", obj->type, obj->offs, obj->size, (const void *) obj->next); } void ggml_print_objects(const struct ggml_context * ctx) { struct ggml_object * obj = ctx->objects_begin; - GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx); + GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx); while (obj != NULL) { ggml_print_object(obj); obj = obj->next; } - GGML_PRINT("%s: --- end ---\n", __func__); + GGML_LOG_INFO("%s: --- end ---\n", __func__); } -GGML_CALL int64_t ggml_nelements(const struct ggml_tensor * tensor) { +int64_t ggml_nelements(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } -GGML_CALL int64_t ggml_nrows(const struct ggml_tensor * tensor) { +int64_t ggml_nrows(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return tensor->ne[1]*tensor->ne[2]*tensor->ne[3]; } -GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) { +size_t ggml_nbytes(const struct ggml_tensor * tensor) { size_t nbytes; - size_t blck_size = ggml_blck_size(tensor->type); + const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { nbytes = ggml_type_size(tensor->type); for (int i = 0; i < GGML_MAX_DIMS; ++i) { @@ -3363,15 +1173,15 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); } -GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { +int64_t ggml_blck_size(enum ggml_type type) { return type_traits[type].blck_size; } -GGML_CALL size_t ggml_type_size(enum ggml_type type) { +size_t ggml_type_size(enum ggml_type type) { return type_traits[type].type_size; } -GGML_CALL size_t ggml_row_size(enum ggml_type type, int64_t ne) { +size_t ggml_row_size(enum ggml_type type, int64_t ne) { assert(ne % ggml_blck_size(type) == 0); return ggml_type_size(type)*ne/ggml_blck_size(type); } @@ -3380,15 +1190,15 @@ double ggml_type_sizef(enum ggml_type type) { return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; } -GGML_CALL const char * ggml_type_name(enum ggml_type type) { - return type_traits[type].type_name; +const char * ggml_type_name(enum ggml_type type) { + return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE"; } -GGML_CALL bool ggml_is_quantized(enum ggml_type type) { +bool ggml_is_quantized(enum ggml_type type) { return type_traits[type].is_quantized; } -GGML_CALL const char * ggml_op_name(enum ggml_op op) { +const char * ggml_op_name(enum ggml_op op) { return GGML_OP_NAME[op]; } @@ -3400,7 +1210,7 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) { return GGML_UNARY_OP_NAME[op]; } -GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t) { +const char * ggml_op_desc(const struct ggml_tensor * t) { if (t->op == GGML_OP_UNARY) { enum ggml_unary_op uop = ggml_get_unary_op(t); return ggml_unary_op_name(uop); @@ -3408,7 +1218,7 @@ GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t) { return ggml_op_name(t->op); } -GGML_CALL size_t ggml_element_size(const struct ggml_tensor * tensor) { +size_t ggml_element_size(const struct ggml_tensor * tensor) { return ggml_type_size(tensor->type); } @@ -3443,22 +1253,6 @@ int ggml_n_dims(const struct ggml_tensor * tensor) { return 1; } -static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return (t0->ne[0] == t1->ne[0]) && - (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable - (t1->ne[3]%t0->ne[3] == 0); -} - -static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return (t0->ne[1] == t1->ne[1]) && - (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable - (t1->ne[3]%t0->ne[3] == 0); -} - enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { enum ggml_type wtype = GGML_TYPE_COUNT; @@ -3485,9 +1279,6 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; - case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; - case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break; - case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -3501,7 +1292,7 @@ size_t ggml_tensor_overhead(void) { return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; } -GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor) { +bool ggml_is_transposed(const struct ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1]; } @@ -3527,23 +1318,23 @@ static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) { return true; } -GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) { +bool ggml_is_contiguous(const struct ggml_tensor * tensor) { return ggml_is_contiguous_0(tensor); } -GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) { +bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) { return ggml_is_contiguous_n(tensor, 0); } -GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) { +bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) { return ggml_is_contiguous_n(tensor, 1); } -GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) { +bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) { return ggml_is_contiguous_n(tensor, 2); } -GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) { +bool ggml_is_permuted(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; @@ -3558,7 +1349,7 @@ static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } -GGML_CALL bool ggml_is_empty(const struct ggml_tensor * tensor) { +bool ggml_is_empty(const struct ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { if (tensor->ne[i] == 0) { // empty if any dimension has no elements @@ -3605,20 +1396,6 @@ static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const str return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1); } -static inline int ggml_up32(int n) { - return (n + 31) & ~31; -} - -//static inline int ggml_up64(int n) { -// return (n + 63) & ~63; -//} - -static inline int ggml_up(int n, int m) { - // assert m is a power of 2 - GGML_ASSERT((m & (m - 1)) == 0); - return (n + m - 1) & ~(m - 1); -} - // assert that pointer is aligned to GGML_MEM_ALIGN #define GGML_ASSERT_ALIGNED(ptr) \ GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) @@ -3626,78 +1403,28 @@ static inline int ggml_up(int n, int m) { //////////////////////////////////////////////////////////////////////////////// struct ggml_context * ggml_init(struct ggml_init_params params) { - // make this function thread safe - ggml_critical_section_start(); - static bool is_first_call = true; + ggml_critical_section_start(); + if (is_first_call) { // initialize time system (required on Windows) ggml_time_init(); - // initialize GELU, Quick GELU, SILU and EXP F32 tables - { - const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - - for (int i = 0; i < (1 << 16); ++i) { - union { - uint16_t u16; - ggml_fp16_t fp16; - } u = {i}; - float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); - ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); - ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f)); - } - - const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - - GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); - } - - // initialize g_state - { - const uint64_t t_start = ggml_time_us(); UNUSED(t_start); - - g_state = (struct ggml_state) { - /*.contexts =*/ { { 0 } }, - /*.numa =*/ { - .n_nodes = 0, - .total_cpus = 0, - }, - }; - - for (int i = 0; i < GGML_MAX_CONTEXTS; ++i) { - g_state.contexts[i].used = false; - } - - const uint64_t t_end = ggml_time_us(); UNUSED(t_end); - - GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); + for (int i = 0; i < (1 << 16); ++i) { + union { + uint16_t u16; + ggml_fp16_t fp16; + } u = {i}; + ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); } is_first_call = false; } - // find non-used context in g_state - struct ggml_context * ctx = NULL; + ggml_critical_section_end(); - for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { - if (!g_state.contexts[i].used) { - g_state.contexts[i].used = true; - ctx = &g_state.contexts[i].context; - - GGML_PRINT_DEBUG("%s: found unused context %d\n", __func__, i); - break; - } - } - - if (ctx == NULL) { - GGML_PRINT_DEBUG("%s: no unused context found\n", __func__); - - ggml_critical_section_end(); - - return NULL; - } + struct ggml_context * ctx = GGML_MALLOC(sizeof(struct ggml_context)); // allow to call ggml_init with 0 size if (params.mem_size == 0) { @@ -3708,79 +1435,49 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { *ctx = (struct ggml_context) { /*.mem_size =*/ mem_size, - /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size), + /*.mem_buffer =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size), /*.mem_buffer_owned =*/ params.mem_buffer ? false : true, /*.no_alloc =*/ params.no_alloc, - /*.no_alloc_save =*/ params.no_alloc, /*.n_objects =*/ 0, /*.objects_begin =*/ NULL, /*.objects_end =*/ NULL, - /*.scratch =*/ { 0, 0, NULL, }, - /*.scratch_save =*/ { 0, 0, NULL, }, }; GGML_ASSERT(ctx->mem_buffer != NULL); GGML_ASSERT_ALIGNED(ctx->mem_buffer); -#if defined(__ARM_FEATURE_SVE) - if (!ggml_sve_cnt_b) { - ggml_sve_cnt_b = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); - } -#endif - GGML_PRINT_DEBUG("%s: context initialized\n", __func__); - ggml_critical_section_end(); - return ctx; } +void ggml_reset(struct ggml_context * ctx) { + if (ctx == NULL) { + return; + } + + ctx->n_objects = 0; + ctx->objects_begin = NULL; + ctx->objects_end = NULL; +} + void ggml_free(struct ggml_context * ctx) { if (ctx == NULL) { return; } - // make this function thread safe - ggml_critical_section_start(); - - bool found = false; - - for (int i = 0; i < GGML_MAX_CONTEXTS; i++) { - if (&g_state.contexts[i].context == ctx) { - g_state.contexts[i].used = false; - - GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n", - __func__, i, ggml_used_mem(ctx)); - - if (ctx->mem_buffer_owned) { - GGML_ALIGNED_FREE(ctx->mem_buffer); - } - - found = true; - break; - } + if (ctx->mem_buffer_owned) { + ggml_aligned_free(ctx->mem_buffer, ctx->mem_size); } - if (!found) { - GGML_PRINT_DEBUG("%s: context not found\n", __func__); - } - - ggml_critical_section_end(); + GGML_FREE(ctx); } size_t ggml_used_mem(const struct ggml_context * ctx) { return ctx->objects_end == NULL ? 0 : ctx->objects_end->offs + ctx->objects_end->size; } -size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch) { - const size_t result = ctx->scratch.data ? ctx->scratch.offs : 0; - - ctx->scratch = scratch; - - return result; -} - bool ggml_get_no_alloc(struct ggml_context * ctx) { return ctx->no_alloc; } @@ -3808,27 +1505,6 @@ size_t ggml_get_max_tensor_size(const struct ggml_context * ctx) { return max_size; } -// IMPORTANT: -// when creating "opt" tensors, always save and load the scratch buffer -// this is an error prone process, but it is necessary to support inplace -// operators when using scratch buffers -// TODO: implement a better way -static void ggml_scratch_save(struct ggml_context * ctx) { - // this is needed to allow opt tensors to store their data - // TODO: again, need to find a better way - ctx->no_alloc_save = ctx->no_alloc; - ctx->no_alloc = false; - - ctx->scratch_save = ctx->scratch; - ctx->scratch.data = NULL; -} - -static void ggml_scratch_load(struct ggml_context * ctx) { - ctx->no_alloc = ctx->no_alloc_save; - - ctx->scratch = ctx->scratch_save; -} - //////////////////////////////////////////////////////////////////////////////// static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml_object_type type, size_t size) { @@ -3846,9 +1522,11 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { - GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", - __func__, cur_end + size_needed, ctx->mem_size); - assert(false); + GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", + __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); +#ifndef NDEBUG + GGML_ABORT("not enough space in the context's memory pool"); +#endif return NULL; } @@ -3907,60 +1585,32 @@ static struct ggml_tensor * ggml_new_tensor_impl( size_t obj_alloc_size = 0; if (view_src == NULL && !ctx->no_alloc) { - if (ctx->scratch.data != NULL) { - // allocate tensor data in the scratch buffer - if (ctx->scratch.offs + data_size > ctx->scratch.size) { - GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n", - __func__, ctx->scratch.offs + data_size, ctx->scratch.size); - assert(false); - return NULL; - } - - data = (char * const) ctx->scratch.data + ctx->scratch.offs; - - ctx->scratch.offs += data_size; - } else { - // allocate tensor data in the context's memory pool - obj_alloc_size = data_size; - } + // allocate tensor data in the context's memory pool + obj_alloc_size = data_size; } struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); GGML_ASSERT(obj_new); - // TODO: for recoverable errors, we would need to free the data allocated from the scratch buffer here - struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs); -#ifdef __clang__ - // temporary until ggml_tensor::backend is removed - #pragma clang diagnostic push - #pragma clang diagnostic ignored "-Wdeprecated-declarations" -#endif - *result = (struct ggml_tensor) { /*.type =*/ type, - /*.backend =*/ GGML_BACKEND_TYPE_CPU, /*.buffer =*/ NULL, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, /*.op =*/ GGML_OP_NONE, /*.op_params =*/ { 0 }, /*.flags =*/ 0, - /*.grad =*/ NULL, /*.src =*/ { NULL }, /*.view_src =*/ view_src, /*.view_offs =*/ view_offs, /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, /*.name =*/ { 0 }, /*.extra =*/ NULL, - ///*.padding =*/ { 0 }, + /*.padding =*/ { 0 }, }; -#ifdef __clang__ - #pragma clang diagnostic pop -#endif - // TODO: this should not be needed as long as we don't rely on aligned SIMD loads //GGML_ASSERT_ALIGNED(result->data); @@ -4024,183 +1674,16 @@ struct ggml_tensor * ggml_new_tensor_4d( return ggml_new_tensor(ctx, type, 4, ne); } -struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) { - ggml_scratch_save(ctx); +void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes) { + struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, nbytes); - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - - ggml_scratch_load(ctx); - - ggml_set_i32(result, value); - - return result; -} - -struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value) { - ggml_scratch_save(ctx); - - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); - - ggml_scratch_load(ctx); - - ggml_set_f32(result, value); - - return result; + return (uint8_t *)ctx->mem_buffer + obj->offs; } struct ggml_tensor * ggml_dup_tensor(struct ggml_context * ctx, const struct ggml_tensor * src) { return ggml_new_tensor(ctx, src->type, GGML_MAX_DIMS, src->ne); } -static void ggml_set_op_params(struct ggml_tensor * tensor, const void * params, size_t params_size) { - GGML_ASSERT(tensor != NULL); // silence -Warray-bounds warnings - assert(params_size <= GGML_MAX_OP_PARAMS); - memcpy(tensor->op_params, params, params_size); -} - -static int32_t ggml_get_op_params_i32(const struct ggml_tensor * tensor, uint32_t i) { - assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); - return ((const int32_t *)(tensor->op_params))[i]; -} - -static float ggml_get_op_params_f32(const struct ggml_tensor * tensor, uint32_t i) { - assert(i < GGML_MAX_OP_PARAMS / sizeof(float)); - return ((const float *)(tensor->op_params))[i]; -} - -static void ggml_set_op_params_i32(struct ggml_tensor * tensor, uint32_t i, int32_t value) { - assert(i < GGML_MAX_OP_PARAMS / sizeof(int32_t)); - ((int32_t *)(tensor->op_params))[i] = value; -} - -static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, float value) { - assert(i < GGML_MAX_OP_PARAMS / sizeof(float)); - ((float *)(tensor->op_params))[i] = value; -} - -struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { - memset(tensor->data, 0, ggml_nbytes(tensor)); - return tensor; -} - -struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { - const int n = ggml_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; - - char * const data = tensor->data; - - switch (tensor->type) { - case GGML_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(ggml_fp16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); - } - } break; - case GGML_TYPE_BF16: - { - assert(tensor->nb[0] == sizeof(ggml_fp16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value)); - } - } break; - case GGML_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - default: - { - GGML_ABORT("fatal error"); - } - } - - return tensor; -} - -struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { - const int n = ggml_nrows(tensor); - const int nc = tensor->ne[0]; - const size_t n1 = tensor->nb[1]; - - char * const data = tensor->data; - - switch (tensor->type) { - case GGML_TYPE_I8: - { - assert(tensor->nb[0] == sizeof(int8_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i8(nc, (int8_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I16: - { - assert(tensor->nb[0] == sizeof(int16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i16(nc, (int16_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_I32: - { - assert(tensor->nb[0] == sizeof(int32_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_i32(nc, (int32_t *)(data + i*n1), value); - } - } break; - case GGML_TYPE_F16: - { - assert(tensor->nb[0] == sizeof(ggml_fp16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value)); - } - } break; - case GGML_TYPE_BF16: - { - assert(tensor->nb[0] == sizeof(ggml_bf16_t)); - for (int i = 0; i < n; i++) { - ggml_vec_set_bf16(nc, (ggml_bf16_t *)(data + i*n1), GGML_FP32_TO_BF16(value)); - } - } break; - case GGML_TYPE_F32: - { - assert(tensor->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { - ggml_vec_set_f32(nc, (float *)(data + i*n1), value); - } - } break; - default: - { - GGML_ABORT("fatal error"); - } - } - - return tensor; -} - void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) { const int64_t ne2 = tensor->ne[2]; const int64_t ne1 = tensor->ne[1]; @@ -4225,280 +1708,6 @@ void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * } } -int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { - if (!ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - return ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]); - } - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - return ((int8_t *)(tensor->data))[i]; - } - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - return ((int16_t *)(tensor->data))[i]; - } - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - return ((int32_t *)(tensor->data))[i]; - } - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } - case GGML_TYPE_BF16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); - return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]); - } - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - return ((float *)(tensor->data))[i]; - } - default: - { - GGML_ABORT("fatal error"); - } - } -} - -void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { - if (!ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value); - return; - } - switch (tensor->type) { - case GGML_TYPE_I8: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); - ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_BF16: - { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); - ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case GGML_TYPE_I8: - return ((int8_t *) data)[0]; - case GGML_TYPE_I16: - return ((int16_t *) data)[0]; - case GGML_TYPE_I32: - return ((int32_t *) data)[0]; - case GGML_TYPE_F16: - return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); - case GGML_TYPE_BF16: - return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]); - case GGML_TYPE_F32: - return ((float *) data)[0]; - default: - GGML_ABORT("fatal error"); - } -} - -void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case GGML_TYPE_I8: - { - ((int8_t *)(data))[0] = value; - } break; - case GGML_TYPE_I16: - { - ((int16_t *)(data))[0] = value; - } break; - case GGML_TYPE_I32: - { - ((int32_t *)(data))[0] = value; - } break; - case GGML_TYPE_F16: - { - ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_BF16: - { - ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value); - } break; - case GGML_TYPE_F32: - { - ((float *)(data))[0] = value; - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { - if (!ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - return ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]); - } - switch (tensor->type) { - case GGML_TYPE_I8: - { - return ((int8_t *)(tensor->data))[i]; - } - case GGML_TYPE_I16: - { - return ((int16_t *)(tensor->data))[i]; - } - case GGML_TYPE_I32: - { - return ((int32_t *)(tensor->data))[i]; - } - case GGML_TYPE_F16: - { - return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); - } - case GGML_TYPE_BF16: - { - return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]); - } - case GGML_TYPE_F32: - { - return ((float *)(tensor->data))[i]; - } - default: - { - GGML_ABORT("fatal error"); - } - } -} - -void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { - if (!ggml_is_contiguous(tensor)) { - int64_t id[4] = { 0, 0, 0, 0 }; - ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); - ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); - return; - } - switch (tensor->type) { - case GGML_TYPE_I8: - { - ((int8_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I16: - { - ((int16_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_I32: - { - ((int32_t *)(tensor->data))[i] = value; - } break; - case GGML_TYPE_F16: - { - ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_BF16: - { - ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value); - } break; - case GGML_TYPE_F32: - { - ((float *)(tensor->data))[i] = value; - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case GGML_TYPE_I8: - return ((int8_t *) data)[0]; - case GGML_TYPE_I16: - return ((int16_t *) data)[0]; - case GGML_TYPE_I32: - return ((int32_t *) data)[0]; - case GGML_TYPE_F16: - return GGML_FP16_TO_FP32(((ggml_fp16_t *) data)[0]); - case GGML_TYPE_BF16: - return GGML_BF16_TO_FP32(((ggml_bf16_t *) data)[0]); - case GGML_TYPE_F32: - return ((float *) data)[0]; - default: - GGML_ABORT("fatal error"); - } -} - -void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { - void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; - switch (tensor->type) { - case GGML_TYPE_I8: - { - ((int8_t *)(data))[0] = value; - } break; - case GGML_TYPE_I16: - { - ((int16_t *)(data))[0] = value; - } break; - case GGML_TYPE_I32: - { - ((int32_t *)(data))[0] = value; - } break; - case GGML_TYPE_F16: - { - ((ggml_fp16_t *)(data))[0] = GGML_FP32_TO_FP16(value); - } break; - case GGML_TYPE_BF16: - { - ((ggml_bf16_t *)(data))[0] = GGML_FP32_TO_BF16(value); - } break; - case GGML_TYPE_F32: - { - ((float *)(data))[0] = value; - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - void * ggml_get_data(const struct ggml_tensor * tensor) { return tensor->data; } @@ -4508,7 +1717,7 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) { return (float *)(tensor->data); } -GGML_CALL enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { +enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) { GGML_ASSERT(tensor->op == GGML_OP_UNARY); return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0); } @@ -4605,18 +1814,11 @@ struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * nam static struct ggml_tensor * ggml_dup_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct ggml_tensor * a, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_DUP; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DUP; result->src[0] = a; return result; @@ -4624,13 +1826,13 @@ static struct ggml_tensor * ggml_dup_impl( struct ggml_tensor * ggml_dup( struct ggml_context * ctx, - struct ggml_tensor * a) { + struct ggml_tensor * a) { return ggml_dup_impl(ctx, a, false); } struct ggml_tensor * ggml_dup_inplace( struct ggml_context * ctx, - struct ggml_tensor * a) { + struct ggml_tensor * a) { return ggml_dup_impl(ctx, a, true); } @@ -4638,21 +1840,14 @@ struct ggml_tensor * ggml_dup_inplace( static struct ggml_tensor * ggml_add_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_ADD; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ADD; result->src[0] = a; result->src[1] = b; @@ -4661,15 +1856,15 @@ static struct ggml_tensor * ggml_add_impl( struct ggml_tensor * ggml_add( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_add_impl(ctx, a, b, false); } struct ggml_tensor * ggml_add_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_add_impl(ctx, a, b, true); } @@ -4677,9 +1872,9 @@ struct ggml_tensor * ggml_add_inplace( static struct ggml_tensor * ggml_add_cast_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - enum ggml_type type) { + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { // TODO: support less-strict constraint // GGML_ASSERT(ggml_can_repeat(b, a)); GGML_ASSERT(ggml_can_repeat_rows(b, a)); @@ -4689,18 +1884,9 @@ static struct ggml_tensor * ggml_add_cast_impl( a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16); - bool is_node = false; - - if (a->grad || b->grad) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne); - result->op = GGML_OP_ADD; - result->grad = is_node ? ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne) : NULL; + result->op = GGML_OP_ADD; result->src[0] = a; result->src[1] = b; @@ -4709,9 +1895,9 @@ static struct ggml_tensor * ggml_add_cast_impl( struct ggml_tensor * ggml_add_cast( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - enum ggml_type type) { + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type) { return ggml_add_cast_impl(ctx, a, b, type); } @@ -4719,22 +1905,15 @@ struct ggml_tensor * ggml_add_cast( static struct ggml_tensor * ggml_add1_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_is_scalar(b)); GGML_ASSERT(ggml_is_padded_1d(a)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_ADD1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ADD1; result->src[0] = a; result->src[1] = b; @@ -4743,15 +1922,15 @@ static struct ggml_tensor * ggml_add1_impl( struct ggml_tensor * ggml_add1( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_add1_impl(ctx, a, b, false); } struct ggml_tensor * ggml_add1_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_add1_impl(ctx, a, b, true); } @@ -4759,31 +1938,24 @@ struct ggml_tensor * ggml_add1_inplace( static struct ggml_tensor * ggml_acc_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset, + bool inplace) { GGML_ASSERT(ggml_nelements(b) <= ggml_nelements(a)); GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(a->type == GGML_TYPE_F32); GGML_ASSERT(b->type == GGML_TYPE_F32); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_ACC; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ACC; result->src[0] = a; result->src[1] = b; @@ -4792,23 +1964,23 @@ static struct ggml_tensor * ggml_acc_impl( struct ggml_tensor * ggml_acc( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); } struct ggml_tensor * ggml_acc_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - size_t nb1, - size_t nb2, - size_t nb3, - size_t offset) { + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset) { return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); } @@ -4816,23 +1988,14 @@ struct ggml_tensor * ggml_acc_inplace( static struct ggml_tensor * ggml_sub_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SUB; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SUB; result->src[0] = a; result->src[1] = b; @@ -4841,15 +2004,15 @@ static struct ggml_tensor * ggml_sub_impl( struct ggml_tensor * ggml_sub( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_sub_impl(ctx, a, b, false); } struct ggml_tensor * ggml_sub_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { return ggml_sub_impl(ctx, a, b, true); } @@ -4857,27 +2020,14 @@ struct ggml_tensor * ggml_sub_inplace( static struct ggml_tensor * ggml_mul_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - // TODO: support backward pass for broadcasting - GGML_ASSERT(ggml_are_same_shape(a, b)); - is_node = true; - } - - if (inplace) { - GGML_ASSERT(!is_node); - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_MUL; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MUL; result->src[0] = a; result->src[1] = b; @@ -4902,25 +2052,14 @@ struct ggml_tensor * ggml_mul_inplace( static struct ggml_tensor * ggml_div_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - bool inplace) { + struct ggml_tensor * a, + struct ggml_tensor * b, + bool inplace) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - - if (inplace) { - GGML_ASSERT(!is_node); - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_DIV; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DIV; result->src[0] = a; result->src[1] = b; @@ -4945,18 +2084,11 @@ struct ggml_tensor * ggml_div_inplace( static struct ggml_tensor * ggml_sqr_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct ggml_tensor * a, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SQR; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SQR; result->src[0] = a; return result; @@ -4978,18 +2110,11 @@ struct ggml_tensor * ggml_sqr_inplace( static struct ggml_tensor * ggml_sqrt_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + struct ggml_tensor * a, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SQRT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SQRT; result->src[0] = a; return result; @@ -5012,17 +2137,10 @@ struct ggml_tensor * ggml_sqrt_inplace( static struct ggml_tensor * ggml_log_impl( struct ggml_context * ctx, struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_LOG; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_LOG; result->src[0] = a; return result; @@ -5045,17 +2163,10 @@ struct ggml_tensor * ggml_log_inplace( static struct ggml_tensor * ggml_sin_impl( struct ggml_context * ctx, struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SIN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SIN; result->src[0] = a; return result; @@ -5078,17 +2189,10 @@ struct ggml_tensor * ggml_sin_inplace( static struct ggml_tensor * ggml_cos_impl( struct ggml_context * ctx, struct ggml_tensor * a, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_COS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_COS; result->src[0] = a; return result; @@ -5110,17 +2214,10 @@ struct ggml_tensor * ggml_cos_inplace( struct ggml_tensor * ggml_sum( struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - + struct ggml_tensor * a) { struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); - result->op = GGML_OP_SUM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SUM; result->src[0] = a; return result; @@ -5130,13 +2227,7 @@ struct ggml_tensor * ggml_sum( struct ggml_tensor * ggml_sum_rows( struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - + struct ggml_tensor * a) { int64_t ne[GGML_MAX_DIMS] = { 1 }; for (int i = 1; i < GGML_MAX_DIMS; ++i) { ne[i] = a->ne[i]; @@ -5144,8 +2235,7 @@ struct ggml_tensor * ggml_sum_rows( struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); - result->op = GGML_OP_SUM_ROWS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SUM_ROWS; result->src[0] = a; return result; @@ -5155,19 +2245,11 @@ struct ggml_tensor * ggml_sum_rows( struct ggml_tensor * ggml_mean( struct ggml_context * ctx, - struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - + struct ggml_tensor * a) { int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_MEAN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MEAN; result->src[0] = a; return result; @@ -5177,42 +2259,46 @@ struct ggml_tensor * ggml_mean( struct ggml_tensor * ggml_argmax( struct ggml_context * ctx, - struct ggml_tensor * a) { + struct ggml_tensor * a) { GGML_ASSERT(ggml_is_matrix(a)); - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); - is_node = true; - } + GGML_ASSERT(a->ne[0] <= INT32_MAX); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, a->ne[1]); - result->op = GGML_OP_ARGMAX; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ARGMAX; result->src[0] = a; return result; } +// ggml_count_equal + +struct ggml_tensor * ggml_count_equal( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_are_same_shape(a, b)); + + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, 1); + + result->op = GGML_OP_COUNT_EQUAL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_repeat struct ggml_tensor * ggml_repeat( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { GGML_ASSERT(ggml_can_repeat(a, b)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne); - result->op = GGML_OP_REPEAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_REPEAT; result->src[0] = a; return result; @@ -5222,24 +2308,13 @@ struct ggml_tensor * ggml_repeat( struct ggml_tensor * ggml_repeat_back( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_tensor * a, + struct ggml_tensor * b) { GGML_ASSERT(ggml_can_repeat(b, a)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (ggml_are_same_shape(a, b) && !is_node) { - return a; - } - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, b->ne); - result->op = GGML_OP_REPEAT_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_REPEAT_BACK; result->src[0] = a; return result; @@ -5249,9 +2324,9 @@ struct ggml_tensor * ggml_repeat_back( struct ggml_tensor * ggml_concat( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int dim) { + struct ggml_tensor * a, + struct ggml_tensor * b, + int dim) { GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); int64_t ne[GGML_MAX_DIMS]; @@ -5264,19 +2339,11 @@ struct ggml_tensor * ggml_concat( ne[d] = a->ne[d]; } - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); ggml_set_op_params_i32(result, 0, dim); - result->op = GGML_OP_CONCAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONCAT; result->src[0] = a; result->src[1] = b; @@ -5385,20 +2452,14 @@ struct ggml_tensor * ggml_relu_inplace( struct ggml_tensor * ggml_leaky_relu( struct ggml_context * ctx, - struct ggml_tensor * a, float negative_slope, bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - GGML_ABORT("fatal error"); // TODO: not implemented - is_node = true; - } - + struct ggml_tensor * a, + float negative_slope, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); - result->op = GGML_OP_LEAKY_RELU; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_LEAKY_RELU; result->src[0] = a; return result; @@ -5466,17 +2527,9 @@ struct ggml_tensor * ggml_silu_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { - bool is_node = false; - - if (a->grad || b->grad) { - // TODO: implement backward - is_node = true; - } - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SILU_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SILU_BACK; result->src[0] = a; result->src[1] = b; @@ -5484,6 +2537,7 @@ struct ggml_tensor * ggml_silu_back( } // ggml hardswish + struct ggml_tensor * ggml_hardswish( struct ggml_context * ctx, struct ggml_tensor * a) { @@ -5491,6 +2545,7 @@ struct ggml_tensor * ggml_hardswish( } // ggml hardsigmoid + struct ggml_tensor * ggml_hardsigmoid( struct ggml_context * ctx, struct ggml_tensor * a) { @@ -5498,6 +2553,7 @@ struct ggml_tensor * ggml_hardsigmoid( } // ggml exp + struct ggml_tensor * ggml_exp( struct ggml_context * ctx, struct ggml_tensor * a) { @@ -5515,21 +2571,13 @@ struct ggml_tensor * ggml_exp_inplace( static struct ggml_tensor * ggml_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, - float eps, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + float eps, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = GGML_OP_NORM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_NORM; result->src[0] = a; return result; @@ -5538,14 +2586,14 @@ static struct ggml_tensor * ggml_norm_impl( struct ggml_tensor * ggml_norm( struct ggml_context * ctx, struct ggml_tensor * a, - float eps) { + float eps) { return ggml_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_norm_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - float eps) { + float eps) { return ggml_norm_impl(ctx, a, eps, true); } @@ -5554,20 +2602,13 @@ struct ggml_tensor * ggml_norm_inplace( static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, - float eps, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - + float eps, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = GGML_OP_RMS_NORM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RMS_NORM; result->src[0] = a; return result; @@ -5576,14 +2617,14 @@ static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, struct ggml_tensor * a, - float eps) { + float eps) { return ggml_rms_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - float eps) { + float eps) { return ggml_rms_norm_impl(ctx, a, eps, true); } @@ -5593,20 +2634,12 @@ struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - float eps) { - bool is_node = false; - - if (a->grad) { - // TODO: implement backward - is_node = true; - } - + float eps) { struct ggml_tensor * result = ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &eps, sizeof(eps)); - result->op = GGML_OP_RMS_NORM_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RMS_NORM_BACK; result->src[0] = a; result->src[1] = b; @@ -5616,48 +2649,48 @@ struct ggml_tensor * ggml_rms_norm_back( // ggml_group_norm static struct ggml_tensor * ggml_group_norm_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups, - float eps, - bool inplace) { - - bool is_node = false; - if (!inplace && (a->grad)) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, n_groups); ggml_set_op_params_f32(result, 1, eps); - result->op = GGML_OP_GROUP_NORM; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_GROUP_NORM; result->src[0] = a; return result; } struct ggml_tensor * ggml_group_norm( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups, - float eps) { + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps) { return ggml_group_norm_impl(ctx, a, n_groups, eps, false); } struct ggml_tensor * ggml_group_norm_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_groups, - float eps) { + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups, + float eps) { return ggml_group_norm_impl(ctx, a, n_groups, eps, true); } // ggml_mul_mat +static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); +} + struct ggml_tensor * ggml_mul_mat( struct ggml_context * ctx, struct ggml_tensor * a, @@ -5665,17 +2698,10 @@ struct ggml_tensor * ggml_mul_mat( GGML_ASSERT(ggml_can_mul_mat(a, b)); GGML_ASSERT(!ggml_is_transposed(a)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_MUL_MAT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MUL_MAT; result->src[0] = a; result->src[1] = b; @@ -5721,17 +2747,10 @@ struct ggml_tensor * ggml_mul_mat_id( GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast - bool is_node = false; - - if (as->grad || b->grad) { - is_node = true; - } - const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_MUL_MAT_ID; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MUL_MAT_ID; result->src[0] = as; result->src[1] = b; result->src[2] = ids; @@ -5741,6 +2760,14 @@ struct ggml_tensor * ggml_mul_mat_id( // ggml_out_prod +static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[1] == t1->ne[1]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); +} + struct ggml_tensor * ggml_out_prod( struct ggml_context * ctx, struct ggml_tensor * a, @@ -5748,18 +2775,11 @@ struct ggml_tensor * ggml_out_prod( GGML_ASSERT(ggml_can_out_prod(a, b)); GGML_ASSERT(!ggml_is_transposed(a)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_OUT_PROD; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_OUT_PROD; result->src[0] = a; result->src[1] = b; @@ -5772,21 +2792,14 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, float s, - bool inplace) { + bool inplace) { GGML_ASSERT(ggml_is_padded_1d(a)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, &s, sizeof(s)); - result->op = GGML_OP_SCALE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SCALE; result->src[0] = a; return result; @@ -5794,15 +2807,15 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_tensor * ggml_scale( struct ggml_context * ctx, - struct ggml_tensor * a, - float s) { + struct ggml_tensor * a, + float s) { return ggml_scale_impl(ctx, a, s, false); } struct ggml_tensor * ggml_scale_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - float s) { + struct ggml_tensor * a, + float s) { return ggml_scale_impl(ctx, a, s, true); } @@ -5816,15 +2829,9 @@ static struct ggml_tensor * ggml_set_impl( size_t nb2, size_t nb3, size_t offset, - bool inplace) { + bool inplace) { GGML_ASSERT(ggml_nelements(a) >= ggml_nelements(b)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // make a view of the destination struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -5832,8 +2839,7 @@ static struct ggml_tensor * ggml_set_impl( int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_SET; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SET; result->src[0] = a; result->src[1] = b; @@ -5842,8 +2848,8 @@ static struct ggml_tensor * ggml_set_impl( struct ggml_tensor * ggml_set( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, @@ -5853,8 +2859,8 @@ struct ggml_tensor * ggml_set( struct ggml_tensor * ggml_set_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, @@ -5864,24 +2870,24 @@ struct ggml_tensor * ggml_set_inplace( struct ggml_tensor * ggml_set_1d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t offset) { return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false); } struct ggml_tensor * ggml_set_1d_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t offset) { return ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true); } struct ggml_tensor * ggml_set_2d( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t nb1, size_t offset) { return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); @@ -5889,8 +2895,8 @@ struct ggml_tensor * ggml_set_2d( struct ggml_tensor * ggml_set_2d_inplace( struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, + struct ggml_tensor * a, + struct ggml_tensor * b, size_t nb1, size_t offset) { return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true); @@ -5904,13 +2910,6 @@ static struct ggml_tensor * ggml_cpy_impl( struct ggml_tensor * b) { GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); - bool is_node = false; - - if (a->grad || b->grad) { - // inplace is false and either one have a grad - is_node = true; - } - // make a view of the destination struct ggml_tensor * result = ggml_view_tensor(ctx, b); if (strlen(b->name) > 0) { @@ -5919,8 +2918,7 @@ static struct ggml_tensor * ggml_cpy_impl( ggml_format_name(result, "%s (copy)", a->name); } - result->op = GGML_OP_CPY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CPY; result->src[0] = a; result->src[1] = b; @@ -5938,13 +2936,10 @@ struct ggml_tensor * ggml_cast( struct ggml_context * ctx, struct ggml_tensor * a, enum ggml_type type) { - bool is_node = false; - struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne); ggml_format_name(result, "%s (copy)", a->name); - result->op = GGML_OP_CPY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CPY; result->src[0] = a; result->src[1] = result; @@ -5956,17 +2951,10 @@ struct ggml_tensor * ggml_cast( static struct ggml_tensor * ggml_cont_impl( struct ggml_context * ctx, struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); ggml_format_name(result, "%s (cont)", a->name); - result->op = GGML_OP_CONT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONT; result->src[0] = a; return result; @@ -6012,13 +3000,10 @@ struct ggml_tensor * ggml_cont_4d( int64_t ne3) { GGML_ASSERT(ggml_nelements(a) == (ne0*ne1*ne2*ne3)); - bool is_node = false; - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); ggml_format_name(result, "%s (cont)", a->name); - result->op = GGML_OP_CONT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONT; result->src[0] = a; return result; @@ -6034,22 +3019,10 @@ struct ggml_tensor * ggml_reshape( // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - - if (b->grad) { - // gradient propagation is not supported - //GGML_ABORT("fatal error"); - } - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6062,18 +3035,11 @@ struct ggml_tensor * ggml_reshape_1d( GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_nelements(a) == ne0); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[1] = { ne0 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6087,18 +3053,11 @@ struct ggml_tensor * ggml_reshape_2d( GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_nelements(a) == ne0*ne1); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[2] = { ne0, ne1 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6113,18 +3072,11 @@ struct ggml_tensor * ggml_reshape_3d( GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[3] = { ne0, ne1, ne2 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6140,18 +3092,11 @@ struct ggml_tensor * ggml_reshape_4d( GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); ggml_format_name(result, "%s (reshaped)", a->name); - result->op = GGML_OP_RESHAPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RESHAPE; result->src[0] = a; return result; @@ -6163,20 +3108,12 @@ static struct ggml_tensor * ggml_view_impl( int n_dims, const int64_t * ne, size_t offset) { - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); ggml_format_name(result, "%s (view)", a->name); ggml_set_op_params(result, &offset, sizeof(offset)); - result->op = GGML_OP_VIEW; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_VIEW; result->src[0] = a; return result; @@ -6189,7 +3126,6 @@ struct ggml_tensor * ggml_view_1d( struct ggml_tensor * a, int64_t ne0, size_t offset) { - struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset); return result; @@ -6204,7 +3140,6 @@ struct ggml_tensor * ggml_view_2d( int64_t ne1, size_t nb1, size_t offset) { - const int64_t ne[2] = { ne0, ne1 }; struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset); @@ -6227,7 +3162,6 @@ struct ggml_tensor * ggml_view_3d( size_t nb1, size_t nb2, size_t offset) { - const int64_t ne[3] = { ne0, ne1, ne2 }; struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset); @@ -6252,7 +3186,6 @@ struct ggml_tensor * ggml_view_4d( size_t nb2, size_t nb3, size_t offset) { - const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset); @@ -6285,12 +3218,6 @@ struct ggml_tensor * ggml_permute( GGML_ASSERT(axis1 != axis3); GGML_ASSERT(axis2 != axis3); - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_view_tensor(ctx, a); ggml_format_name(result, "%s (permuted)", a->name); @@ -6317,8 +3244,7 @@ struct ggml_tensor * ggml_permute( result->nb[2] = nb[2]; result->nb[3] = nb[3]; - result->op = GGML_OP_PERMUTE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_PERMUTE; result->src[0] = a; int32_t params[] = { axis0, axis1, axis2, axis3 }; @@ -6332,12 +3258,6 @@ struct ggml_tensor * ggml_permute( struct ggml_tensor * ggml_transpose( struct ggml_context * ctx, struct ggml_tensor * a) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = ggml_view_tensor(ctx, a); ggml_format_name(result, "%s (transposed)", a->name); @@ -6347,8 +3267,7 @@ struct ggml_tensor * ggml_transpose( result->nb[0] = a->nb[1]; result->nb[1] = a->nb[0]; - result->op = GGML_OP_TRANSPOSE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_TRANSPOSE; result->src[0] = a; return result; @@ -6364,12 +3283,6 @@ struct ggml_tensor * ggml_get_rows( GGML_ASSERT(b->ne[3] == 1); GGML_ASSERT(b->type == GGML_TYPE_I32); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // TODO: implement non F32 return enum ggml_type type = GGML_TYPE_F32; if (a->type == GGML_TYPE_I32) { @@ -6377,8 +3290,7 @@ struct ggml_tensor * ggml_get_rows( } struct ggml_tensor * result = ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]); - result->op = GGML_OP_GET_ROWS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_GET_ROWS; result->src[0] = a; result->src[1] = b; @@ -6395,18 +3307,11 @@ struct ggml_tensor * ggml_get_rows_back( GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_matrix(c) && (a->ne[0] == c->ne[0])); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } - // TODO: implement non F32 return //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, c->ne[0], c->ne[1]); - result->op = GGML_OP_GET_ROWS_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_GET_ROWS_BACK; result->src[0] = a; result->src[1] = b; @@ -6419,17 +3324,11 @@ struct ggml_tensor * ggml_diag( struct ggml_context * ctx, struct ggml_tensor * a) { GGML_ASSERT(a->ne[1] == 1); - bool is_node = false; - - if (a->grad) { - is_node = true; - } const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 4, ne); - result->op = GGML_OP_DIAG; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DIAG; result->src[0] = a; return result; @@ -6442,19 +3341,12 @@ static struct ggml_tensor * ggml_diag_mask_inf_impl( struct ggml_tensor * a, int n_past, bool inplace) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); int32_t params[] = { n_past }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_DIAG_MASK_INF; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DIAG_MASK_INF; result->src[0] = a; return result; @@ -6481,19 +3373,12 @@ static struct ggml_tensor * ggml_diag_mask_zero_impl( struct ggml_tensor * a, int n_past, bool inplace) { - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); int32_t params[] = { n_past }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_DIAG_MASK_ZERO; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_DIAG_MASK_ZERO; result->src[0] = a; return result; @@ -6536,19 +3421,12 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(mask); } - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); float params[] = { scale, max_bias }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_SOFT_MAX; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SOFT_MAX; result->src[0] = a; result->src[1] = mask; @@ -6576,41 +3454,43 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } -// ggml_soft_max_back +// ggml_soft_max_ext_back -static struct ggml_tensor * ggml_soft_max_back_impl( +static struct ggml_tensor * ggml_soft_max_ext_back_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + float scale, + float max_bias, bool inplace) { - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; // TODO : implement backward pass - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - result->op = GGML_OP_SOFT_MAX_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SOFT_MAX_BACK; result->src[0] = a; result->src[1] = b; + memcpy((float *) result->op_params + 0, &scale, sizeof(float)); + memcpy((float *) result->op_params + 1, &max_bias, sizeof(float)); + return result; } -struct ggml_tensor * ggml_soft_max_back( +struct ggml_tensor * ggml_soft_max_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, false); + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false); } -struct ggml_tensor * ggml_soft_max_back_inplace( +struct ggml_tensor * ggml_soft_max_ext_back_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b) { - return ggml_soft_max_back_impl(ctx, a, b, true); + struct ggml_tensor * b, + float scale, + float max_bias) { + return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true); } // ggml_rope @@ -6641,25 +3521,21 @@ static struct ggml_tensor * ggml_rope_impl( GGML_ASSERT(c->ne[0] >= n_dims / 2); } - bool is_node = false; - - if (a->grad) { - is_node = true; - } + int sections[4] = {0, 0, 0, 0}; struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(params + 11, §ions, sizeof(int)*4); ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_ROPE; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ROPE; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -6678,6 +3554,53 @@ struct ggml_tensor * ggml_rope( ); } +struct ggml_tensor * ggml_rope_multi( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + // Multimodal Rotary Position Embedding + GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); + + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token + + if (c) { + GGML_ASSERT(c->type == GGML_TYPE_F32); + GGML_ASSERT(c->ne[0] >= n_dims / 2); + } + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(¶ms[11], sections, sizeof(int)*4); + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ROPE; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + struct ggml_tensor * ggml_rope_inplace( struct ggml_context * ctx, struct ggml_tensor * a, @@ -6767,9 +3690,25 @@ struct ggml_tensor * ggml_rope_custom_inplace( ); } +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} + +void ggml_rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); + dims[0] = MAX(0, start); + dims[1] = MIN(n_dims - 1, end); +} + // ggml_rope_back -struct ggml_tensor * ggml_rope_back( +struct ggml_tensor * ggml_rope_ext_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -6783,37 +3722,32 @@ struct ggml_tensor * ggml_rope_back( float attn_factor, float beta_fast, float beta_slow) { - GGML_ASSERT(ggml_is_vector(b)); - GGML_ASSERT(b->type == GGML_TYPE_I32); - GGML_ASSERT(a->ne[2] == b->ne[0]); - - bool is_node = false; - - if (a->grad) { - GGML_ASSERT(false && "backwards pass not implemented"); - is_node = false; - } - - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - - int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; - memcpy(params + 5, &freq_base, sizeof(float)); - memcpy(params + 6, &freq_scale, sizeof(float)); - memcpy(params + 7, &ext_factor, sizeof(float)); - memcpy(params + 8, &attn_factor, sizeof(float)); - memcpy(params + 9, &beta_fast, sizeof(float)); - memcpy(params + 10, &beta_slow, sizeof(float)); - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_ROPE_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - result->src[2] = c; - + struct ggml_tensor * result = ggml_rope_ext( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + result->op = GGML_OP_ROPE_BACK; return result; } +struct ggml_tensor * ggml_rope_multi_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + struct ggml_tensor * result = ggml_rope_multi( + ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + result->op = GGML_OP_ROPE_BACK; + return result; +} // ggml_clamp struct ggml_tensor * ggml_clamp( @@ -6821,33 +3755,96 @@ struct ggml_tensor * ggml_clamp( struct ggml_tensor * a, float min, float max) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - // TODO: when implement backward, fix this: struct ggml_tensor * result = ggml_view_tensor(ctx, a); float params[] = { min, max }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_CLAMP; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CLAMP; result->src[0] = a; return result; } +static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; +} + +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OH, OW, IC*KH*KW] +struct ggml_tensor * ggml_im2col( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D, + enum ggml_type dst_type) { + if (is_2D) { + GGML_ASSERT(a->ne[2] == b->ne[2]); + } else { + //GGML_ASSERT(b->ne[1] % a->ne[1] == 0); + GGML_ASSERT(b->ne[1] == a->ne[1]); + GGML_ASSERT(b->ne[3] == 1); + } + + const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; + const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + + GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); + GGML_ASSERT((OW > 0) && "b too small compared to a"); + + const int64_t ne[4] = { + is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], + OW, + is_2D ? OH : b->ne[2], + is_2D ? b->ne[3] : 1, + }; + + struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_im2col_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int64_t * ne, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D) { + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_IM2COL_BACK; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // ggml_conv_1d -static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { - return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; -} - -GGML_API struct ggml_tensor * ggml_conv_1d( +struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, @@ -6877,6 +3874,38 @@ struct ggml_tensor* ggml_conv_1d_ph( return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); } +// ggml_conv_1d_dw + +struct ggml_tensor * ggml_conv_1d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]); + struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]); + + struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); + + struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a); + + result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1); + + return result; +} + +// ggml_conv_1d_dw_ph + +struct ggml_tensor * ggml_conv_1d_dw_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int d0) { + return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0); +} + // ggml_conv_transpose_1d static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { @@ -6897,13 +3926,6 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d( GGML_ASSERT(p0 == 0); GGML_ASSERT(d0 == 1); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), a->ne[1], b->ne[2], 1, @@ -6913,125 +3935,15 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d( int32_t params[] = { s0, p0, d0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_CONV_TRANSPOSE_1D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONV_TRANSPOSE_1D; result->src[0] = a; result->src[1] = b; return result; } -// ggml_conv_depthwise -struct ggml_tensor * ggml_conv_depthwise_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { - - struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); - struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, - ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), - s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] - struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] - - new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] - struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b); - result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] - - return result; -} // ggml_conv_2d -// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] -// a: [OC,IC, KH, KW] -// b: [N, IC, IH, IW] -// result: [N, OH, OW, IC*KH*KW] -struct ggml_tensor * ggml_im2col( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D, - enum ggml_type dst_type) { - - if(is_2D) { - GGML_ASSERT(a->ne[2] == b->ne[2]); - } else { - GGML_ASSERT(a->ne[1] == b->ne[1]); - GGML_ASSERT(b->ne[3] == 1); - } - bool is_node = false; - - if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data - is_node = true; - } - - const int64_t OH = is_2D ? ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; - const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); - - GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); - GGML_ASSERT((OW > 0) && "b too small compared to a"); - - const int64_t ne[4] = { - is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], - OW, - is_2D ? OH : b->ne[2], - is_2D ? b->ne[3] : 1, - }; - - struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_IM2COL; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -struct ggml_tensor * ggml_im2col_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int64_t * ne, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D) { - - bool is_node = false; - - if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data - is_node = true; - } - - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_IM2COL_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b; - - return result; -} - // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OC, OH, OW] @@ -7039,12 +3951,12 @@ struct ggml_tensor * ggml_conv_2d( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW] struct ggml_tensor * result = @@ -7060,6 +3972,7 @@ struct ggml_tensor * ggml_conv_2d( } // ggml_conv_2d_sk_p0 + struct ggml_tensor * ggml_conv_2d_sk_p0( struct ggml_context * ctx, struct ggml_tensor * a, @@ -7076,6 +3989,31 @@ struct ggml_tensor * ggml_conv_2d_s1_ph( return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); } +// ggml_conv_2d_dw + +struct ggml_tensor * ggml_conv_2d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); + struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, + ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), + s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] + struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] + + new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] + struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b); + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] + + return result; +} + // ggml_conv_transpose_2d_p0 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { @@ -7089,13 +4027,6 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0( int stride) { GGML_ASSERT(a->ne[3] == b->ne[2]); - bool is_node = false; - - if (a->grad || b->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/), ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/), @@ -7106,8 +4037,7 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0( ggml_set_op_params_i32(result, 0, stride); - result->op = GGML_OP_CONV_TRANSPOSE_2D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CONV_TRANSPOSE_2D; result->src[0] = a; result->src[1] = b; @@ -7129,14 +4059,6 @@ struct ggml_tensor * ggml_pool_1d( int k0, int s0, int p0) { - - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), a->ne[1], @@ -7148,8 +4070,7 @@ struct ggml_tensor * ggml_pool_1d( int32_t params[] = { op, k0, s0, p0 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_POOL_1D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_POOL_1D; result->src[0] = a; return result; @@ -7167,13 +4088,6 @@ struct ggml_tensor * ggml_pool_2d( int s1, float p0, float p1) { - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result; const int64_t ne[4] = { ggml_calc_pool_output_size(a->ne[0], k0, s0, p0), @@ -7186,9 +4100,9 @@ struct ggml_tensor * ggml_pool_2d( int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_POOL_2D; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_POOL_2D; result->src[0] = a; + return result; } @@ -7203,100 +4117,105 @@ struct ggml_tensor * ggml_pool_2d_back( int s1, float p0, float p1) { - - bool is_node = false; - - if (a->grad) { - is_node = true; - } - struct ggml_tensor * result; result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, af->ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_POOL_2D_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_POOL_2D_BACK; result->src[0] = a; result->src[1] = af; + return result; } // ggml_upscale static struct ggml_tensor * ggml_upscale_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - int ne2, - int ne3) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3) { GGML_ASSERT(a->ne[0] <= ne0); GGML_ASSERT(a->ne[1] <= ne1); GGML_ASSERT(a->ne[2] <= ne2); GGML_ASSERT(a->ne[3] <= ne3); - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, - ne0, - ne1, - ne2, - ne3 - ); + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); - result->op = GGML_OP_UPSCALE; - - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_UPSCALE; result->src[0] = a; return result; } struct ggml_tensor * ggml_upscale( - struct ggml_context * ctx, - struct ggml_tensor * a, - int scale_factor) { + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor) { return ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]); } struct ggml_tensor * ggml_upscale_ext( - struct ggml_context * ctx, - struct ggml_tensor * a, - int ne0, - int ne1, - int ne2, - int ne3) { + struct ggml_context * ctx, + struct ggml_tensor * a, + int ne0, + int ne1, + int ne2, + int ne3) { return ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3); } // ggml_pad struct ggml_tensor * ggml_pad( - struct ggml_context * ctx, - struct ggml_tensor * a, - int p0, int p1, int p2, int p3) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1, + int p2, + int p3) { struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0] + p0, a->ne[1] + p1, a->ne[2] + p2, a->ne[3] + p3); - result->op = GGML_OP_PAD; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_PAD; + result->src[0] = a; + + return result; +} + +// ggml_pad_reflect_1d + +struct ggml_tensor * ggml_pad_reflect_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int p0, + int p1) { + GGML_ASSERT(p0 >= 0); + GGML_ASSERT(p1 >= 0); + + GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the + GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded + + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->type == GGML_TYPE_F32); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, + a->ne[0] + p0 + p1, + a->ne[1], + a->ne[2], + a->ne[3]); + + int32_t params[] = { p0, p1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_PAD_REFLECT_1D; result->src[0] = a; return result; @@ -7305,39 +4224,32 @@ struct ggml_tensor * ggml_pad( // ggml_arange struct ggml_tensor * ggml_arange( - struct ggml_context * ctx, - float start, - float stop, - float step) { - + struct ggml_context * ctx, + float start, + float stop, + float step) { GGML_ASSERT(stop > start); const int64_t steps = (int64_t) ceilf((stop - start) / step); struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, steps); - result->op = GGML_OP_ARANGE; ggml_set_op_params_f32(result, 0, start); ggml_set_op_params_f32(result, 1, stop); ggml_set_op_params_f32(result, 2, step); + result->op = GGML_OP_ARANGE; + return result; } // ggml_timestep_embedding struct ggml_tensor * ggml_timestep_embedding( - struct ggml_context * ctx, - struct ggml_tensor * timesteps, - int dim, - int max_period) { - bool is_node = false; - - if (timesteps->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * timesteps, + int dim, + int max_period) { int actual_dim = dim; if (dim % 2 != 0) { actual_dim = dim + 1; @@ -7345,11 +4257,10 @@ struct ggml_tensor * ggml_timestep_embedding( struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]); - result->op = GGML_OP_TIMESTEP_EMBEDDING; ggml_set_op_params_i32(result, 0, dim); ggml_set_op_params_i32(result, 1, max_period); - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_TIMESTEP_EMBEDDING; result->src[0] = timesteps; return result; @@ -7358,22 +4269,15 @@ struct ggml_tensor * ggml_timestep_embedding( // ggml_argsort struct ggml_tensor * ggml_argsort( - struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_sort_order order) { - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: not implemented - is_node = true; - } - + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_sort_order order) { + GGML_ASSERT(a->ne[0] <= INT32_MAX); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne); ggml_set_op_params_i32(result, 0, (int32_t) order); - result->op = GGML_OP_ARGSORT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ARGSORT; result->src[0] = a; return result; @@ -7424,12 +4328,6 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(mask); } - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - is_node = true; - } - // permute(0, 2, 1, 3) int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); @@ -7437,8 +4335,7 @@ struct ggml_tensor * ggml_flash_attn_ext( float params[] = { scale, max_bias, logit_softcap }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_FLASH_ATTN_EXT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_FLASH_ATTN_EXT; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -7457,6 +4354,15 @@ void ggml_flash_attn_ext_set_prec( ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second } +enum ggml_prec ggml_flash_attn_ext_get_prec( + const struct ggml_tensor * a) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = ggml_get_op_params_i32(a, 3); + + return (enum ggml_prec) prec_i32; +} + // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back( @@ -7497,14 +4403,6 @@ struct ggml_tensor * ggml_flash_attn_back( GGML_ASSERT(ne2 % kvne2 == 0); - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - // when using this operation (in backwards pass) these grads are set. - // we don't want to create (big) grad of our result, so is_node is false. - is_node = false; - } - // store gradients of q, k and v as continuous tensors concatenated in result. // note: v and gradv are actually transposed, i.e. v->ne[0] != D. const int64_t elem_q = ggml_nelements(q); @@ -7527,8 +4425,7 @@ struct ggml_tensor * ggml_flash_attn_back( int32_t masked_i = masked ? 1 : 0; ggml_set_op_params(result, &masked_i, sizeof(masked_i)); - result->op = GGML_OP_FLASH_ATTN_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_FLASH_ATTN_BACK; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -7552,21 +4449,14 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? + // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); - bool is_node = false; - - if (sx->grad || c->grad) { - GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s); - result->op = GGML_OP_SSM_CONV; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_SSM_CONV; result->src[0] = sx; result->src[1] = c; @@ -7610,18 +4500,10 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->ne[2] == n_seqs); } - bool is_node = false; - - if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) { - GGML_ABORT("fatal error"); // TODO: implement - is_node = true; - } - // concatenated y + ssm_states struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); result->op = GGML_OP_SSM_SCAN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = s; result->src[1] = x; result->src[2] = dt; @@ -7641,13 +4523,6 @@ struct ggml_tensor * ggml_win_part( GGML_ASSERT(a->ne[3] == 1); GGML_ASSERT(a->type == GGML_TYPE_F32); - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - // padding const int px = (w - a->ne[1]%w)%w; const int py = (w - a->ne[2]%w)%w; @@ -7662,8 +4537,7 @@ struct ggml_tensor * ggml_win_part( int32_t params[] = { npx, npy, w }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_WIN_PART; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_WIN_PART; result->src[0] = a; return result; @@ -7679,21 +4553,13 @@ struct ggml_tensor * ggml_win_unpart( int w) { GGML_ASSERT(a->type == GGML_TYPE_F32); - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { a->ne[0], w0, h0, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); int32_t params[] = { w }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_WIN_UNPART; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_WIN_UNPART; result->src[0] = a; return result; @@ -7709,18 +4575,10 @@ struct ggml_tensor * ggml_get_rel_pos( GGML_ASSERT(qh == kh); GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]); - bool is_node = false; - - if (a->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - const int64_t ne[4] = { a->ne[0], kh, qh, 1, }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 3, ne); - result->op = GGML_OP_GET_REL_POS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_GET_REL_POS; result->src[0] = a; return result; @@ -7744,17 +4602,10 @@ static struct ggml_tensor * ggml_add_rel_pos_impl( GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]); GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]); - bool is_node = false; - - if (!inplace && (a->grad || pw->grad || ph->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, inplace ? 1 : 0); - result->op = GGML_OP_ADD_REL_POS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_ADD_REL_POS; result->src[0] = a; result->src[1] = pw; result->src[2] = ph; @@ -7778,16 +4629,16 @@ struct ggml_tensor * ggml_add_rel_pos_inplace( return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); } -// ggml_rwkv_wkv +// ggml_rwkv_wkv6 -struct ggml_tensor * ggml_rwkv_wkv( +struct ggml_tensor * ggml_rwkv_wkv6( struct ggml_context * ctx, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor * r, - struct ggml_tensor * tf, - struct ggml_tensor * td, - struct ggml_tensor * state) { + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * r, + struct ggml_tensor * tf, + struct ggml_tensor * td, + struct ggml_tensor * state) { GGML_ASSERT(ggml_is_contiguous(k)); GGML_ASSERT(ggml_is_contiguous(v)); GGML_ASSERT(ggml_is_contiguous(r)); @@ -7796,31 +4647,21 @@ struct ggml_tensor * ggml_rwkv_wkv( GGML_ASSERT(ggml_is_contiguous(state)); const int64_t S = k->ne[0]; - const int64_t H = k->ne[2]; - const int64_t n_tokens = k->ne[3]; + const int64_t H = k->ne[1]; + const int64_t n_tokens = k->ne[2]; const int64_t n_seqs = state->ne[1]; { - GGML_ASSERT(k->ne[1] == 1); - GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens); - GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens); - // TODO: RWKV v4 and v5 - GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens); + GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens); + GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens); GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); } - bool is_node = false; - - if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) { - GGML_ABORT("fatal error"); // TODO: implement backward - is_node = true; - } - // concat output and new_state const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - result->op = GGML_OP_RWKV_WKV; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_RWKV_WKV6; result->src[0] = k; result->src[1] = v; result->src[2] = r; @@ -7831,27 +4672,63 @@ struct ggml_tensor * ggml_rwkv_wkv( return result; } +// ggml_gated_linear_attn + +struct ggml_tensor * ggml_gated_linear_attn( + struct ggml_context * ctx, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * q, + struct ggml_tensor * g, + struct ggml_tensor * state, + float scale) { + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S = k->ne[0]; + const int64_t H = k->ne[1]; + const int64_t n_tokens = k->ne[2]; + const int64_t n_seqs = state->ne[1]; + { + GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); + GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens); + GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); + } + + // concat output and new_state + const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_f32(result, 0, scale); + + result->op = GGML_OP_GATED_LINEAR_ATTN; + result->src[0] = k; + result->src[1] = v; + result->src[2] = q; + result->src[3] = g; + result->src[4] = state; + + return result; +} + // ggml_unary static struct ggml_tensor * ggml_unary_impl( struct ggml_context * ctx, - struct ggml_tensor * a, - enum ggml_unary_op op, - bool inplace) { + struct ggml_tensor * a, + enum ggml_unary_op op, + bool inplace) { GGML_ASSERT(ggml_is_contiguous_1(a)); - bool is_node = false; - - if (!inplace && (a->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params_i32(result, 0, (int32_t) op); - result->op = GGML_OP_UNARY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_UNARY; result->src[0] = a; return result; @@ -7860,14 +4737,14 @@ static struct ggml_tensor * ggml_unary_impl( struct ggml_tensor * ggml_unary( struct ggml_context * ctx, struct ggml_tensor * a, - enum ggml_unary_op op) { + enum ggml_unary_op op) { return ggml_unary_impl(ctx, a, op, false); } struct ggml_tensor * ggml_unary_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - enum ggml_unary_op op) { + enum ggml_unary_op op) { return ggml_unary_impl(ctx, a, op, true); } @@ -7876,20 +4753,13 @@ struct ggml_tensor * ggml_unary_inplace( static struct ggml_tensor * ggml_map_unary_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, - const ggml_unary_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - + const ggml_unary_op_f32_t fun, + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_UNARY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_UNARY; result->src[0] = a; return result; @@ -7898,14 +4768,14 @@ static struct ggml_tensor * ggml_map_unary_impl_f32( struct ggml_tensor * ggml_map_unary_f32( struct ggml_context * ctx, struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { + const ggml_unary_op_f32_t fun) { return ggml_map_unary_impl_f32(ctx, a, fun, false); } struct ggml_tensor * ggml_map_unary_inplace_f32( struct ggml_context * ctx, struct ggml_tensor * a, - const ggml_unary_op_f32_t fun) { + const ggml_unary_op_f32_t fun) { return ggml_map_unary_impl_f32(ctx, a, fun, true); } @@ -7915,22 +4785,15 @@ static struct ggml_tensor * ggml_map_binary_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - const ggml_binary_op_f32_t fun, - bool inplace) { + const ggml_binary_op_f32_t fun, + bool inplace) { GGML_ASSERT(ggml_are_same_shape(a, b)); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_BINARY; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_BINARY; result->src[0] = a; result->src[1] = b; @@ -7941,7 +4804,7 @@ struct ggml_tensor * ggml_map_binary_f32( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { + const ggml_binary_op_f32_t fun) { return ggml_map_binary_impl_f32(ctx, a, b, fun, false); } @@ -7949,7 +4812,7 @@ struct ggml_tensor * ggml_map_binary_inplace_f32( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - const ggml_binary_op_f32_t fun) { + const ggml_binary_op_f32_t fun) { return ggml_map_binary_impl_f32(ctx, a, b, fun, true); } @@ -7959,19 +4822,12 @@ static struct ggml_tensor * ggml_map_custom1_impl_f32( struct ggml_context * ctx, struct ggml_tensor * a, const ggml_custom1_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_CUSTOM1_F32; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM1_F32; result->src[0] = a; return result; @@ -7998,19 +4854,12 @@ static struct ggml_tensor * ggml_map_custom2_impl_f32( struct ggml_tensor * a, struct ggml_tensor * b, const ggml_custom2_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_CUSTOM2_F32; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM2_F32; result->src[0] = a; result->src[1] = b; @@ -8041,19 +4890,12 @@ static struct ggml_tensor * ggml_map_custom3_impl_f32( struct ggml_tensor * b, struct ggml_tensor * c, const ggml_custom3_op_f32_t fun, - bool inplace) { - bool is_node = false; - - if (!inplace && (a->grad || b->grad || c->grad)) { - is_node = true; - } - + bool inplace) { struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); ggml_set_op_params(result, (const void *) &fun, sizeof(fun)); - result->op = GGML_OP_MAP_CUSTOM3_F32; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM3_F32; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -8080,27 +4922,16 @@ struct ggml_tensor * ggml_map_custom3_inplace_f32( } // ggml_map_custom1 -struct ggml_map_custom1_op_params { - ggml_custom1_op_t fun; - int n_tasks; - void * userdata; -}; static struct ggml_tensor * ggml_map_custom1_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); - bool is_node = false; - - if (!inplace && a->grad) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_map_custom1_op_params params = { @@ -8110,55 +4941,42 @@ static struct ggml_tensor * ggml_map_custom1_impl( }; ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - result->op = GGML_OP_MAP_CUSTOM1; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM1; result->src[0] = a; return result; } struct ggml_tensor * ggml_map_custom1( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false); } struct ggml_tensor * ggml_map_custom1_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - const ggml_custom1_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + const ggml_custom1_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true); } // ggml_map_custom2 -struct ggml_map_custom2_op_params { - ggml_custom2_op_t fun; - int n_tasks; - void * userdata; -}; - static struct ggml_tensor * ggml_map_custom2_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); - bool is_node = false; - - if (!inplace && (a->grad || b->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_map_custom2_op_params params = { @@ -8168,8 +4986,7 @@ static struct ggml_tensor * ggml_map_custom2_impl( }; ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - result->op = GGML_OP_MAP_CUSTOM2; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM2; result->src[0] = a; result->src[1] = b; @@ -8177,50 +4994,38 @@ static struct ggml_tensor * ggml_map_custom2_impl( } struct ggml_tensor * ggml_map_custom2( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false); } struct ggml_tensor * ggml_map_custom2_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const ggml_custom2_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const ggml_custom2_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true); } // ggml_map_custom3 -struct ggml_map_custom3_op_params { - ggml_custom3_op_t fun; - int n_tasks; - void * userdata; -}; - static struct ggml_tensor * ggml_map_custom3_impl( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_t fun, - int n_tasks, - void * userdata, - bool inplace) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata, + bool inplace) { GGML_ASSERT(n_tasks == GGML_N_TASKS_MAX || n_tasks > 0); - bool is_node = false; - - if (!inplace && (a->grad || b->grad || c->grad)) { - is_node = true; - } - struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_map_custom3_op_params params = { @@ -8230,8 +5035,7 @@ static struct ggml_tensor * ggml_map_custom3_impl( }; ggml_set_op_params(result, (const void *) ¶ms, sizeof(params)); - result->op = GGML_OP_MAP_CUSTOM3; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_MAP_CUSTOM3; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -8240,44 +5044,38 @@ static struct ggml_tensor * ggml_map_custom3_impl( } struct ggml_tensor * ggml_map_custom3( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false); } struct ggml_tensor * ggml_map_custom3_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c, - const ggml_custom3_op_t fun, - int n_tasks, - void * userdata) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + const ggml_custom3_op_t fun, + int n_tasks, + void * userdata) { return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); } // ggml_cross_entropy_loss struct ggml_tensor * ggml_cross_entropy_loss( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b) { + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { GGML_ASSERT(ggml_are_same_shape(a, b)); - bool is_node = false; - - if (a->grad || b->grad) { - is_node = true; - } struct ggml_tensor * result = ggml_new_tensor_1d(ctx, a->type, 1); - result->op = GGML_OP_CROSS_ENTROPY_LOSS; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_CROSS_ENTROPY_LOSS; result->src[0] = a; result->src[1] = b; @@ -8287,17 +5085,16 @@ struct ggml_tensor * ggml_cross_entropy_loss( // ggml_cross_entropy_loss_back struct ggml_tensor * ggml_cross_entropy_loss_back( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_tensor * c) { - GGML_ASSERT(ggml_are_same_shape(a, b)); - GGML_ASSERT(ggml_is_scalar(c)); + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(ggml_is_scalar(a)); + GGML_ASSERT(ggml_are_same_shape(b, c)); - struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + struct ggml_tensor * result = ggml_dup_tensor(ctx, b); - result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; - result->grad = NULL; + result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK; result->src[0] = a; result->src[1] = b; result->src[2] = c; @@ -8305,9499 +5102,32 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( return result; } -//////////////////////////////////////////////////////////////////////////////// +// opt_step_adamw -void ggml_set_param( +struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, - struct ggml_tensor * tensor) { - tensor->flags |= GGML_TENSOR_FLAG_PARAM; - - GGML_ASSERT(tensor->grad == NULL); - tensor->grad = ggml_dup_tensor(ctx, tensor); - ggml_format_name(tensor->grad, "%s (grad)", tensor->name); -} - -// ggml_compute_forward_dup - -static void ggml_compute_forward_dup_same_cont( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == dst->type); - - const size_t nb0 = ggml_type_size(src0->type); - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by elements - const int ne = ggml_nelements(dst); - const int dr = (ne + nth - 1) / nth; - const int ie0 = dr * ith; - const int ie1 = MIN(ie0 + dr, ne); - - if (ie0 < ie1) { - memcpy( - ((char *) dst->data + ie0*nb0), - ((char *) src0->data + ie0*nb0), - (ie1 - ie0) * nb0); - } -} - -static void ggml_compute_forward_dup_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - - GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy - - if (ggml_is_contiguous(dst)) { - if (nb00 == sizeof(ggml_fp16_t)) { - if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_FP16_TO_FP32(src0_ptr[i00]); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (type_traits[dst->type].from_float) { - ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; - float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - for (int i00 = 0; i00 < ne00; i00++) { - src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]); - } - - quantize_row_q(src0_f32, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } - return; - } - - // dst counters - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t)); - - if (++i10 == ne00) { - i10 = 0; - if (++i11 == ne01) { - i11 = 0; - if (++i12 == ne02) { - i12 = 0; - if (++i13 == ne03) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } -} - -static void ggml_compute_forward_dup_bf16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - - GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy - - if (ggml_is_contiguous(dst)) { - if (nb00 == sizeof(ggml_bf16_t)) { - if (dst->type == GGML_TYPE_BF16) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(src0_ptr[i00])); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - for (int i00 = 0; i00 < ne00; i00++) { - dst_ptr[id] = GGML_BF16_TO_FP32(src0_ptr[i00]); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (type_traits[dst->type].from_float) { - ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; - float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - for (int i00 = 0; i00 < ne00; i00++) { - src0_f32[i00] = GGML_BF16_TO_FP32(src0_ptr[i00]); - } - - quantize_row_q(src0_f32, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_BF16_TO_FP32(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_BF16) { - size_t id = 0; - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*src0_ptr)); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } - return; - } - - // dst counters - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_BF16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(ggml_bf16_t)); - - if (++i10 == ne00) { - i10 = 0; - if (++i11 == ne01) { - i11 = 0; - if (++i12 == ne02) { - i12 = 0; - if (++i13 == ne03) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr)); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(float *) dst_ptr = GGML_BF16_TO_FP32(*(const ggml_bf16_t *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } -} - -static void ggml_compute_forward_dup_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - - GGML_TENSOR_UNARY_OP_LOCALS - - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) { - // copy by rows - const size_t rs = ne00*nb00; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - if (ggml_is_contiguous(dst)) { - // TODO: simplify - if (nb00 == sizeof(float)) { - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - const size_t rs = ne00 * nb00; - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else if (type_traits[dst->type].from_float) { - ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; - - size_t id = 0; - size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type)); - char * dst_ptr = (char *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - quantize_row_q(src0_ptr, dst_ptr + id, ne00); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - if (dst->type == GGML_TYPE_F32) { - size_t id = 0; - float * dst_ptr = (float *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = *src0_ptr; - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_F16) { - size_t id = 0; - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else if (dst->type == GGML_TYPE_BF16) { - size_t id = 0; - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) dst->data; - - for (int i03 = 0; i03 < ne03; i03++) { - for (int i02 = 0; i02 < ne02; i02++) { - id += ne00 * ir0; - for (int i01 = ir0; i01 < ir1; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - - dst_ptr[id] = GGML_FP32_TO_BF16(*src0_ptr); - id++; - } - } - id += ne00 * (ne01 - ir1); - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } - } - - return; - } - - // dst counters - - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - if (dst->type == GGML_TYPE_F32) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, sizeof(float)); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_F16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else if (dst->type == GGML_TYPE_BF16) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - *(ggml_bf16_t *) dst_ptr = GGML_FP32_TO_BF16(*(const float *) src0_ptr); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - } else { - GGML_ABORT("fatal error"); // TODO: implement - } -} - -// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy. -static void ggml_compute_forward_dup_bytes( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - GGML_ASSERT(src0->type == dst->type); - - GGML_TENSOR_UNARY_OP_LOCALS; - - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { - ggml_compute_forward_dup_same_cont(params, dst); - return; - } - - const size_t type_size = ggml_type_size(src0->type); - const int ith = params->ith; // thread index - const int nth = params->nth; // number of threads - - - // parallelize by rows - const int nr = ne01; - // number of rows per thread - const int dr = (nr + nth - 1) / nth; - // row range for this thread - const int ir0 = dr * ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (src0->type == dst->type && - ne00 == ne0 && - nb00 == type_size && nb0 == type_size) { - // copy by rows - const size_t rs = ne00 * type_size; - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ir0; i01 < ir1; i01++) { - memcpy( - ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), - rs); - } - } - } - return; - } - - if (ggml_is_contiguous(dst)) { - size_t id = 0; - char * dst_ptr = (char *) dst->data; - const size_t rs = ne00 * type_size; - - if (nb00 == type_size) { - // src0 is contigous on first dimension, copy by rows - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int64_t i01 = ir0; i01 < ir1; i01++) { - const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, rs); - id += rs; - } - id += rs * (ne01 - ir1); - } - } - } else { - //printf("%s: this is not optimal - fix me\n", __func__); - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - id += rs * ir0; - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03; - memcpy(dst_ptr + id, src0_ptr, type_size); - - id += type_size; - } - } - id += rs * (ne01 - ir1); - } - } - } - - return; - } - - // dst counters - - int64_t i10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - i10 += ne00 * ir0; - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - - memcpy(dst_ptr, src0_ptr, type_size); - - if (++i10 == ne0) { - i10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - i10 += ne00 * (ne01 - ir1); - while (i10 >= ne0) { - i10 -= ne0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } -} - -static void ggml_compute_forward_dup( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (src0->type == dst->type) { - ggml_compute_forward_dup_bytes(params, dst); - return; - } - - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_dup_f16(params, dst); - } break; - case GGML_TYPE_BF16: - { - ggml_compute_forward_dup_bf16(params, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_dup_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_add - -static void ggml_compute_forward_add_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0; r < nr0; ++r) { -#ifdef GGML_USE_ACCELERATE - vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); -#else - ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; - } - } - } -} - -static void ggml_compute_forward_add_f16_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - if (dst->type == GGML_TYPE_F32) { - GGML_ASSERT( nb0 == sizeof(float)); - } - else { - GGML_ASSERT(dst->type == GGML_TYPE_F16); - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - } - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - if (dst->type == GGML_TYPE_F16) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); - } - } - } else { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; - } - } - } - } - else { - // src1 is not contiguous - GGML_ABORT("fatal error"); - } -} - -static void ggml_compute_forward_add_bf16_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_BF16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - if (dst->type == GGML_TYPE_F32) { - GGML_ASSERT( nb0 == sizeof(float)); - } - else { - GGML_ASSERT(dst->type == GGML_TYPE_BF16); - GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); - } - - GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - if (dst->type == GGML_TYPE_BF16) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); - } - } - } else { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_BF16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; - } - } - } - } - else { - // src1 is not contiguous - GGML_ABORT("fatal error"); - } -} - -static void ggml_compute_forward_add_f16_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(ggml_fp16_t)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(src1_ptr[i])); - } - } - } - else { - // src1 is not contiguous - GGML_ABORT("fatal error"); - } -} - -static void ggml_compute_forward_add_bf16_bf16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_BF16); - GGML_ASSERT(src1->type == GGML_TYPE_BF16); - GGML_ASSERT(dst->type == GGML_TYPE_BF16); - - GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(ggml_bf16_t)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - ggml_bf16_t * src1_ptr = (ggml_bf16_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + GGML_BF16_TO_FP32(src1_ptr[i])); - } - } - } - else { - // src1 is not contiguous - GGML_ABORT("fatal error"); - } -} - -static void ggml_compute_forward_add_q_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum ggml_type type = src0->type; - const enum ggml_type dtype = dst->type; - ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float; - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == sizeof(float)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(ggml_is_quantized(src0->type)); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 indices - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - - // src1 and dst are same shape as src0 => same indices - const int i13 = i03; - const int i12 = i02; - const int i11 = i01; - - const int i3 = i03; - const int i2 = i02; - const int i1 = i01; - - void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); - float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)); - void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - assert(ne00 % 32 == 0); - - // unquantize row from src0 to temp buffer - dequantize_row_q(src0_row, wdata, ne00); - // add src1 - ggml_vec_acc_f32(ne00, wdata, src1_row); - // quantize row to dst - if (quantize_row_q != NULL) { - quantize_row_q(wdata, dst_row, ne00); - } else { - memcpy(dst_row, wdata, ne0*nb0); - } - } -} - -static void ggml_compute_forward_add( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add_f32(params, dst); - } - else { - GGML_ABORT("fatal error"); - } - } break; - case GGML_TYPE_F16: - { - if (src1->type == GGML_TYPE_F16) { - ggml_compute_forward_add_f16_f16(params, dst); - } - else if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add_f16_f32(params, dst); - } - else { - GGML_ABORT("fatal error"); - } - } break; - case GGML_TYPE_BF16: - { - if (src1->type == GGML_TYPE_BF16) { - ggml_compute_forward_add_bf16_bf16(params, dst); - } - else if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add_bf16_f32(params, dst); - } - else { - GGML_ABORT("fatal error"); - } - } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q4_0_4_4: - case GGML_TYPE_Q4_0_4_8: - case GGML_TYPE_Q4_0_8_8: - { - ggml_compute_forward_add_q_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_add1 - -static void ggml_compute_forward_add1_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - -#ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_add1_f32); - - vDSP_vadd( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) src1->data), 0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); -#else - ggml_vec_add1_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - *(float *) src1->data); -#endif - } -} - -static void ggml_compute_forward_add1_f16_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - // scalar to add - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void ggml_compute_forward_add1_f16_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - // scalar to add - const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F16); - - GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void ggml_compute_forward_add1_q_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - // scalar to add - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - const enum ggml_type type = src0->type; - ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - ggml_from_float_t const quantize_row_q = type_traits[type].from_float; - - // we don't support permuted src0 - GGML_ASSERT(nb00 == ggml_type_size(type)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(ggml_is_quantized(src0->type)); - GGML_ASSERT(dst->type == src0->type); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03)); - void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 )); - - assert(ne0 % 32 == 0); - - // unquantize row from src0 to temp buffer - dequantize_row_q(src0_row, wdata, ne0); - // add src1 - ggml_vec_acc1_f32(ne0, wdata, v); - // quantize row to dst - quantize_row_q(wdata, dst_row, ne0); - } -} - -static void ggml_compute_forward_add1_bf16_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - // scalar to add - const float v = *(float *) src1->data; - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_BF16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_BF16); - - GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void ggml_compute_forward_add1_bf16_bf16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_scalar(src1)); - - // scalar to add - const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(src0->type == GGML_TYPE_BF16); - GGML_ASSERT(src1->type == GGML_TYPE_BF16); - GGML_ASSERT(dst->type == GGML_TYPE_BF16); - - GGML_ASSERT( nb0 == sizeof(ggml_bf16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_bf16_t)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v); - } - } -} - -static void ggml_compute_forward_add1( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_add1_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - if (src1->type == GGML_TYPE_F16) { - ggml_compute_forward_add1_f16_f16(params, dst); - } - else if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add1_f16_f32(params, dst); - } - else { - GGML_ABORT("fatal error"); - } - } break; - case GGML_TYPE_BF16: - { - if (src1->type == GGML_TYPE_BF16) { - ggml_compute_forward_add1_bf16_bf16(params, dst); - } - else if (src1->type == GGML_TYPE_F32) { - ggml_compute_forward_add1_bf16_f32(params, dst); - } - else { - GGML_ABORT("fatal error"); - } - } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q4_0_4_4: - case GGML_TYPE_Q4_0_4_8: - case GGML_TYPE_Q4_0_8_8: - { - ggml_compute_forward_add1_q_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_acc - -static void ggml_compute_forward_acc_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - - // view src0 and dst with these strides and data offset inbytes during acc - // nb0 is implicitly element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) dst->op_params)[0]; - size_t nb2 = ((int32_t *) dst->op_params)[1]; - size_t nb3 = ((int32_t *) dst->op_params)[2]; - size_t offset = ((int32_t *) dst->op_params)[3]; - bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - if (params->ith == 0) { - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); - } - ggml_barrier(params->threadpool); - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src1); - const int nc = src1->ne[0]; - - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) - - // src0 and dst as viewed during acc - const size_t nb0 = ggml_element_size(src0); - - const size_t nb00 = nb0; - const size_t nb01 = nb1; - const size_t nb02 = nb2; - const size_t nb03 = nb3; - - GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst)); - GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0)); - - GGML_ASSERT(nb10 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are viewed with shape of src1 and offset - // => same indices - const int i3 = ir/(ne12*ne11); - const int i2 = (ir - i3*ne12*ne11)/ne11; - const int i1 = (ir - i3*ne12*ne11 - i2*ne11); - -#ifdef GGML_USE_ACCELERATE - vDSP_vadd( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1, - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc); -#else - ggml_vec_add_f32(nc, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); -#endif - } -} - -static void ggml_compute_forward_acc( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_acc_f32(params, dst); - } break; - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q4_0_4_4: - case GGML_TYPE_Q4_0_4_8: - case GGML_TYPE_Q4_0_8_8: - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sub - -static void ggml_compute_forward_sub_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0; r < nr0; ++r) { -#ifdef GGML_USE_ACCELERATE - vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); -#else - ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int ir = ir0; ir < ir1; ++ir) { - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] - *src1_ptr; - } - } - } -} - -static void ggml_compute_forward_sub( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sub_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_mul - -static void ggml_compute_forward_mul_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0 ; r < nr0; ++r) { -#ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_mul_f32); - - vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); -#else - ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne00; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); - } - } - } -} - -static void ggml_compute_forward_mul( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now"); - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_mul_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_div - -static void ggml_compute_forward_div_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nr = ggml_nrows(src0); - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (nb10 == sizeof(float)) { - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - const int64_t nr0 = ne00 / ne10; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - for (int64_t r = 0; r < nr0; ++r) { -#ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_div_f32); - - vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); -#else - ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); -#endif - } - } - } else { - // src1 is not contiguous - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - - for (int64_t i0 = 0; i0 < ne00; ++i0) { - const int64_t i10 = i0 % ne10; - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); - - dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); - } - } - } -} - -static void ggml_compute_forward_div( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_div_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sqr - -static void ggml_compute_forward_sqr_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sqr_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sqr( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sqr_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sqrt - -static void ggml_compute_forward_sqrt_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sqrt_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sqrt( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sqrt_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_log - -static void ggml_compute_forward_log_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_log_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_log( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_log_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sin - -static void ggml_compute_forward_sin_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_sin_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sin( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sin_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_cos - -static void ggml_compute_forward_cos_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_cos_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_cos( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_cos_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sum - -static void ggml_compute_forward_sum_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_scalar(dst)); - assert(src0->nb[0] == sizeof(float)); - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - ggml_float sum = 0; - ggml_float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32_ggf(ne00, - &row_sum, - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - sum += row_sum; - } - } - } - ((float *) dst->data)[0] = sum; -} - -static void ggml_compute_forward_sum_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_scalar(dst)); - - assert(src0->nb[0] == sizeof(ggml_fp16_t)); - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - float sum = 0; - float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f16_ggf(ne00, - &row_sum, - (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); - sum += row_sum; - } - } - } - ((ggml_fp16_t *) dst->data)[0] = GGML_FP32_TO_FP16(sum); -} - -static void ggml_compute_forward_sum_bf16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_scalar(dst)); - - assert(src0->nb[0] == sizeof(ggml_bf16_t)); - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) - - float sum = 0; - float row_sum = 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_bf16_ggf(ne00, - &row_sum, - (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03)); - sum += row_sum; - } - } - } - ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum); -} - -static void ggml_compute_forward_sum( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sum_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_sum_f16(params, dst); - } break; - case GGML_TYPE_BF16: - { - ggml_compute_forward_sum_bf16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sum_rows - -static void ggml_compute_forward_sum_rows_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT(dst->nb[0] == sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(ne0 == 1); - GGML_ASSERT(ne1 == ne01); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - - for (int64_t i3 = 0; i3 < ne03; i3++) { - for (int64_t i2 = 0; i2 < ne02; i2++) { - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03); - float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3); - float row_sum = 0; - ggml_vec_sum_f32(ne00, &row_sum, src_row); - dst_row[0] = row_sum; - } - } - } -} - -static void ggml_compute_forward_sum_rows( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sum_rows_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_mean - -static void ggml_compute_forward_mean_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - assert(ne0 == 1); - assert(ne1 == ne01); - assert(ne2 == ne02); - assert(ne3 == ne03); - - UNUSED(ne0); - UNUSED(ne1); - UNUSED(ne2); - UNUSED(ne3); - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - ggml_vec_sum_f32(ne00, - (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), - (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03)); - - *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00; - } - } - } -} - -static void ggml_compute_forward_mean( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_mean_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_argmax - -static void ggml_compute_forward_argmax_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(src0->nb[0] == sizeof(float)); - assert(dst->nb[0] == sizeof(float)); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - - const size_t nb01 = src0->nb[1]; - const size_t nb0 = dst->nb[0]; - - for (int64_t i1 = 0; i1 < ne01; i1++) { - float * src = (float *) ((char *) src0->data + i1*nb01); - int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0); - int v = 0; - ggml_vec_argmax_f32(ne00, &v, src); - dst_[0] = v; - } -} - -static void ggml_compute_forward_argmax( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_argmax_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_repeat - -static void ggml_compute_forward_repeat_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_can_repeat(src0, dst)); - - GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in ggml_can_repeat - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - ggml_vec_cpy_f32(ne00, - (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0), - (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01)); - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_repeat_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_can_repeat(src0, dst)); - - GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in ggml_can_repeat - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); - ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); - // ggml_vec_cpy_f16(ne00, y, x) - for (int i = 0; i < ne00; ++i) { - y[i] = x[i]; - } - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_repeat( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_I16: - { - ggml_compute_forward_repeat_f16(params, dst); - } break; - case GGML_TYPE_F32: - case GGML_TYPE_I32: - { - ggml_compute_forward_repeat_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_repeat_back - -static void ggml_compute_forward_repeat_back_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_can_repeat(dst, src0)); - - GGML_TENSOR_UNARY_OP_LOCALS - - // guaranteed to be an integer due to the check in ggml_can_repeat - const int nr0 = (int)(ne00/ne0); - const int nr1 = (int)(ne01/ne1); - const int nr2 = (int)(ne02/ne2); - const int nr3 = (int)(ne03/ne3); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - if (ggml_is_contiguous(dst)) { - ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - } else { - for (int k3 = 0; k3 < ne3; k3++) { - for (int k2 = 0; k2 < ne2; k2++) { - for (int k1 = 0; k1 < ne1; k1++) { - ggml_vec_set_f32(ne0, - (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3), - 0); - } - } - } - } - - // TODO: maybe this is not optimal? - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne3; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne2; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne1; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - ggml_vec_acc_f32(ne0, - (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1), - (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00)); - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_repeat_back( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_repeat_back_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_concat - -static void ggml_compute_forward_concat_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int32_t dim = ggml_get_op_params_i32(dst, 0); - - GGML_ASSERT(dim >= 0 && dim < 4); - - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = src0->ne[dim]; - - const float * x; - - // TODO: smarter multi-theading - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ith; i2 < ne2; i2 += nth) { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); - } else { - x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); - } - - float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - - *y = *x; - } - } - } - } -} - -static void ggml_compute_forward_concat( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_I32: - { - ggml_compute_forward_concat_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_abs - -static void ggml_compute_forward_abs_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_abs_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_abs( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_abs_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sgn - -static void ggml_compute_forward_sgn_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_sgn_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sgn( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sgn_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_neg - -static void ggml_compute_forward_neg_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_neg_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_neg( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_neg_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_step - -static void ggml_compute_forward_step_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_step_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_step( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_step_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_tanh - -static void ggml_compute_forward_tanh_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_tanh_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_tanh( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_tanh_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_elu - -static void ggml_compute_forward_elu_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_elu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_elu( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_elu_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_relu - -static void ggml_compute_forward_relu_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_relu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_relu( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_relu_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_sigmoid - -static void ggml_compute_forward_sigmoid_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_sigmoid_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_sigmoid( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_sigmoid_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_gelu - -static void ggml_compute_forward_gelu_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_compute_forward_gelu( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_gelu_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_gelu_quick - -static void ggml_compute_forward_gelu_quick_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_gelu_quick_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_compute_forward_gelu_quick( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_gelu_quick_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_silu - -static void ggml_compute_forward_silu_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_compute_forward_silu( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_silu_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} -// ggml_compute_forward_leaky_relu - -static void ggml_compute_forward_leaky_relu_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - float negative_slope; - memcpy(&negative_slope, dst->op_params, sizeof(float)); - - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < n; i++) { - ggml_vec_leaky_relu_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope); - } -} - -static void ggml_compute_forward_leaky_relu( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_leaky_relu_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_silu_back - -static void ggml_compute_forward_silu_back_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * grad = dst->src[1]; - - assert(ggml_is_contiguous_1(grad)); - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - assert(ggml_are_same_shape(src0, grad)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - ggml_vec_silu_backward_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1])), - (float *) ((char *) grad->data + i1*(grad->nb[1]))); - -#ifndef NDEBUG - for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; - UNUSED(x); - assert(!isnan(x)); - assert(!isinf(x)); - } -#endif - } -} - -static void ggml_compute_forward_silu_back( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_silu_back_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - - -static void ggml_compute_forward_hardswish_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_hardswish_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} -static void ggml_compute_forward_hardswish( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_hardswish_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -static void ggml_compute_forward_hardsigmoid_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_hardsigmoid_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_hardsigmoid( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_hardsigmoid_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -static void ggml_compute_forward_exp_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - ggml_vec_exp_f32(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_exp( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_exp_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - - -// ggml_compute_forward_norm - -static void ggml_compute_forward_norm_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - GGML_ASSERT(eps > 0.0f); - - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)x[i00]; - } - - float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - ggml_float sum2 = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (ggml_float)(v*v); - } - - float variance = sum2/ne00; - const float scale = 1.0f/sqrtf(variance + eps); - - ggml_vec_scale_f32(ne00, y, scale); - } - } - } -} - -static void ggml_compute_forward_norm( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_norm_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_group_rms_norm - -static void ggml_compute_forward_rms_norm_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - GGML_ASSERT(eps > 0.0f); - - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)(x[i00] * x[i00]); - } - - const float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - // for (int i00 = 0; i00 < ne00; i00++) { - // y[i00] = x[i00]; - // } - - const float scale = 1.0f/sqrtf(mean + eps); - - ggml_vec_scale_f32(ne00, y, scale); - } - } - } -} - -static void ggml_compute_forward_rms_norm( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_rms_norm_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -static void ggml_compute_forward_rms_norm_back_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_BINARY_OP_LOCALS - - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - - // TODO: optimize - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - // src1 is same shape as src0 => same indices - const int64_t i11 = i01; - const int64_t i12 = i02; - const int64_t i13 = i03; - - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); - - ggml_float sum_xx = 0.0; - ggml_float sum_xdz = 0.0; - - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum_xx += (ggml_float)(x[i00] * x[i00]); - sum_xdz += (ggml_float)(x[i00] * dz[i00]); - } - - //const float mean = (float)(sum_xx)/ne00; - const float mean_eps = (float)(sum_xx)/ne00 + eps; - const float sum_eps = (float)(sum_xx) + eps*ne00; - //const float mean_xdz = (float)(sum_xdz)/ne00; - // we could cache rms from forward pass to improve performance. - // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms. - //const float rms = sqrtf(mean_eps); - const float rrms = 1.0f / sqrtf(mean_eps); - //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3) - - { - // z = rms_norm(x) - // - // rms_norm(src0) = - // scale( - // src0, - // div( - // 1, - // sqrt( - // add( - // scale( - // sum( - // sqr( - // src0)), - // (1.0/N)), - // eps)))); - - // postorder: - // ## op args grad - // 00 param src0 grad[#00] - // 01 const 1 - // 02 sqr (#00) grad[#02] - // 03 sum (#02) grad[#03] - // 04 const 1/N - // 05 scale (#03, #04) grad[#05] - // 06 const eps - // 07 add (#05, #06) grad[#07] - // 08 sqrt (#07) grad[#08] - // 09 div (#01,#08) grad[#09] - // 10 scale (#00,#09) grad[#10] - // - // backward pass, given grad[#10] - // #10: scale - // grad[#00] += scale(grad[#10],#09) - // grad[#09] += sum(mul(grad[#10],#00)) - // #09: div - // grad[#08] += neg(mul(grad[#09], div(#09,#08))) - // #08: sqrt - // grad[#07] += mul(grad[#08], div(0.5, #08)) - // #07: add - // grad[#05] += grad[#07] - // #05: scale - // grad[#03] += scale(grad[#05],#04) - // #03: sum - // grad[#02] += repeat(grad[#03], #02) - // #02: - // grad[#00] += scale(mul(#00, grad[#02]), 2.0) - // - // substitute and simplify: - // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) - // grad[#02] = repeat(grad[#03], #02) - // grad[#02] = repeat(scale(grad[#05],#04), #02) - // grad[#02] = repeat(scale(grad[#07],#04), #02) - // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) - // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) - // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) - // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) - // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) - // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) - // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) - // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps)) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps))) - // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps)) - // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps)) - // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps)) - // a = b*c + d*e - // a = b*c*f/f + d*e*f/f - // a = (b*c*f + d*e*f)*(1/f) - // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c)) - // a = (b + d*e/c)*c - // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps) - // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms - // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms - // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms - // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms - // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms - // a = (dz + x*div(-mean_xdz,mean_eps))*rrms - // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms) - // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) - // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) - } - // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms) - // post-order: - // dx := x - // dx := scale(dx,-mean_xdz/mean_eps) - // dx := add(dx, dz) - // dx := scale(dx, rrms) - float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - ggml_vec_cpy_f32 (ne00, dx, x); - // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); - ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); - ggml_vec_acc_f32 (ne00, dx, dz); - ggml_vec_scale_f32(ne00, dx, rrms); - } - } - } -} - -static void ggml_compute_forward_rms_norm_back( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_rms_norm_back_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_group_norm - -static void ggml_compute_forward_group_norm_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - // TODO: optimize - - float eps; - memcpy(&eps, dst->op_params + 1, sizeof(float)); - - int n_channels = src0->ne[2]; - int n_groups = dst->op_params[0]; - int n_channels_per_group = (n_channels + n_groups - 1) / n_groups; - for (int i = ith; i < n_groups; i += nth) { - int start = i * n_channels_per_group; - int end = start + n_channels_per_group; - if (end > n_channels) { - end = n_channels; - } - int step = end - start; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - ggml_float sum = 0.0; - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); - - ggml_float sumr = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sumr += (ggml_float)x[i00]; - } - sum += sumr; - } - } - const float mean = sum / (ne00 * ne01 * step); - - ggml_float sum2 = 0.0; - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03); - - float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); - - ggml_float sumr = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sumr += (ggml_float)(v * v); - } - sum2 += sumr; - } - } - const float variance = sum2 / (ne00 * ne01 * step); - const float scale = 1.0f / sqrtf(variance + eps); - - for (int64_t i02 = start; i02 < end; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3); - ggml_vec_scale_f32(ne00, y, scale); - } - } - } - } -} - -static void ggml_compute_forward_group_norm( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_group_norm_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_mul_mat - -static void ggml_compute_forward_mul_mat_one_chunk( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const int64_t num_rows_per_vec_dot, - const int64_t ir0_start, - const int64_t ir0_end, - const int64_t ir1_start, - const int64_t ir1_end) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const enum ggml_type type = src0->type; - - const bool src1_cont = ggml_is_contiguous(src1); - - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - - //printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end); - - // threads with no work simply yield (not sure if it helps) - if (ir0_start >= ir0_end || ir1_start >= ir1_end) { - return; - } - - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); - - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; - - const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11; - - // attempt to reduce false-sharing (does not seem to make a difference) - // 16 * 2, accounting for mmla kernels - float tmp[32]; - - for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { - for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { - const int64_t i13 = (ir1 / (ne12 * ne1)); - const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; - const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); - - // broadcast src0 into src1 - const int64_t i03 = i13 / r3; - const int64_t i02 = i12 / r2; - - const int64_t i1 = i11; - const int64_t i2 = i12; - const int64_t i3 = i13; - - const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03); - - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char*)wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size - : (i11 * nb11 + i12 * nb12 + i13 * nb13)); - float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} - - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - } - - for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { - memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float)); - } - } - } - } -} - -static void ggml_compute_forward_mul_mat( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum ggml_type type = src0->type; - - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; - ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; - int64_t const vec_dot_num_rows = type_traits[type].nrows; - int64_t const matmul_num_cols = type_traits[type].ncols; - int64_t const blck_size_interleave = type_traits[type].blck_size_interleave; - ggml_gemv_t const gemv = type_traits[type].gemv; - ggml_gemm_t const gemm = type_traits[type].gemm; - - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == ggml_type_size(src1->type)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - -#if GGML_USE_LLAMAFILE - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - - const bool src1_cont = ggml_is_contiguous(src1); - - if (src1_cont) { - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), - (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, - nb01/ggml_type_size(src0->type), - (const char *)src1->data + i12*nb12 + i13*nb13, - nb11/ggml_type_size(src1->type), - (char *)dst->data + i12*nb2 + i13*nb3, - nb1/ggml_type_size(dst->type), - ith, nth, - src0->type, - src1->type, - dst->type)) - goto UseGgmlGemm1; - return; - } -UseGgmlGemm1:; -#endif - - if (src1->type != vec_dot_type) { - char * wdata = params->wdata; - - const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); - const size_t nbw2 = nbw1*ne11; - const size_t nbw3 = nbw2*ne12; - - assert(params->wsize >= ne13*nbw3); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - int64_t i11_processed = 0; - if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - 4, ne10, blck_size_interleave); - } - i11_processed = ne11 - ne11 % 4; - } - for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { - from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); - } - } - } - } - - if (ith == 0) { - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - atomic_store_explicit(¶ms->threadpool->current_chunk, nth, memory_order_relaxed); - } - - ggml_barrier(params->threadpool); - -#if GGML_USE_LLAMAFILE - if (src1->type != vec_dot_type) { - const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); - - for (int64_t i13 = 0; i13 < ne13; i13++) - for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), - (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, - nb01/ggml_type_size(src0->type), - (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, - row_size/ggml_type_size(vec_dot_type), - (char *)dst->data + i12*nb2 + i13*nb3, - nb1/ggml_type_size(dst->type), - ith, nth, - src0->type, - vec_dot_type, - dst->type)) - goto UseGgmlGemm2; - return; - } -UseGgmlGemm2:; -#endif - - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) - const int64_t nr0 = ne0; - - // This is the size of the rest of the dimensions of the result - const int64_t nr1 = ne1 * ne2 * ne3; - - // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols - int64_t num_rows_per_vec_dot = vec_dot_num_rows; - // TODO: currently the mmla kernels support only even numbered rows/cols. - // this check can be removed once they are extended to support odd numbered rows/cols too - if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { - num_rows_per_vec_dot = 1; - } - - // Now select a reasonable chunk size. - int chunk_size = 16; - - // We need to step up the size if it's small - if (nr0 == 1 || nr1 == 1) { - chunk_size = 64; - } - - // distribute the work across the inner or outer loop based on which one is larger - // The number of chunks in the 0/1 dim. - // CEIL(nr0/chunk_size) - int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; - int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; - - // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. - // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 - // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. - if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { - // distribute the thread work across the inner or outer loop based on which one is larger - nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - } - - // The number of elements in each chunk - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - - if ((ggml_n_dims(src0) == 2) && gemv) { - const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11; - int64_t src0_start = (ith * ne01) / nth; - int64_t src0_end = ((ith + 1) * ne01) / nth; - src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; - src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; - if (src0_start >= src0_end) return; - - // If there are more than three rows in src1, use gemm; otherwise, use gemv. - if (gemm && (ne11 > 3)) { - gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); - } - for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) { - gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, - src0_end - src0_start); - } - return; - } - - // The first chunk comes from our thread_id, the rest will get auto-assigned. - int current_chunk = ith; - - while (current_chunk < nchunk0 * nchunk1) { - const int64_t ith0 = current_chunk % nchunk0; - const int64_t ith1 = current_chunk / nchunk0; - - const int64_t ir0_start = dr0 * ith0; - const int64_t ir0_end = MIN(ir0_start + dr0, nr0); - - const int64_t ir1_start = dr1 * ith1; - const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - - ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); - - if (nth >= nchunk0 * nchunk1) { - break; - } - - current_chunk = atomic_fetch_add_explicit(¶ms->threadpool->current_chunk, 1, memory_order_relaxed); - } -} - -// ggml_compute_forward_mul_mat_id - -static void ggml_compute_forward_mul_mat_id( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * ids = dst->src[2]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum ggml_type type = src0->type; - - const bool src1_cont = ggml_is_contiguous(src1); - - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; - int64_t const matmul_num_cols = type_traits[type].ncols; - ggml_gemv_t const gemv = type_traits[type].gemv; - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == ggml_type_size(src1->type)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - // row groups - const int n_ids = ids->ne[0]; // n_expert_used - const int n_as = ne02; // n_expert - - char * wdata_src1_end = (src1->type == vec_dot_type) ? - (char *) params->wdata : - (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t)); - - struct mmid_row_mapping { - int32_t i1; - int32_t i2; - }; - - int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] - - if (src1->type != vec_dot_type) { - char * wdata = params->wdata; - - const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); - const size_t nbw2 = nbw1*ne11; - const size_t nbw3 = nbw2*ne12; - - assert(params->wsize >= ne13*nbw3); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = ith; i11 < ne11; i11 += nth) { - from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); - } - } - } - } - -#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] - - if (ith == 0) { - // initialize matrix_row_counts - memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); - - // group rows by src0 matrix - for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { - for (int id = 0; id < n_ids; ++id) { - const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); - - assert(i02 >= 0 && i02 < n_as); - - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1}; - matrix_row_counts[i02] += 1; - } - } - } - - ggml_barrier(params->threadpool); - - // compute each matrix multiplication in sequence - for (int cur_a = 0; cur_a < n_as; ++cur_a) { - const int64_t cne1 = matrix_row_counts[cur_a]; - - if (cne1 == 0) { - continue; - } - - const char * src0_cur = (const char *) src0->data + cur_a*nb02; - - const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); - - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1; // src1 rows - - if (((ggml_n_dims(src0) - 1) == 2) && gemv) { - int64_t src0_cur_start = (ith * ne01) / nth; - int64_t src0_cur_end = ((ith + 1) * ne01) / nth; - src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start; - src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end; - if (src0_cur_start >= src0_cur_end) return; - - for (int ir1 = 0; ir1 < nr1; ir1++) { - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); - const int id = row_mapping.i1; // selected expert index - - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 - - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row - - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12 * ne11) * row_size - : (i11 * nb11 + i12 * nb12)); - - gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); - } - continue; - } - - // distribute the thread work across the inner or outer loop based on which one is larger - - const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - const int64_t ith0 = ith % nth0; - const int64_t ith1 = ith / nth0; - - const int64_t dr0 = (nr0 + nth0 - 1)/nth0; - const int64_t dr1 = (nr1 + nth1 - 1)/nth1; - - const int64_t ir010 = dr0*ith0; - const int64_t ir011 = MIN(ir010 + dr0, nr0); - - const int64_t ir110 = dr1*ith1; - const int64_t ir111 = MIN(ir110 + dr1, nr1); - - // threads with no work simply yield (not sure if it helps) - //if (ir010 >= ir011 || ir110 >= ir111) { - // sched_yield(); - // continue; - //} - - // block-tiling attempt - const int64_t blck_0 = 16; - const int64_t blck_1 = 16; - - // attempt to reduce false-sharing (does not seem to make a difference) - float tmp[16]; - - for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { - for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { - for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t _i12 = ir1; // logical row index for this expert - - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12); - const int id = row_mapping.i1; // selected expert index - - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 - - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row - - // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides - // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using - // the original src1 data pointer, so we should index using the indices directly - // TODO: this is a bit of a hack, we should probably have a better way to handle this - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11)*row_size - : (i11*nb11 + i12*nb12)); - - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); - - //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); - //} - - for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); - } - - memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); - } - } - } - } - -#undef MMID_MATRIX_ROW -} - -// ggml_compute_forward_out_prod - -static void ggml_compute_forward_out_prod_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - GGML_ASSERT(ne0 == ne00); - GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne3 == ne13); - GGML_ASSERT(ne03 == ne13); - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == sizeof(float)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - // GGML_ASSERT(nb0 <= nb1); - // GGML_ASSERT(nb1 <= nb2); - // GGML_ASSERT(nb2 <= nb3); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - - if (ith == 0) { - ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - } - ggml_barrier(params->threadpool); - - // dst[:,:,:,:] = 0 - // for i2,i3: - // for i1: - // for i01: - // for i0: - // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] - - // parallelize by last three dimensions - - // total rows in dst - const int64_t nr = ne1*ne2*ne3; - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - // block-tiling attempt - const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32); - const int64_t blck_1 = 16; - - for (int64_t bir = ir0; bir < ir1; bir += blck_1) { - const int64_t bir1 = MIN(bir + blck_1, ir1); - for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { - const int64_t bne01 = MIN(bi01 + blck_0, ne01); - for (int64_t ir = bir; ir < bir1; ++ir) { - // dst indices - const int64_t i3 = ir/(ne2*ne1); - const int64_t i2 = (ir - i3*ne2*ne1)/ne1; - const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); - - const int64_t i02 = i2; - const int64_t i03 = i3; - - //const int64_t i10 = i1; - const int64_t i12 = i2; - const int64_t i13 = i3; - -#if GGML_VEC_MAD_UNROLL > 2 - const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL); - for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); - } - for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - ggml_vec_mad_f32(ne0, d, s0, *s1); - } -#else - for (int64_t i01 = bi01; i01 < bne01; ++i01) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - ggml_vec_mad_f32(ne0, d, s0, *s1); - } -#endif - } - } - } -} - -static void ggml_compute_forward_out_prod_q_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS; - - const int ith = params->ith; - const int nth = params->nth; - - const enum ggml_type type = src0->type; - ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - - GGML_ASSERT(ne02 == ne12); - GGML_ASSERT(ne03 == ne13); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 dim0 - GGML_ASSERT(nb00 == ggml_type_size(type)); - - // dst dim0 cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - // GGML_ASSERT(nb0 <= nb1); - // GGML_ASSERT(nb1 <= nb2); - // GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(ne0 == ne00); - GGML_ASSERT(ne1 == ne10); - GGML_ASSERT(ne2 == ne02); - GGML_ASSERT(ne3 == ne03); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - - if (ith == 0) { - ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - } - ggml_barrier(params->threadpool); - - // parallelize by last three dimensions - - // total rows in dst - const int64_t nr = ne1*ne2*ne3; - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - // dst[:,:,:,:] = 0 - // for i2,i3: - // for i1: - // for i01: - // for i0: - // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] - - float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; - - for (int64_t ir = ir0; ir < ir1; ++ir) { - // dst indices - const int64_t i3 = ir/(ne2*ne1); - const int64_t i2 = (ir - i3*ne2*ne1)/ne1; - const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); - - const int64_t i02 = i2; - const int64_t i03 = i3; - - //const int64_t i10 = i1; - const int64_t i12 = i2; - const int64_t i13 = i3; - - for (int64_t i01 = 0; i01 < ne01; ++i01) { - const int64_t i11 = i01; - - float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); - float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); - float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - - dequantize_row_q(s0, wdata, ne0); - ggml_vec_mad_f32(ne0, d, wdata, *s1); - } - } -} - -static void ggml_compute_forward_out_prod( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q4_0_4_4: - case GGML_TYPE_Q4_0_4_8: - case GGML_TYPE_Q4_0_8_8: - { - ggml_compute_forward_out_prod_q_f32(params, dst); - } break; - case GGML_TYPE_F16: - { - GGML_ABORT("fatal error"); // todo - // ggml_compute_forward_out_prod_f16_f32(params, dst); - } - case GGML_TYPE_F32: - { - ggml_compute_forward_out_prod_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_scale - -static void ggml_compute_forward_scale_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - // scale factor - float v; - memcpy(&v, dst->op_params, sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const size_t nb01 = src0->nb[1]; - - const size_t nb1 = dst->nb[1]; - - for (int i1 = ir0; i1 < ir1; i1++) { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); - } - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); - } -} - -static void ggml_compute_forward_scale( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_scale_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_set - -static void ggml_compute_forward_set_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - - // view src0 and dst with these strides and data offset inbytes during set - // nb0 is implicitly element_size because src0 and dst are contiguous - size_t nb1 = ((int32_t *) dst->op_params)[0]; - size_t nb2 = ((int32_t *) dst->op_params)[1]; - size_t nb3 = ((int32_t *) dst->op_params)[2]; - size_t offset = ((int32_t *) dst->op_params)[3]; - bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - - if (!inplace) { - if (params->ith == 0) { - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); - } - ggml_barrier(params->threadpool); - } - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(src1); - const int nc = src1->ne[0]; - - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) - - // src0 and dst as viewed during set - const size_t nb0 = ggml_element_size(src0); - - const int im0 = (ne10 == 0 ? 0 : ne10-1); - const int im1 = (ne11 == 0 ? 0 : ne11-1); - const int im2 = (ne12 == 0 ? 0 : ne12-1); - const int im3 = (ne13 == 0 ? 0 : ne13-1); - - GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst)); - - GGML_ASSERT(nb10 == sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // src0 and dst are viewed with shape of src1 and offset - // => same indices - const int i3 = ir/(ne12*ne11); - const int i2 = (ir - i3*ne12*ne11)/ne11; - const int i1 = (ir - i3*ne12*ne11 - i2*ne11); - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); - } -} - -static void ggml_compute_forward_set( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_set_f32(params, dst); - } break; - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q4_0_4_4: - case GGML_TYPE_Q4_0_4_8: - case GGML_TYPE_Q4_0_8_8: - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_cpy - -static void ggml_compute_forward_cpy( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - ggml_compute_forward_dup(params, dst); -} - -// ggml_compute_forward_cont - -static void ggml_compute_forward_cont( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - ggml_compute_forward_dup(params, dst); -} - -// ggml_compute_forward_reshape - -static void ggml_compute_forward_reshape( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(dst); -} - -// ggml_compute_forward_view - -static void ggml_compute_forward_view( - const struct ggml_compute_params * params, - const struct ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(dst); -} - -// ggml_compute_forward_permute - -static void ggml_compute_forward_permute( - const struct ggml_compute_params * params, - const struct ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(dst); -} - -// ggml_compute_forward_transpose - -static void ggml_compute_forward_transpose( - const struct ggml_compute_params * params, - const struct ggml_tensor * dst) { - // NOP - UNUSED(params); - UNUSED(dst); -} - -// ggml_compute_forward_get_rows - -static void ggml_compute_forward_get_rows_q( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); - - const enum ggml_type type = src0->type; - ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - - assert(ne0 == nc); - assert(ne02 == ne11); - assert(nb00 == ggml_type_size(type)); - assert(ggml_nrows(dst) == nr); - - const int ith = params->ith; - const int nth = params->nth; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i/(ne11*ne10); - const int64_t i11 = (i - i12*ne11*ne10)/ne10; - const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - - GGML_ASSERT(i01 >= 0 && i01 < ne01); - - dequantize_row_q( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); - } -} - -static void ggml_compute_forward_get_rows_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); - - assert(ne0 == nc); - assert(ne02 == ne11); - assert(nb00 == sizeof(ggml_fp16_t)); - assert(ggml_nrows(dst) == nr); - - const int ith = params->ith; - const int nth = params->nth; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i/(ne11*ne10); - const int64_t i11 = (i - i12*ne11*ne10)/ne10; - const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - - GGML_ASSERT(i01 >= 0 && i01 < ne01); - - ggml_fp16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); - } -} - -static void ggml_compute_forward_get_rows_bf16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); - - assert(ne0 == nc); - assert(ne02 == ne11); - assert(nb00 == sizeof(ggml_bf16_t)); - assert(ggml_nrows(dst) == nr); - - const int ith = params->ith; - const int nth = params->nth; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i/(ne11*ne10); - const int64_t i11 = (i - i12*ne11*ne10)/ne10; - const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - - GGML_ASSERT(i01 >= 0 && i01 < ne01); - - ggml_bf16_to_fp32_row( - (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); - } -} - -static void ggml_compute_forward_get_rows_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); - - assert(ne0 == nc); - assert(ne02 == ne11); - assert(nb00 == sizeof(float)); - assert(ggml_nrows(dst) == nr); - - const int ith = params->ith; - const int nth = params->nth; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i/(ne11*ne10); - const int64_t i11 = (i - i12*ne11*ne10)/ne10; - const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); - const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - - GGML_ASSERT(i01 >= 0 && i01 < ne01); - - ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), - (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); - } -} - -static void ggml_compute_forward_get_rows( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q4_0_4_4: - case GGML_TYPE_Q4_0_4_8: - case GGML_TYPE_Q4_0_8_8: - { - ggml_compute_forward_get_rows_q(params, dst); - } break; - case GGML_TYPE_F16: - { - ggml_compute_forward_get_rows_f16(params, dst); - } break; - case GGML_TYPE_BF16: - { - ggml_compute_forward_get_rows_bf16(params, dst); - } break; - case GGML_TYPE_F32: - case GGML_TYPE_I32: - { - ggml_compute_forward_get_rows_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } - - //static bool first = true; - //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - //if (first) { - // first = false; - //} else { - // for (int k = 0; k < dst->ne[1]; ++k) { - // for (int j = 0; j < dst->ne[0]/16; ++j) { - // for (int i = 0; i < 16; ++i) { - // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // exit(0); - //} -} - -// ggml_compute_forward_get_rows_back - -static void ggml_compute_forward_get_rows_back_f32_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_is_contiguous(dst)); - - // ggml_compute_forward_dup_same_cont(params, opt0, dst); - - memset(dst->data, 0, ggml_nbytes(dst)); - - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); - - GGML_ASSERT( dst->ne[0] == nc); - GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - for (int j = 0; j < nc; ++j) { - ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_FP16_TO_FP32(v); - } - } -} - -static void ggml_compute_forward_get_rows_back_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - if (params->ith != 0) { - return; - } - - GGML_ASSERT(ggml_is_contiguous(dst)); - - // ggml_compute_forward_dup_same_cont(params, opt0, dst); - - memset(dst->data, 0, ggml_nbytes(dst)); - - const int nc = src0->ne[0]; - const int nr = ggml_nelements(src1); - - GGML_ASSERT( dst->ne[0] == nc); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int i = 0; i < nr; ++i) { - const int r = ((int32_t *) src1->data)[i]; - - ggml_vec_add_f32(nc, - (float *) ((char *) dst->data + r*dst->nb[1]), - (float *) ((char *) dst->data + r*dst->nb[1]), - (float *) ((char *) src0->data + i*src0->nb[1])); - } -} - -static void ggml_compute_forward_get_rows_back( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_get_rows_back_f32_f16(params, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_get_rows_back_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } - - //static bool first = true; - //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); - //if (first) { - // first = false; - //} else { - // for (int k = 0; k < dst->ne[1]; ++k) { - // for (int j = 0; j < dst->ne[0]/16; ++j) { - // for (int i = 0; i < 16; ++i) { - // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); - // } - // printf("\n"); - // } - // printf("\n"); - // } - // printf("\n"); - // exit(0); - //} -} - -// ggml_compute_forward_diag - -static void ggml_compute_forward_diag_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - // TODO: handle transposed/permuted matrices - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(ne00 == ne0); - GGML_ASSERT(ne00 == ne1); - GGML_ASSERT(ne01 == 1); - GGML_ASSERT(ne02 == ne2); - GGML_ASSERT(ne03 == ne3); - - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb0 == sizeof(float)); - - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = 0; i2 < ne2; i2++) { - for (int i1 = 0; i1 < ne1; i1++) { - float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02); - for (int i0 = 0; i0 < i1; i0++) { - d[i0] = 0; - } - d[i1] = s[i1]; - for (int i0 = i1+1; i0 < ne0; i0++) { - d[i0] = 0; - } - } - } - } -} - -static void ggml_compute_forward_diag( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_diag_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_diag_mask_inf - -static void ggml_compute_forward_diag_mask_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const float value) { - - const struct ggml_tensor * src0 = dst->src[0]; - - const int ith = params->ith; - const int nth = params->nth; - - const int n_past = ((int32_t *) dst->op_params)[0]; - const bool inplace = src0->data == dst->data; - - GGML_ASSERT(n_past >= 0); - - if (!inplace) { - if (ith == 0) { - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); - } - ggml_barrier(params->threadpool); - } - - // TODO: handle transposed/permuted matrices - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - const int nr = src0->ne[1]; - const int nz = n/nr; - - GGML_ASSERT( dst->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - for (int k = 0; k < nz; k++) { - for (int j = ith; j < nr; j += nth) { - for (int i = n_past; i < nc; i++) { - if (i > n_past + j) { - *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value; - } - } - } - } -} - -static void ggml_compute_forward_diag_mask_inf( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -static void ggml_compute_forward_diag_mask_zero( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_diag_mask_f32(params, dst, 0); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_soft_max - -static void ggml_compute_forward_soft_max_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - assert(ggml_is_contiguous(dst)); - assert(ggml_are_same_shape(src0, dst)); - - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - - // TODO: handle transposed/permuted matrices - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - //const int64_t ne11 = src1 ? src1->ne[1] : 1; - - // TODO: is this supposed to be ceil instead of floor? - // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 - const uint32_t n_head = ne02; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith; - - const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); - - for (int i1 = ir0; i1 < ir1; i1++) { - // ALiBi - const uint32_t h = (i1/ne01)%ne02; // head - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - - float * sp = (float *)((char *) src0->data + i1*src0->nb[1]); - float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); - - // broadcast the mask across rows - ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; - - ggml_vec_cpy_f32 (nc, wp, sp); - ggml_vec_scale_f32(nc, wp, scale); - if (mp_f32) { - if (use_f16) { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]); - } - } else { - for (int i = 0; i < nc; ++i) { - wp[i] += slope*mp_f32[i]; - } - } - } - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(wp[i])); - } -#endif - - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, wp); - - ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max); - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(nc, dp, sum); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dp[i])); - assert(!isinf(dp[i])); - } -#endif - } -} - -static void ggml_compute_forward_soft_max( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_soft_max_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - - -// ggml_compute_forward_soft_max_back - -static void ggml_compute_forward_soft_max_back_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_are_same_shape(src1, dst)); - - // TODO: handle transposed/permuted matrices - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float *dy = (float *)((char *) src0->data + i1*src0->nb[1]); - float *y = (float *)((char *) src1->data + i1*src1->nb[1]); - float *dx = (float *)((char *) dst->data + i1*dst->nb[1]); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(dy[i])); - assert(!isnan(y[i])); - } -#endif - // Jii = yi - yi*yi - // Jij = -yi*yj - // J = diag(y)-y.T*y - // dx = J * dy - // dxk = sum_i(Jki * dyi) - // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk - // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk - // dxk = sum_i(-yk*yi * dyi) + yk*dyk - // dxk = -yk * sum_i(yi * dyi) + yk*dyk - // dxk = -yk * dot(y, dy) + yk*dyk - // dxk = yk * (- dot(y, dy) + dyk) - // dxk = yk * (dyk - dot(y, dy)) - // - // post-order: - // dot_y_dy := dot(y, dy) - // dx := dy - // dx := dx - dot_y_dy - // dx := dx * y - - // linear runtime, no additional memory - float dot_y_dy = 0; - ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1); - ggml_vec_cpy_f32 (nc, dx, dy); - ggml_vec_acc1_f32(nc, dx, -dot_y_dy); - ggml_vec_mul_f32 (nc, dx, dx, y); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(dx[i])); - assert(!isinf(dx[i])); - } -#endif - } -} - -static void ggml_compute_forward_soft_max_back( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_soft_max_back_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_clamp - -static void ggml_compute_forward_clamp_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - float min; - float max; - memcpy(&min, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - - GGML_ASSERT( nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - for (int j = ith; j < n; j += nth) { - float * dst_ptr = (float *) ((char *) dst->data + j*nb1); - float * src0_ptr = (float *) ((char *) src0->data + j*nb01); - - for (int i = 0; i < nc; i++) { - dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min); - } - } -} - -static void ggml_compute_forward_clamp( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_clamp_f32(params, dst); - } break; - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_Q8_K: - case GGML_TYPE_Q4_0_4_4: - case GGML_TYPE_Q4_0_4_8: - case GGML_TYPE_Q4_0_8_8: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_I64: - case GGML_TYPE_F64: - case GGML_TYPE_COUNT: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_rope - -static float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / MAX(0.001f, high - low); - return 1 - MIN(1, MAX(0, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); - } - *cos_theta = cosf(theta) * mscale; - *sin_theta = sinf(theta) * mscale; -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { - return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); -} - -static void ggml_rope_cache_init( - float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, - float * cache, float sin_sign, float theta_scale) { - // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - float theta = theta_base; - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; - rope_yarn( - theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] - ); - cache[i0 + 1] *= sin_sign; - - theta *= theta_scale; - } -} - -GGML_CALL void ggml_rope_yarn_corr_dims( - int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] -) { - // start and end correction dims - float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); - float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); - dims[0] = MAX(0, start); - dims[1] = MIN(n_dims - 1, end); -} - -static void ggml_compute_forward_rope_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const bool forward) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src2 = dst->src[2]; - - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - //const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - GGML_ASSERT(nb00 == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(dst); - - GGML_ASSERT(n_dims <= ne0); - GGML_ASSERT(n_dims % 2 == 0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - // row index used to determine which thread to use - int ir = 0; - - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - - const float * freq_factors = NULL; - if (src2 != NULL) { - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - - // backward process uses inverse rotation by cos and sin. - // cos and sin build a rotation matrix, where the inverse is the transpose. - // this essentially just switches the sign of sin. - const float sin_sign = forward ? 1.0f : -1.0f; - - const int32_t * pos = (const int32_t *) src1->data; - - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; - - float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (ir++ < ir0) continue; - if (ir > ir1) break; - - if (!is_neox) { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const float cos_theta = cache[i0 + 0]; - const float sin_theta = cache[i0 + 1]; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const int64_t ic = i0/2; - - const float cos_theta = cache[i0 + 0]; - const float sin_theta = cache[i0 + 1]; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } - } - - for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } - } - } -} - -// TODO: deduplicate f16/f32 code -static void ggml_compute_forward_rope_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const bool forward) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src2 = dst->src[2]; - - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - //const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - - GGML_TENSOR_UNARY_OP_LOCALS - - //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); - //printf("n_past = %d, ne2 = %d\n", n_past, ne2); - - GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); - - const int ith = params->ith; - const int nth = params->nth; - - const int nr = ggml_nrows(dst); - - GGML_ASSERT(n_dims <= ne0); - GGML_ASSERT(n_dims % 2 == 0); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - // row index used to determine which thread to use - int ir = 0; - - const float theta_scale = powf(freq_base, -2.0f/n_dims); - - float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; - - const float * freq_factors = NULL; - if (src2 != NULL) { - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - - // backward process uses inverse rotation by cos and sin. - // cos and sin build a rotation matrix, where the inverse is the transpose. - // this essentially just switches the sign of sin. - const float sin_sign = forward ? 1.0f : -1.0f; - - const int32_t * pos = (const int32_t *) src1->data; - - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; - - float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - - for (int64_t i1 = 0; i1 < ne1; i1++) { - if (ir++ < ir0) continue; - if (ir > ir1) break; - - if (!is_neox) { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const float cos_theta = cache[i0 + 0]; - const float sin_theta = cache[i0 + 1]; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[1]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } - } else { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { - const int64_t ic = i0/2; - - const float cos_theta = cache[i0 + 0]; - const float sin_theta = cache[i0 + 1]; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } - } - - for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } - } - } -} - -static void ggml_compute_forward_rope( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_rope_f16(params, dst, true); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_rope_f32(params, dst, true); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_rope_back - -static void ggml_compute_forward_rope_back( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_rope_f16(params, dst, false); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_rope_f32(params, dst, false); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_conv_transpose_1d - -static void ggml_compute_forward_conv_transpose_1d_f16_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (ith == 0) { - memset(params->wdata, 0, params->wsize); - - // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ne02 + i02] = src[i00]; - } - } - } - } - - // permute source data (src1) from (L x Cin) to (Cin x L) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; - ggml_fp16_t * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = GGML_FP32_TO_FP16(src[i10]); - } - } - } - - // need to zero dst since we are accumulating into it - memset(dst->data, 0, ggml_nbytes(dst)); - } - ggml_barrier(params->threadpool); - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - ggml_vec_dot_f16(ne02, &v, 0, - (ggml_fp16_t *) wdata_src + i1n, 0, - (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void ggml_compute_forward_conv_transpose_1d_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02; - - GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (ith == 0) { - memset(params->wdata, 0, params->wsize); - - // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) - { - float * const wdata = (float *) params->wdata + 0; - - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i01*ne00*ne02; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ne02 + i02] = src[i00]; - } - } - } - } - - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + nk; - float * dst_data = wdata; - - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne11 + i11] = src[i10]; - } - } - } - - // need to zero dst since we are accumulating into it - memset(dst->data, 0, ggml_nbytes(dst)); - } - ggml_barrier(params->threadpool); - - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - - // total rows in dst - const int nr = ne1; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float * const wdata = (float *) params->wdata + 0; - float * const wdata_src = wdata + nk; - - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - float * wdata_kernel = wdata + i1*ne02*ne00; - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i10*ne11; - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - ggml_vec_dot_f32(ne02, &v, 0, - wdata_src + i1n, 0, - wdata_kernel + i00*ne02, 0, 1); - dst_data[i10*s0 + i00] += v; - } - } - } -} - -static void ggml_compute_forward_conv_transpose_1d( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_conv_transpose_1d_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_im2col_f32 -// src0: kernel [OC, IC, KH, KW] -// src1: image [N, IC, IH, IW] -// dst: result [N, OH, OW, IC*KH*KW] -static void ggml_compute_forward_im2col_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t N = is_2D ? ne13 : ne12; - const int64_t IC = is_2D ? ne12 : ne11; - const int64_t IH = is_2D ? ne11 : 1; - const int64_t IW = ne10; - - const int64_t KH = is_2D ? ne01 : 1; - const int64_t KW = ne00; - - const int64_t OH = is_2D ? ne2 : 1; - const int64_t OW = ne1; - - int ofs0 = is_2D ? nb13 : nb12; - int ofs1 = is_2D ? nb12 : nb11; - - GGML_ASSERT(nb10 == sizeof(float)); - - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] - { - float * const wdata = (float *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 - for (int64_t iow = 0; iow < OW; iow++) { - for (int64_t iic = ith; iic < IC; iic += nth) { - - // micro kernel - float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] - - for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 - for (int64_t ikw = 0; ikw < KW; ikw++) { - const int64_t iiw = iow*s0 + ikw*d0 - p0; - const int64_t iih = ioh*s1 + ikh*d1 - p1; - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; - } else { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]); - } - } - } - } - } - } - } - } -} - - -// ggml_compute_forward_im2col_f16 -// src0: kernel [OC, IC, KH, KW] -// src1: image [N, IC, IH, IW] -// dst: result [N, OH, OW, IC*KH*KW] -static void ggml_compute_forward_im2col_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F16); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t N = is_2D ? ne13 : ne12; - const int64_t IC = is_2D ? ne12 : ne11; - const int64_t IH = is_2D ? ne11 : 1; - const int64_t IW = ne10; - - const int64_t KH = is_2D ? ne01 : 1; - const int64_t KW = ne00; - - const int64_t OH = is_2D ? ne2 : 1; - const int64_t OW = ne1; - - int ofs0 = is_2D ? nb13 : nb12; - int ofs1 = is_2D ? nb12 : nb11; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t ioh = 0; ioh < OH; ioh++) { // 1 - for (int64_t iow = 0; iow < OW; iow++) { - for (int64_t iic = ith; iic < IC; iic += nth) { - - // micro kernel - ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] - - for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 - for (int64_t ikw = 0; ikw < KW; ikw++) { - const int64_t iiw = iow*s0 + ikw*d0 - p0; - const int64_t iih = ioh*s1 + ikh*d1 - p1; - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; - } else { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); - } - } - } - } - } - } - } - } -} - -static void ggml_compute_forward_im2col( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - switch (dst->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_im2col_f16(params, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_im2col_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_im2col_back_f32 - -static void ggml_compute_forward_im2col_back_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_TENSOR_BINARY_OP_LOCALS; - - const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t *)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t *)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t *)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1; - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t N = is_2D ? ne3 : ne2; - const int64_t IC = is_2D ? ne2 : ne1; - const int64_t IH = is_2D ? ne1 : 1; - const int64_t IW = ne0; - - const int64_t KH = is_2D ? ne01 : 1; - const int64_t KW = ne00; - - const int64_t OH = is_2D ? ne12 : 1; - const int64_t OW = ne11; - - int ofs0 = is_2D ? nb3 : nb2; - int ofs1 = is_2D ? nb2 : nb1; - - GGML_ASSERT(nb0 == sizeof(float)); - - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] - { - float * const wdata = (float *) dst->data; - - for (int64_t in = 0; in < N; in++) { - for (int64_t iic = ith; iic < IC; iic += nth) { - for (int64_t iih = 0; iih < IH; iih++) { - for (int64_t iiw = 0; iiw < IW; iiw++) { - - // micro kernel - float grad = 0.0f; - for (int64_t ikh = 0; ikh < KH; ikh++) { - for (int64_t ikw = 0; ikw < KW; ikw++) { - // For s0 > 1 some values were skipped over in the forward pass. - // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well. - const int64_t tmpw = (iiw + p0 - ikw*d0); - if (tmpw % s0 != 0) { - continue; - } - const int64_t iow = tmpw / s0; - - // Equivalent logic as above except for s1. - int64_t ioh; - if (is_2D) { - const int64_t tmph = iih + p1 - ikh*d1; - - if (tmph % s1 != 0) { - continue; - } - - ioh = tmph / s1; - } else { - ioh = 0; - } - - if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) { - continue; - } - - const float * const src_data = (const float *) src1->data - + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - grad += src_data[iic*(KH*KW) + ikh*KW + ikw]; - } - } - float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW] - dst_data[iih*IW + iiw] = grad; - } - } - } - } - } -} - -// ggml_compute_forward_conv_transpose_2d - -static void ggml_compute_forward_conv_transpose_2d( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const int nk = ne00*ne01*ne02*ne03; - - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); - - if (ith == 0) { - memset(params->wdata, 0, params->wsize); - - // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); - ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; - for (int64_t i01 = 0; i01 < ne01; i01++) { - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; - } - } - } - } - } - - // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) - { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; - for (int i12 = 0; i12 < ne12; i12++) { - for (int i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); - ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; - for (int i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne12 + i12] = GGML_FP32_TO_FP16(src[i10]); - } - } - } - } - - memset(dst->data, 0, ggml_nbytes(dst)); - } - ggml_barrier(params->threadpool); - - const int32_t stride = ggml_get_op_params_i32(dst, 0); - - // total patches in dst - const int np = ne2; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - const int ip0 = dp*ith; - const int ip1 = MIN(ip0 + dp, np); - - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; - - for (int i2 = ip0; i2 < ip1; i2++) { // Cout - float * dst_data = (float *)((char *) dst->data + i2*nb2); - ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; - for (int i11 = 0; i11 < ne11; i11++) { - for (int i10 = 0; i10 < ne10; i10++) { - const int i1n = i11*ne10*ne12 + i10*ne12; - for (int i01 = 0; i01 < ne01; i01++) { - for (int i00 = 0; i00 < ne00; i00++) { - float v = 0; - ggml_vec_dot_f16(ne03, &v, 0, - wdata_src + i1n, 0, - wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); - dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; - } - } - } - } - } -} - -// ggml_compute_forward_pool_1d_sk_p0 - -static void ggml_compute_forward_pool_1d_sk_p0( - const struct ggml_compute_params * params, - const enum ggml_op_pool op, - const int k, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src = dst->src[0]; - - assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); - - if (params->ith != 0) { - return; - } - - const char * cdata = (const char *)src->data; - const char * const data_end = cdata + ggml_nbytes(src); - float * drow = (float *)dst->data; - - const int64_t rs = dst->ne[0]; - - while (cdata < data_end) { - const void * srow = (const void *)cdata; - int j = 0; - for (int64_t i = 0; i < rs; ++i) { - switch (op) { - case GGML_OP_POOL_AVG: drow[i] = 0; break; - case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); - } - for (int ki = 0; ki < k; ++ki) { - const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); - switch (op) { - case GGML_OP_POOL_AVG: drow[i] += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); - } - ++j; - } - switch (op) { - case GGML_OP_POOL_AVG: drow[i] /= k; break; - case GGML_OP_POOL_MAX: break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); - } - } - - cdata += src->nb[1]; - drow += rs; - } -} - -// ggml_compute_forward_pool_1d - -static void ggml_compute_forward_pool_1d( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const int32_t * opts = (const int32_t *)dst->op_params; - enum ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int s0 = opts[2]; - const int p0 = opts[3]; - GGML_ASSERT(p0 == 0); // padding not supported - GGML_ASSERT(k0 == s0); // only s = k supported - - ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst); -} - -// ggml_compute_forward_pool_2d - -static void ggml_compute_forward_pool_2d( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src = dst->src[0]; - - assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); - - if (params->ith != 0) { - return; - } - - const int32_t * opts = (const int32_t *)dst->op_params; - enum ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int k1 = opts[2]; - const int s0 = opts[3]; - const int s1 = opts[4]; - const int p0 = opts[5]; - const int p1 = opts[6]; - const char * cdata = (const char*)src->data; - const char * const data_end = cdata + ggml_nbytes(src); - - const int64_t px = dst->ne[0]; - const int64_t py = dst->ne[1]; - const int64_t pa = px * py; - - float * dplane = (float *)dst->data; - - const int ka = k0 * k1; - const int offset0 = -p0; - const int offset1 = -p1; - - while (cdata < data_end) { - for (int oy = 0; oy < py; ++oy) { - float * const drow = dplane + oy * px; - for (int ox = 0; ox < px; ++ox) { - float * const out = drow + ox; - switch (op) { - case GGML_OP_POOL_AVG: *out = 0; break; - case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); - } - - const int ix = offset0 + ox * s0; - const int iy = offset1 + oy * s1; - - for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; - const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); - for (int kx = 0; kx < k0; ++kx) { - int j = ix + kx; - if (j < 0 || j >= src->ne[0]) continue; - const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); - switch (op) { - case GGML_OP_POOL_AVG: *out += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); - } - } - } - switch (op) { - case GGML_OP_POOL_AVG: *out /= ka; break; - case GGML_OP_POOL_MAX: break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); - } - } - } - - cdata += src->nb[2]; - dplane += pa; - } -} - -// ggml_compute_forward_pool_2d_back - -static void ggml_compute_forward_pool_2d_back( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src = dst->src[0]; - const struct ggml_tensor * dstf = dst->src[1]; // forward tensor of dst - - assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - - if (params->ith != 0) { - return; - } - - const int32_t * opts = (const int32_t *)dst->op_params; - enum ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int k1 = opts[2]; - const int s0 = opts[3]; - const int s1 = opts[4]; - const int p0 = opts[5]; - const int p1 = opts[6]; - - char * cdata = (char *) dst->data; - const char * cdataf = (const char *) dstf->data; - const char * const data_end = cdata + ggml_nbytes(dst); - - GGML_ASSERT(params->ith == 0); - memset(cdata, 0, ggml_nbytes(dst)); - - const int64_t px = src->ne[0]; - const int64_t py = src->ne[1]; - const int64_t pa = px * py; - - const float * splane = (const float *) src->data; - - const int ka = k0 * k1; - const int offset0 = -p0; - const int offset1 = -p1; - - while (cdata < data_end) { - for (int oy = 0; oy < py; ++oy) { - const float * const srow = splane + oy * px; - for (int ox = 0; ox < px; ++ox) { - const float grad0 = srow[ox]; - - const int ix = offset0 + ox * s0; - const int iy = offset1 + oy * s1; - - if (op == GGML_OP_POOL_MAX) { - float maxval = -FLT_MAX; - int kxmax = -1; - int kymax = -1; - - for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= dst->ne[1]) { - continue; - } - const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky)); - for (int kx = 0; kx < k0; ++kx) { - int j = ix + kx; - if (j < 0 || j >= dst->ne[0]) { - continue; - } - - const float val = dst->type == GGML_TYPE_F32 ? - ((const float *) drowf)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]); - if (val <= maxval) { - continue; - } - - maxval = val; - kxmax = kx; - kymax = ky; - } - } - - if (kxmax == -1 || kymax == -1) { - continue; - } - - void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax)); - const int j = ix + kxmax; - if (dst->type == GGML_TYPE_F32) { - ((float *) drow)[j] += grad0; - } else { - ((ggml_fp16_t *) drow)[j] = GGML_FP32_TO_FP16(grad0 + GGML_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j])); - } - } else if (op == GGML_OP_POOL_AVG) { - const float grad = grad0 / ka; - - for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= dst->ne[1]) { - continue; - } - void * drow = (void *)(cdata + dst->nb[1] * (iy + ky)); - for (int kx = 0; kx < k0; ++kx) { - int j = ix + kx; - if (j < 0 || j >= dst->ne[0]) { - continue; - } - - if (dst->type == GGML_TYPE_F32) { - ((float *) drow)[j] += grad; - } else { - ((ggml_fp16_t *) drow)[j] += GGML_FP32_TO_FP16(grad); - } - } - } - } else { - GGML_ASSERT(false); - } - } - } - - cdata += dst->nb[2]; - cdataf += dst->nb[2]; - splane += pa; - } -} - -// ggml_compute_forward_upscale - -static void ggml_compute_forward_upscale_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - const float sf0 = (float)ne0/src0->ne[0]; - const float sf1 = (float)ne1/src0->ne[1]; - const float sf2 = (float)ne2/src0->ne[2]; - const float sf3 = (float)ne3/src0->ne[3]; - - // TODO: optimize - - for (int64_t i3 = 0; i3 < ne3; i3++) { - const int64_t i03 = i3 / sf3; - for (int64_t i2 = ith; i2 < ne2; i2 += nth) { - const int64_t i02 = i2 / sf2; - for (int64_t i1 = 0; i1 < ne1; i1++) { - const int64_t i01 = i1 / sf1; - for (int64_t i0 = 0; i0 < ne0; i0++) { - const int64_t i00 = i0 / sf0; - - const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); - - *y = *x; - } - } - } - } -} - -static void ggml_compute_forward_upscale( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_upscale_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - - -// ggml_compute_forward_pad - -static void ggml_compute_forward_pad_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT( dst->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float * dst_ptr = (float *) dst->data; - - // TODO: optimize - - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = ith; i1 < ne1; i1 += nth) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - for (int64_t i3 = 0; i3 < ne3; ++i3) { - const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; - - const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - dst_ptr[dst_idx] = *src_ptr; - } else { - dst_ptr[dst_idx] = 0; - } - } - } - } - } -} - -static void ggml_compute_forward_pad( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_pad_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - - -// ggml_compute_forward_arange - -static void ggml_compute_forward_arange_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - GGML_ASSERT(dst->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const float start = ggml_get_op_params_f32(dst, 0); - const float stop = ggml_get_op_params_f32(dst, 1); - const float step = ggml_get_op_params_f32(dst, 2); - - const int64_t steps = (int64_t) ceilf((stop - start) / step); - - GGML_ASSERT(ggml_nelements(dst) == steps); - - for (int64_t i = ith; i < steps; i+= nth) { - float value = start + step * i; - ((float *)dst->data)[i] = value; - } -} - -static void ggml_compute_forward_arange( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - switch (dst->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_arange_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -static void ggml_compute_forward_timestep_embedding_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - const int dim = ggml_get_op_params_i32(dst, 0); - const int max_period = ggml_get_op_params_i32(dst, 1); - - int half = dim / 2; - - for (int64_t i = 0; i < ne00; i++) { - float * embed_data = (float *)((char *) dst->data + i*nb1); - for (int64_t j = ith; j < half; j += nth) { - float timestep = ((float *)src0->data)[i]; - float freq = (float)expf(-logf(max_period) * j / half); - float arg = timestep * freq; - embed_data[j] = cosf(arg); - embed_data[j + half] = sinf(arg); - } - if (dim % 2 != 0 && ith == 0) { - embed_data[dim] = 0.f; - } - } -} - -static void ggml_compute_forward_timestep_embedding( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_timestep_embedding_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_argsort - -static void ggml_compute_forward_argsort_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_TENSOR_UNARY_OP_LOCALS - - GGML_ASSERT(nb0 == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nr = ggml_nrows(src0); - - enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0); - - for (int64_t i = ith; i < nr; i += nth) { - int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); - const float * src_data = (float *)((char *) src0->data + i*nb01); - - for (int64_t j = 0; j < ne0; j++) { - dst_data[j] = j; - } - - // C doesn't have a functional sort, so we do a bubble sort instead - for (int64_t j = 0; j < ne0; j++) { - for (int64_t k = j + 1; k < ne0; k++) { - if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || - (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { - int32_t tmp = dst_data[j]; - dst_data[j] = dst_data[k]; - dst_data[k] = tmp; - } - } - } - } -} - -static void ggml_compute_forward_argsort( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_argsort_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_flash_attn_ext - -static void ggml_compute_forward_flash_attn_ext_f16( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * mask, - struct ggml_tensor * dst) { - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - - GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne2 == N); - - // input tensor rows must be contiguous - GGML_ASSERT(nbq0 == ggml_type_size(q->type)); - GGML_ASSERT(nbk0 == ggml_type_size(k->type)); - GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev0 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nev0 == D); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - // broadcast factors - const int64_t rk2 = neq2/nek2; - const int64_t rk3 = neq3/nek3; - - const int64_t rv2 = neq2/nev2; - const int64_t rv3 = neq3/nev3; - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); - - if (logit_softcap != 0) { - scale /= logit_softcap; - } - - const uint32_t n_head = neq2; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type; - ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float; - ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; - ggml_to_float_t const v_to_float = type_traits[v->type].to_float; - - // loop over n_batch and n_head - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - const uint32_t h = iq2; // head index - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - - float S = 0.0f; // sum - float M = -INFINITY; // maximum KQ value - - float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 - - if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); - } else { - memset(VKQ32, 0, D*sizeof(float)); - } - - const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; - - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - - // v indices - const int iv3 = iq3 / rv3; - const int iv2 = iq2 / rv2; - - const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, D); - - // online softmax / attention - // loop over n_kv and n_head_kv - // ref: https://arxiv.org/pdf/2112.05682.pdf - for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; - if (mv == -INFINITY) { - continue; - } - - float s; // KQ value - - const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - - s = s*scale; // scale KQ value - - if (logit_softcap != 0.0f) { - s = logit_softcap*tanhf(s); - } - - s += mv; // apply mask - - const float Mold = M; - - float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value - float vs = 1.0f; // post-softmax KQ value, expf(s - M) - - const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - - if (v->type == GGML_TYPE_F16) { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, VKQ16, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); - } - - // V += v*expf(s - M) - ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); - } else { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - ggml_vec_scale_f32(D, VKQ32, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); - } - - v_to_float(v_data, V32, D); - - // V += v*expf(s - M) - ggml_vec_mad_f32(D, VKQ32, V32, vs); - } - - S = S*ms + vs; // scale and increment sum with partial sum - } - - if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < D; ++d) { - VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); - } - } - - // V /= S - const float S_inv = 1.0f/S; - ggml_vec_scale_f32(D, VKQ32, S_inv); - - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // original - //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - - // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); - } -} - -static void ggml_compute_forward_flash_attn_ext( - const struct ggml_compute_params * params, - const struct ggml_tensor * q, - const struct ggml_tensor * k, - const struct ggml_tensor * v, - const struct ggml_tensor * mask, - struct ggml_tensor * dst) { - switch (dst->op_params[3]) { - case GGML_PREC_DEFAULT: - case GGML_PREC_F32: - { - // uses F32 accumulators - ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_flash_attn_back - -static void ggml_compute_forward_flash_attn_back_f32( - const struct ggml_compute_params * params, - const bool masked, - struct ggml_tensor * dst) { - - const struct ggml_tensor * q = dst->src[0]; - const struct ggml_tensor * k = dst->src[1]; - const struct ggml_tensor * v = dst->src[2]; - const struct ggml_tensor * d = dst->src[3]; - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ned, d, ne) - GGML_TENSOR_LOCALS(size_t, nbd, d, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - const int mxDM = MAX(D, Mup); - - // GGML_ASSERT(ne0 == D); - // GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(float)); - GGML_ASSERT(nbk0 == sizeof(float)); - GGML_ASSERT(nbv0 == sizeof(float)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - GGML_ASSERT(ned0 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - GGML_ASSERT(ned1 == N); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (ith == 0) { - memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); - } - ggml_barrier(params->threadpool); - - const int64_t elem_q = ggml_nelements(q); - const int64_t elem_k = ggml_nelements(k); - - enum ggml_type result_type = dst->type; - GGML_ASSERT(ggml_blck_size(result_type) == 1); - const size_t tsize = ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); - const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); - - void * grad_q = (char *) dst->data; - void * grad_k = (char *) dst->data + offs_k; - void * grad_v = (char *) dst->data + offs_v; - - const size_t nbgq1 = nb0*neq0; - const size_t nbgq2 = nb0*neq0*neq1; - const size_t nbgq3 = nb0*neq0*neq1*neq2; - - const size_t nbgk1 = nb0*nek0; - const size_t nbgk2 = nb0*nek0*nek1; - const size_t nbgk3 = nb0*nek0*nek1*neq2; - - const size_t nbgv1 = nb0*nev0; - const size_t nbgv2 = nb0*nev0*nev1; - const size_t nbgv3 = nb0*nev0*nev1*neq2; - - // parallelize by k rows using ggml_vec_dot_f32 - - // total rows in k - const int nr = nek2*nek3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - // how often k2 (and v2) is repeated in q2 - int nrep = neq2/nek2; - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int ik3 = ir/(nek2); - const int ik2 = ir - ik3*nek2; - - const int iq3 = ik3; - const int id3 = ik3; - const int iv3 = ik3; - const int iv2 = ik2; - - for (int irep = 0; irep < nrep; ++irep) { - const int iq2 = ik2 + irep*nek2; - const int id2 = iq2; - - // (ik2 + irep*nek2) % nek2 == ik2 - for (int iq1 = 0; iq1 < neq1; ++iq1) { - const int id1 = iq1; - - // not sure about CACHE_LINE_SIZE_F32.. - // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? - float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); - float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - const int64_t masked_begin = masked ? (P + iq1 + 1) : M; - for (int64_t ic = 0; ic < masked_begin; ++ic) { - // k indices - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f32(neq0, - S + i1, 0, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1); - } - - // scale - ggml_vec_scale_f32(masked_begin, S, scale); - - for (int64_t i = masked_begin; i < M; i++) { - S[i] = -INFINITY; - } - - // softmax - // exclude known -INF S[..] values from max and loop - // dont forget to set their SM values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(masked_begin, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(SM, 1, &max, SM, 1, Mup); - vvexpf(SM, SM, &Mup); - ggml_vec_sum_f32(Mup, &sum, SM); -#else - sum = ggml_vec_soft_max_f32(Mup, SM, S, max); -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(masked_begin, SM, sum); - - } - - // step-by-step explanation - { - // forward-process shape grads from backward process - // parallel_for ik2,ik3: - // for irep: - // iq2 = ik2 + irep*nek2 - // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur] - // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] - // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur] - // for iq1: - // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur - // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur - // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 - // S0 = -Inf [D,1,1,1] - // ~S1[i] = dot(kcur[:D,i], qcur) - // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale - // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) - // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur - // ~S5[i] = dot(vcur[:,i], S4) - // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3] - // ~dst[i,iq1,iq2,iq3] = S5[i] ^ - // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3] - // dst backward-/ grad[dst] = d - // - // output gradients with their dependencies: - // - // grad[kcur] = grad[S1].T @ qcur - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S4] = grad[S5] @ vcur - // grad[S4] = d[:D,id1,id2,id3] @ vcur - // grad[qcur] = grad[S1] @ kcur - // grad[vcur] = grad[S5].T @ S4 - // grad[vcur] = d[:D,id1,id2,id3].T @ S4 - // - // in post-order: - // - // S1 = qcur @ kcur.T - // S2 = S1 * scale - // S3 = diag_mask_inf(S2, P) - // S4 = softmax(S3) - // grad[S4] = d[:D,id1,id2,id3] @ vcur - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[qcur] = grad[S1] @ kcur - // grad[kcur] = grad[S1].T @ qcur - // grad[vcur] = d[:D,id1,id2,id3].T @ S4 - // - // using less variables (SM=S4): - // - // S = diag_mask_inf(qcur @ kcur.T * scale, P) - // SM = softmax(S) - // S = d[:D,iq1,iq2,iq3] @ vcur - // dot_SM_gradSM = dot(SM, S) - // S = SM * (S - dot(SM, S)) - // S = diag_mask_zero(S, P) * scale - // - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[k][:D,:M,ik2,ik3] += S.T @ qcur - // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM - } - - // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] - // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] - // for ic: - // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] - // exclude known future zero S[..] values from operation - ggml_vec_set_f32(masked_begin, S, 0); - for (int64_t ic = 0; ic < D; ++ic) { - ggml_vec_mad_f32(masked_begin, - S, - (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); - } - - // S = SM * (S - dot(SM, S)) - float dot_SM_gradSM = 0; - ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1); - ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); - ggml_vec_mul_f32 (masked_begin, S, S, SM); - - // S = diag_mask_zero(S, P) * scale - // already done by above ggml_vec_set_f32 - - // exclude known zero S[..] values from operation - ggml_vec_scale_f32(masked_begin, S, scale); - - // S shape [M,1] - // SM shape [M,1] - // kcur shape [D,M] - // qcur shape [D,1] - // vcur shape [M,D] - - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] - // for ic: - // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] - // exclude known zero S[..] values from loop - for (int64_t ic = 0; ic < masked_begin; ++ic) { - ggml_vec_mad_f32(D, - (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), - (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), - S[ic]); - } - - // grad[k][:D,:M,iq2,iq3] += S.T @ qcur - // for ic: - // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] - // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] - // exclude known zero S[..] values from loop - for (int64_t ic = 0; ic < masked_begin; ++ic) { - ggml_vec_mad_f32(D, - (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), - S[ic]); - } - - // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM - // for ic: - // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] - // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] - // exclude known zero SM[..] values from mad - for (int64_t ic = 0; ic < D; ++ic) { - ggml_vec_mad_f32(masked_begin, - (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), - SM, - *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); - } - } - } - } -} - -static void ggml_compute_forward_flash_attn_back( - const struct ggml_compute_params * params, - const bool masked, - struct ggml_tensor * dst) { - - const struct ggml_tensor * q = dst->src[0]; - - switch (q->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_flash_attn_back_f32(params, masked, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_ssm_conv - -static void ggml_compute_forward_ssm_conv_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // conv_x - const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight - - const int ith = params->ith; - const int nth = params->nth; - - const int nc = src1->ne[0]; // d_conv - const int ncs = src0->ne[0]; // d_conv - 1 + n_t - const int nr = src0->ne[1]; // d_inner - const int n_t = dst->ne[1]; // tokens per sequence - const int n_s = dst->ne[2]; // number of sequences in the batch - - GGML_ASSERT( dst->ne[0] == nr); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; - - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - // {d_conv - 1 + n_t, d_inner, n_seqs} - // sliding window - const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s} - const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner} - float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s} - - // TODO: transpose the output for smaller strides for big batches? - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // rowwise dot product - // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision - float sumf = 0.0f; - - // d_conv - for (int i0 = 0; i0 < nc; ++i0) { - sumf += s[i0 + i1*ncs] * c[i0 + i1*nc]; - } - x[i1] = sumf; - } - } - } -} - -static void ggml_compute_forward_ssm_conv( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - switch (dst->src[0]->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_ssm_conv_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_ssm_scan - -static void ggml_compute_forward_ssm_scan_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // dt - const struct ggml_tensor * src3 = dst->src[3]; // A - const struct ggml_tensor * src4 = dst->src[4]; // B - const struct ggml_tensor * src5 = dst->src[5]; // C - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch - - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT(src1->nb[0] == sizeof(float)); - GGML_ASSERT(src2->nb[0] == sizeof(float)); - GGML_ASSERT(src3->nb[0] == sizeof(float)); - GGML_ASSERT(src4->nb[0] == sizeof(float)); - GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; - - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } - - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; - } - y[i1] = sumf; - } - } - } -} - -static void ggml_compute_forward_ssm_scan( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - switch (dst->src[0]->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_ssm_scan_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_win_part - -static void ggml_compute_forward_win_part_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - UNUSED(params); - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - - const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; - const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; - const int32_t w = ((const int32_t *)(dst->op_params))[2]; - - assert(ne00 == ne0); - assert(ne3 == nep0*nep1); - - // TODO: optimize / multi-thread - for (int py = 0; py < nep1; ++py) { - for (int px = 0; px < nep0; ++px) { - const int64_t i3 = py*nep0 + px; - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int64_t i02 = py*w + i2; - const int64_t i01 = px*w + i1; - const int64_t i00 = i0; - - const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0; - const int64_t j = i02*ne01*ne00 + i01*ne00 + i00; - - if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { - ((float *) dst->data)[i] = 0.0f; - } else { - ((float *) dst->data)[i] = ((float *) src0->data)[j]; - } - } - } - } - } - } -} - -static void ggml_compute_forward_win_part( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_win_part_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_win_unpart - -static void ggml_compute_forward_win_unpart_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - UNUSED(params); - - const struct ggml_tensor * src0 = dst->src[0]; - - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - - const int32_t w = ((const int32_t *)(dst->op_params))[0]; - - // padding - const int px = (w - ne1%w)%w; - //const int py = (w - ne2%w)%w; - - const int npx = (px + ne1)/w; - //const int npy = (py + ne2)/w; - - assert(ne0 == ne00); - - // TODO: optimize / multi-thread - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - for (int64_t i0 = 0; i0 < ne0; ++i0) { - const int ip2 = i2/w; - const int ip1 = i1/w; - - const int64_t i02 = i2%w; - const int64_t i01 = i1%w; - const int64_t i00 = i0; - - const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00; - const int64_t j = i2*ne1*ne0 + i1*ne0 + i0; - - ((float *) dst->data)[j] = ((float *) src0->data)[i]; - } - } - } -} - -static void ggml_compute_forward_win_unpart( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_win_unpart_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -//gmml_compute_forward_unary - -static void ggml_compute_forward_unary( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const enum ggml_unary_op op = ggml_get_unary_op(dst); - - switch (op) { - case GGML_UNARY_OP_ABS: - { - ggml_compute_forward_abs(params, dst); - } break; - case GGML_UNARY_OP_SGN: - { - ggml_compute_forward_sgn(params, dst); - } break; - case GGML_UNARY_OP_NEG: - { - ggml_compute_forward_neg(params, dst); - } break; - case GGML_UNARY_OP_STEP: - { - ggml_compute_forward_step(params, dst); - } break; - case GGML_UNARY_OP_TANH: - { - ggml_compute_forward_tanh(params, dst); - } break; - case GGML_UNARY_OP_ELU: - { - ggml_compute_forward_elu(params, dst); - } break; - case GGML_UNARY_OP_RELU: - { - ggml_compute_forward_relu(params, dst); - } break; - case GGML_UNARY_OP_SIGMOID: - { - ggml_compute_forward_sigmoid(params, dst); - } break; - case GGML_UNARY_OP_GELU: - { - ggml_compute_forward_gelu(params, dst); - } break; - case GGML_UNARY_OP_GELU_QUICK: - { - ggml_compute_forward_gelu_quick(params, dst); - } break; - case GGML_UNARY_OP_SILU: - { - ggml_compute_forward_silu(params, dst); - } break; - case GGML_UNARY_OP_HARDSWISH: - { - ggml_compute_forward_hardswish(params, dst); - } break; - case GGML_UNARY_OP_HARDSIGMOID: - { - ggml_compute_forward_hardsigmoid(params, dst); - } break; - case GGML_UNARY_OP_EXP: - { - ggml_compute_forward_exp(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_get_rel_pos - -static void ggml_compute_forward_get_rel_pos_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - UNUSED(params); - - const struct ggml_tensor * src0 = dst->src[0]; - - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 - - GGML_TENSOR_UNARY_OP_LOCALS - - const int64_t w = ne1; - - ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; - ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; - - for (int64_t i2 = 0; i2 < ne2; ++i2) { - for (int64_t i1 = 0; i1 < ne1; ++i1) { - const int64_t pos = (w - i1 - 1) + i2; - for (int64_t i0 = 0; i0 < ne0; ++i0) { - dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; - } - } - } -} - -static void ggml_compute_forward_get_rel_pos( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F16: - case GGML_TYPE_BF16: - { - ggml_compute_forward_get_rel_pos_f16(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_add_rel_pos - -static void ggml_compute_forward_add_rel_pos_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src2 = dst->src[2]; - - const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; - if (!inplace) { - if (params->ith == 0) { - memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); - } - ggml_barrier(params->threadpool); - } - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 - - float * src1_data = (float *) src1->data; - float * src2_data = (float *) src2->data; - float * dst_data = (float *) dst->data; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const int ith = params->ith; - const int nth = params->nth; - - // total patches in dst - const int np = ne13; - - // patches per thread - const int dp = (np + nth - 1)/nth; - - // patch range for this thread - const int ip0 = dp*ith; - const int ip1 = MIN(ip0 + dp, np); - - for (int64_t i13 = ip0; i13 < ip1; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10; - for (int64_t i10 = 0; i10 < ne10; ++i10) { - const int64_t jp0 = jp1 + i10; - const float src1_e = src1_data[jp0]; - const float src2_e = src2_data[jp0]; - - const int64_t jdh = jp0 * ne10; - const int64_t jdw = jdh - (ne10 - 1) * i10; - - for (int64_t j = 0; j < ne10; ++j) { - dst_data[jdh + j ] += src2_e; - dst_data[jdw + j*ne10] += src1_e; - } - } - } - } - } -} - -static void ggml_compute_forward_add_rel_pos( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_add_rel_pos_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_rwkv_wkv - -static void ggml_compute_forward_rwkv_wkv_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - const size_t T = dst->src[1]->ne[3]; - const size_t C = dst->ne[0]; - const size_t H = dst->src[1]->ne[2]; - const size_t n_seqs = dst->src[5]->ne[1]; - - float * dst_data = (float *) dst->data; - float * state = ((float *) dst->data) + C * T; - - if (params->ith != 0) { - return; - } - - memset(dst_data, 0, T * C * sizeof(float)); - - float * k = (float *) dst->src[0]->data; - float * v = (float *) dst->src[1]->data; - float * r = (float *) dst->src[2]->data; - float * time_faaaa = (float *) dst->src[3]->data; - float * time_decay = (float *) dst->src[4]->data; - - size_t t_stride = H * (C / H); - - size_t h_stride = C / H; - size_t h_stride_2d = (C / H) * (C / H); - - // basically fused operations: - // dst = r @ (time_faaaa * (k @ v) + state), - // state = time_decay * state + (k @ v), - // recursive through each token - for (size_t t = 0; t < T; t++) { - size_t t_offset = t * t_stride; - size_t state_offset = (C / H) * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset; - - for (size_t h = 0; h < H; h++) { - size_t h_offset = h * h_stride; - size_t t_h_offset = t_offset + h_offset; - size_t h_2d_offset = h * h_stride_2d; - - for (size_t i = 0; i < C / H; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_i_offset = h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; - - float k_val = k[t_h_i_offset]; - float r_val = r[t_h_i_offset]; - float time_faaaa_val = time_faaaa[h_i_offset]; - // RWKV v6: different time_decay for each token. - float time_decay_val = time_decay[t_h_i_offset]; - - for (size_t j = 0; j < C / H; j ++) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - float v_val = v[t_h_j_offset]; - float kv_val = v_val * k_val; - float prev_state_val = state_prev[h_2d_i_j_offset]; - float temp_val = kv_val * time_faaaa_val + prev_state_val; - dst_data[t_h_j_offset] += temp_val * r_val; - state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val; - } - } - } - } -} - -static void ggml_compute_forward_rwkv_wkv( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_rwkv_wkv_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_map_unary - -static void ggml_compute_forward_map_unary_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - - const struct ggml_tensor * src0 = dst->src[0]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1]))); - } -} - -static void ggml_compute_forward_map_unary( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_unary_op_f32_t fun) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_unary_f32(params, dst, fun); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_map_binary - -static void ggml_compute_forward_map_binary_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - if (params->ith != 0) { - return; - } - - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(src1)); - assert(ggml_is_contiguous_1(dst)); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int n = ggml_nrows(src0); - const int nc = src0->ne[0]; - - for (int i = 0; i < n; i++) { - fun(nc, - (float *) ((char *) dst->data + i*( dst->nb[1])), - (float *) ((char *) src0->data + i*(src0->nb[1])), - (float *) ((char *) src1->data + i*(src1->nb[1]))); - } -} - -static void ggml_compute_forward_map_binary( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_binary_op_f32_t fun) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_map_binary_f32(params, dst, fun); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_map_custom1 - -static void ggml_compute_forward_map_custom1_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom1_op_f32_t fun) { - - const struct ggml_tensor * a = dst->src[0]; - - if (params->ith != 0) { - return; - } - - fun(dst, a); -} - -// ggml_compute_forward_map_custom2 - -static void ggml_compute_forward_map_custom2_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom2_op_f32_t fun) { - - const struct ggml_tensor * a = dst->src[0]; - const struct ggml_tensor * b = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b); -} - -// ggml_compute_forward_map_custom3 - -static void ggml_compute_forward_map_custom3_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - const ggml_custom3_op_f32_t fun) { - - const struct ggml_tensor * a = dst->src[0]; - const struct ggml_tensor * b = dst->src[1]; - const struct ggml_tensor * c = dst->src[1]; - - if (params->ith != 0) { - return; - } - - fun(dst, a, b, c); -} - -// ggml_compute_forward_map_custom1 - -static void ggml_compute_forward_map_custom1( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * a = dst->src[0]; - - struct ggml_map_custom1_op_params p; - memcpy(&p, dst->op_params, sizeof(p)); - - p.fun(dst, a, params->ith, params->nth, p.userdata); -} - -// ggml_compute_forward_map_custom2 - -static void ggml_compute_forward_map_custom2( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * a = dst->src[0]; - const struct ggml_tensor * b = dst->src[1]; - - struct ggml_map_custom2_op_params p; - memcpy(&p, dst->op_params, sizeof(p)); - - p.fun(dst, a, b, params->ith, params->nth, p.userdata); -} - -// ggml_compute_forward_map_custom3 - -static void ggml_compute_forward_map_custom3( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * a = dst->src[0]; - const struct ggml_tensor * b = dst->src[1]; - const struct ggml_tensor * c = dst->src[2]; - - struct ggml_map_custom3_op_params p; - memcpy(&p, dst->op_params, sizeof(p)); - - p.fun(dst, a, b, c, params->ith, params->nth, p.userdata); -} - -// ggml_compute_forward_cross_entropy_loss - -static void ggml_compute_forward_cross_entropy_loss_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_scalar(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - - const int ith = params->ith; - const int nth = params->nth; - - float * sums = (float *) params->wdata; - - // TODO: handle transposed/permuted matrices - const int nc = src0->ne[0]; - const int nr = ggml_nrows(src0); - - GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); - - if (ith == 0) { - memset(sums, 0, sizeof(float) * (nth + nth * nc)); - } - ggml_barrier(params->threadpool); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int i1 = ir0; i1 < ir1; i1++) { - float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); - float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); - float * st = ((float *) params->wdata) + nth + ith*nc; - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(s0[i])); - assert(!isnan(s1[i])); - } -#endif - - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, s0); - ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max); - assert(sum >= 0.0); - - ggml_vec_add1_f32(nc, st, st, -sum); - ggml_vec_mul_f32(nc, st, st, s1); - - float st_sum = 0.0f; - ggml_vec_sum_f32(nc, &st_sum, st); - sums[ith] += st_sum; - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(st[i])); - assert(!isinf(st[i])); - } -#endif - } - ggml_barrier(params->threadpool); - - if (ith == 0) { - float * dp = (float *) dst->data; - ggml_vec_sum_f32(nth, dp, sums); - dp[0] *= -1.0f / (float) nr; - } -} - -static void ggml_compute_forward_cross_entropy_loss( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_cross_entropy_loss_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -// ggml_compute_forward_cross_entropy_loss_back - -static void ggml_compute_forward_cross_entropy_loss_back_f32( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * opt0 = dst->src[2]; - - GGML_ASSERT(ggml_is_contiguous(dst)); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - GGML_ASSERT(ggml_is_contiguous(opt0)); - GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - const int64_t ith = params->ith; - const int64_t nth = params->nth; - - // TODO: handle transposed/permuted matrices - const int64_t nc = src0->ne[0]; - const int64_t nr = ggml_nrows(src0); - - // rows per thread - const int64_t dr = (nr + nth - 1)/nth; - - // row range for this thread - const int64_t ir0 = dr*ith; - const int64_t ir1 = MIN(ir0 + dr, nr); - - float * d = (float *) opt0->data; - - for (int64_t i1 = ir0; i1 < ir1; i1++) { - float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); - float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); - float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - //printf("p[%d] = %f\n", i, p[i]); - assert(!isnan(s0[i])); - assert(!isnan(s1[i])); - } -#endif - - // soft_max - float max = -INFINITY; - ggml_vec_max_f32(nc, &max, s0); - ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max); - assert(sum > 0.0); - ggml_vec_scale_f32(nc, ds0, 1.0/sum); - - // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr - ggml_vec_sub_f32(nc, ds0, ds0, s1); - ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); - -#ifndef NDEBUG - for (int i = 0; i < nc; ++i) { - assert(!isnan(ds0[i])); - assert(!isinf(ds0[i])); - } -#endif - } -} - -static void ggml_compute_forward_cross_entropy_loss_back( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * src0 = dst->src[0]; - - switch (src0->type) { - case GGML_TYPE_F32: - { - ggml_compute_forward_cross_entropy_loss_back_f32(params, dst); - } break; - default: - { - GGML_ABORT("fatal error"); - } - } -} - -///////////////////////////////// - -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { - GGML_ASSERT(params); - - if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { - return; - } - - switch (tensor->op) { - case GGML_OP_DUP: - { - ggml_compute_forward_dup(params, tensor); - } break; - case GGML_OP_ADD: - { - ggml_compute_forward_add(params, tensor); - } break; - case GGML_OP_ADD1: - { - ggml_compute_forward_add1(params, tensor); - } break; - case GGML_OP_ACC: - { - ggml_compute_forward_acc(params, tensor); - } break; - case GGML_OP_SUB: - { - ggml_compute_forward_sub(params, tensor); - } break; - case GGML_OP_MUL: - { - ggml_compute_forward_mul(params, tensor); - } break; - case GGML_OP_DIV: - { - ggml_compute_forward_div(params, tensor); - } break; - case GGML_OP_SQR: - { - ggml_compute_forward_sqr(params, tensor); - } break; - case GGML_OP_SQRT: - { - ggml_compute_forward_sqrt(params, tensor); - } break; - case GGML_OP_LOG: - { - ggml_compute_forward_log(params, tensor); - } break; - case GGML_OP_SIN: - { - ggml_compute_forward_sin(params, tensor); - } break; - case GGML_OP_COS: - { - ggml_compute_forward_cos(params, tensor); - } break; - case GGML_OP_SUM: - { - ggml_compute_forward_sum(params, tensor); - } break; - case GGML_OP_SUM_ROWS: - { - ggml_compute_forward_sum_rows(params, tensor); - } break; - case GGML_OP_MEAN: - { - ggml_compute_forward_mean(params, tensor); - } break; - case GGML_OP_ARGMAX: - { - ggml_compute_forward_argmax(params, tensor); - } break; - case GGML_OP_REPEAT: - { - ggml_compute_forward_repeat(params, tensor); - } break; - case GGML_OP_REPEAT_BACK: - { - ggml_compute_forward_repeat_back(params, tensor); - } break; - case GGML_OP_CONCAT: - { - ggml_compute_forward_concat(params, tensor); - } break; - case GGML_OP_SILU_BACK: - { - ggml_compute_forward_silu_back(params, tensor); - } break; - case GGML_OP_NORM: - { - ggml_compute_forward_norm(params, tensor); - } break; - case GGML_OP_RMS_NORM: - { - ggml_compute_forward_rms_norm(params, tensor); - } break; - case GGML_OP_RMS_NORM_BACK: - { - ggml_compute_forward_rms_norm_back(params, tensor); - } break; - case GGML_OP_GROUP_NORM: - { - ggml_compute_forward_group_norm(params, tensor); - } break; - case GGML_OP_MUL_MAT: - { - ggml_compute_forward_mul_mat(params, tensor); - } break; - case GGML_OP_MUL_MAT_ID: - { - ggml_compute_forward_mul_mat_id(params, tensor); - } break; - case GGML_OP_OUT_PROD: - { - ggml_compute_forward_out_prod(params, tensor); - } break; - case GGML_OP_SCALE: - { - ggml_compute_forward_scale(params, tensor); - } break; - case GGML_OP_SET: - { - ggml_compute_forward_set(params, tensor); - } break; - case GGML_OP_CPY: - { - ggml_compute_forward_cpy(params, tensor); - } break; - case GGML_OP_CONT: - { - ggml_compute_forward_cont(params, tensor); - } break; - case GGML_OP_RESHAPE: - { - ggml_compute_forward_reshape(params, tensor); - } break; - case GGML_OP_VIEW: - { - ggml_compute_forward_view(params, tensor); - } break; - case GGML_OP_PERMUTE: - { - ggml_compute_forward_permute(params, tensor); - } break; - case GGML_OP_TRANSPOSE: - { - ggml_compute_forward_transpose(params, tensor); - } break; - case GGML_OP_GET_ROWS: - { - ggml_compute_forward_get_rows(params, tensor); - } break; - case GGML_OP_GET_ROWS_BACK: - { - ggml_compute_forward_get_rows_back(params, tensor); - } break; - case GGML_OP_DIAG: - { - ggml_compute_forward_diag(params, tensor); - } break; - case GGML_OP_DIAG_MASK_INF: - { - ggml_compute_forward_diag_mask_inf(params, tensor); - } break; - case GGML_OP_DIAG_MASK_ZERO: - { - ggml_compute_forward_diag_mask_zero(params, tensor); - } break; - case GGML_OP_SOFT_MAX: - { - ggml_compute_forward_soft_max(params, tensor); - } break; - case GGML_OP_SOFT_MAX_BACK: - { - ggml_compute_forward_soft_max_back(params, tensor); - } break; - case GGML_OP_ROPE: - { - ggml_compute_forward_rope(params, tensor); - } break; - case GGML_OP_ROPE_BACK: - { - ggml_compute_forward_rope_back(params, tensor); - } break; - case GGML_OP_CLAMP: - { - ggml_compute_forward_clamp(params, tensor); - } break; - case GGML_OP_CONV_TRANSPOSE_1D: - { - ggml_compute_forward_conv_transpose_1d(params, tensor); - } break; - case GGML_OP_IM2COL: - { - ggml_compute_forward_im2col(params, tensor); - } break; - case GGML_OP_IM2COL_BACK: - { - ggml_compute_forward_im2col_back_f32(params, tensor); - } break; - case GGML_OP_CONV_TRANSPOSE_2D: - { - ggml_compute_forward_conv_transpose_2d(params, tensor); - } break; - case GGML_OP_POOL_1D: - { - ggml_compute_forward_pool_1d(params, tensor); - } break; - case GGML_OP_POOL_2D: - { - ggml_compute_forward_pool_2d(params, tensor); - } break; - case GGML_OP_POOL_2D_BACK: - { - ggml_compute_forward_pool_2d_back(params, tensor); - } break; - case GGML_OP_UPSCALE: - { - ggml_compute_forward_upscale(params, tensor); - } break; - case GGML_OP_PAD: - { - ggml_compute_forward_pad(params, tensor); - } break; - case GGML_OP_ARANGE: - { - ggml_compute_forward_arange(params, tensor); - } break; - case GGML_OP_TIMESTEP_EMBEDDING: - { - ggml_compute_forward_timestep_embedding(params, tensor); - } break; - case GGML_OP_ARGSORT: - { - ggml_compute_forward_argsort(params, tensor); - } break; - case GGML_OP_LEAKY_RELU: - { - ggml_compute_forward_leaky_relu(params, tensor); - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - ggml_compute_forward_flash_attn_back(params, masked, tensor); - } break; - case GGML_OP_SSM_CONV: - { - ggml_compute_forward_ssm_conv(params, tensor); - } break; - case GGML_OP_SSM_SCAN: - { - ggml_compute_forward_ssm_scan(params, tensor); - } break; - case GGML_OP_WIN_PART: - { - ggml_compute_forward_win_part(params, tensor); - } break; - case GGML_OP_WIN_UNPART: - { - ggml_compute_forward_win_unpart(params, tensor); - } break; - case GGML_OP_UNARY: - { - ggml_compute_forward_unary(params, tensor); - } break; - case GGML_OP_GET_REL_POS: - { - ggml_compute_forward_get_rel_pos(params, tensor); - } break; - case GGML_OP_ADD_REL_POS: - { - ggml_compute_forward_add_rel_pos(params, tensor); - } break; - case GGML_OP_RWKV_WKV: - { - ggml_compute_forward_rwkv_wkv(params, tensor); - } break; - case GGML_OP_MAP_UNARY: - { - ggml_unary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_unary(params, tensor, fun); - } - break; - case GGML_OP_MAP_BINARY: - { - ggml_binary_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_binary(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM1_F32: - { - ggml_custom1_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom1_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM2_F32: - { - ggml_custom2_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom2_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM3_F32: - { - ggml_custom3_op_f32_t fun; - memcpy(&fun, tensor->op_params, sizeof(fun)); - ggml_compute_forward_map_custom3_f32(params, tensor, fun); - } - break; - case GGML_OP_MAP_CUSTOM1: - { - ggml_compute_forward_map_custom1(params, tensor); - } - break; - case GGML_OP_MAP_CUSTOM2: - { - ggml_compute_forward_map_custom2(params, tensor); - } - break; - case GGML_OP_MAP_CUSTOM3: - { - ggml_compute_forward_map_custom3(params, tensor); - } - break; - case GGML_OP_CROSS_ENTROPY_LOSS: - { - ggml_compute_forward_cross_entropy_loss(params, tensor); - } - break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - ggml_compute_forward_cross_entropy_loss_back(params, tensor); - } - break; - case GGML_OP_NONE: - { - // nop - } break; - case GGML_OP_COUNT: - { - GGML_ABORT("fatal error"); - } - } + struct ggml_tensor * a, + struct ggml_tensor * grad, + struct ggml_tensor * m, + struct ggml_tensor * v, + struct ggml_tensor * adamw_params) { + GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); + GGML_ASSERT(ggml_are_same_shape(a, grad)); + GGML_ASSERT(ggml_are_same_shape(a, m)); + GGML_ASSERT(ggml_are_same_shape(a, v)); + GGML_ASSERT(adamw_params->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_OPT_STEP_ADAMW; + result->src[0] = a; + result->src[1] = grad; + result->src[2] = m; + result->src[3] = v; + result->src[4] = adamw_params; + + return result; } //////////////////////////////////////////////////////////////////////////////// @@ -17864,1053 +5194,543 @@ static void ggml_hash_map_free(struct hash_map * map) { GGML_FREE(map); } -// gradient checkpointing +// utility functions to change gradients +// isrc is the index of tensor in cgraph->visited_has_set.keys +// the corresponding gradient (accumulators) are also at position isrc +// if tensor has a gradient accumulator, modify that accumulator in-place +// else if there is no gradient for tensor, set the corresponding value +// else, just add/subtract/etc. the gradients -static struct ggml_tensor * ggml_recompute_graph_node( +static void ggml_add_or_set( struct ggml_context * ctx, - struct ggml_cgraph * graph, - struct hash_map * replacements, - struct ggml_tensor * node) { - - if (node == NULL) { - return NULL; + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + GGML_ASSERT(src); + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = tensor; } - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - return node; - } - - if (!ggml_hash_contains(&graph->visited_hash_set, node)) { - return node; - } - - int count_children = 0; - for (int k = 0; k < GGML_MAX_SRC; ++k) { - if (node->src[k]) { - ++count_children; - } - } - - if (count_children == 0) { - return node; - } - - size_t i = ggml_hash_find(&replacements->set, node); - GGML_ASSERT(i != GGML_HASHSET_FULL); // assert that not full - if (replacements->set.keys[i] == node) { - return replacements->vals[i]; - } - - struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, GGML_MAX_DIMS, node->ne); - - // insert clone into replacements - GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite - replacements->set.keys[i] = node; - replacements->vals[i] = clone; - - clone->op = node->op; - clone->grad = node->grad; - clone->flags = node->flags; - clone->extra = node->extra; - for (int k = 0; k < GGML_MAX_DIMS; ++k) { - clone->nb[k] = node->nb[k]; - } - for (int k = 0; k < GGML_MAX_SRC; ++k) { - clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); - } - if (node->view_src != NULL) { - clone->data = (node->view_src->data == NULL) - ? NULL // view_src not yet allocated - : (char *) node->view_src->data // view_src already allocated - + node->view_offs; - clone->view_src = node->view_src; - clone->view_offs = node->view_offs; - } - - GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); - GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME); - memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); - ggml_format_name(clone, "%s (clone)", ggml_get_name(node)); - - return clone; + ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -void ggml_build_backward_gradient_checkpointing( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - struct ggml_cgraph * gb_tmp, - struct ggml_tensor * * checkpoints, - int n_checkpoints) { - ggml_graph_cpy(gf, gb_tmp); - ggml_build_backward_expand(ctx, gf, gb_tmp, true); +static void ggml_acc_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor, + const size_t nb1, + const size_t nb2, + const size_t nb3, + const size_t offset) { + struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + GGML_ASSERT(src); + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]); + } else { + struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN + cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false); + } + ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); +} - if (n_checkpoints <= 0) { - ggml_graph_cpy(gb_tmp, gb); +static void ggml_add1_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + GGML_ASSERT(src); + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src); + } + ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); +} + +static void ggml_sub_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + GGML_ASSERT(src); + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = ggml_neg(ctx, tensor); + } + ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); +} + +static void ggml_compute_backward( + struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) { + struct ggml_tensor * tensor = cgraph->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor); + + if (!grad) { return; } - struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints); - - // insert checkpoints in replacements - for (int i = 0; i < n_checkpoints; ++i) { - size_t k = ggml_hash_find(&replacements->set, checkpoints[i]); - GGML_ASSERT(k != GGML_HASHSET_FULL); // assert that not full - GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite - replacements->set.keys[k] = checkpoints[i]; - replacements->vals[k] = checkpoints[i]; - } - - ggml_graph_cpy(gf, gb); - // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], - // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), - // by recomputing them from checkpoints - for (int i = gf->n_nodes; in_nodes; ++i) { - struct ggml_tensor * node = gb_tmp->nodes[i]; - for (int k = 0; k < GGML_MAX_SRC; ++k) { - // insert new tensors recomputing src, reusing already made replacements, - // remember replacements: remember new tensors with mapping from corresponding gf nodes - // recurse for input tensors, - // unless (i.e. terminating when) input tensors are replacements (like checkpoints) - node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); - } - // insert rewritten backward node with replacements made into resulting backward graph gb - ggml_build_forward_expand(gb, node); - } - - ggml_hash_map_free(replacements); -} - -// functions to change gradients considering the case that input a might be initial gradient with zero value - -static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) { - if (ggml_hash_contains(zero_table, a)) { - return b; - } else { - return ggml_add_impl(ctx, a, b, false); - } -} - -static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) { - if (ggml_hash_contains(zero_table, a)) { - struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); - return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); - } else { - return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); - } -} - -static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) { - if (ggml_hash_contains(zero_table, a)) { - return ggml_repeat(ctx, b, a); - } else { - return ggml_add1_impl(ctx, a, b, false); - } -} - -static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) { - if (ggml_hash_contains(zero_table, a)) { - return ggml_neg(ctx, b); - } else { - return ggml_sub_impl(ctx, a, b, false); - } -} - -static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table) { struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; struct ggml_tensor * src2 = tensor->src[2]; + struct ggml_hash_set * hash_set = &cgraph->visited_hash_set; + const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1; + const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1; + const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1; + const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0]; + const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1]; + const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2]; switch (tensor->op) { - case GGML_OP_DUP: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - } break; - case GGML_OP_ADD: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - if (ggml_are_same_shape(src0, src1)) { - src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); - } else { - src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table); - } - } - } break; - case GGML_OP_ADD1: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - src1->grad = ggml_add_or_set(ctx, - src1->grad, - ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - zero_table); - } - } break; - case GGML_OP_ACC: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, - tensor->grad, - src1->grad->ne[0], - src1->grad->ne[1], - src1->grad->ne[2], - src1->grad->ne[3], - nb1, nb2, nb3, offset); - - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_reshape(ctx, - ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table); - } - } break; - case GGML_OP_SUB: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table); - } - } break; - case GGML_OP_MUL: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, src1, tensor->grad), - zero_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_mul(ctx, src0, tensor->grad), - zero_table); - } - } break; - case GGML_OP_DIV: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_div(ctx, tensor->grad, src1), - zero_table); - } - if (src1->grad) { - src1->grad = - ggml_sub_or_set(ctx, - src1->grad, - ggml_mul(ctx, - tensor->grad, - ggml_div(ctx, tensor, src1)), - zero_table); - } - } break; - case GGML_OP_SQR: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale(ctx, - ggml_mul(ctx, src0, tensor->grad), - 2.0f), - zero_table); - } - } break; - case GGML_OP_SQRT: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale(ctx, - ggml_div(ctx, - tensor->grad, - tensor), - 0.5f), - zero_table); - } - } break; - case GGML_OP_LOG: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_div(ctx, - tensor->grad, - src0), - zero_table); - } - } break; - case GGML_OP_SIN: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - tensor->grad, - ggml_cos(ctx, src0)), - zero_table); - } - } break; - case GGML_OP_COS: - { - if (src0->grad) { - src0->grad = - ggml_sub_or_set(ctx, - src0->grad, - ggml_mul(ctx, - tensor->grad, - ggml_sin(ctx, src0)), - zero_table); - } - } break; - case GGML_OP_SUM: - { - if (src0->grad) { - src0->grad = - ggml_add1_or_set(ctx, - src0->grad, - tensor->grad, - zero_table); - } - } break; - case GGML_OP_SUM_ROWS: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_repeat(ctx, - tensor->grad, - src0->grad), - zero_table); - } - } break; - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - { - GGML_ABORT("fatal error"); // TODO: implement + case GGML_OP_DUP: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_REPEAT: - { - // necessary for llama - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_repeat_back(ctx, tensor->grad, src0->grad), - zero_table); - } - } break; - case GGML_OP_REPEAT_BACK: - { - if (src0->grad) { - // TODO: test this - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table); - } - } break; - case GGML_OP_CONCAT: - { - GGML_ABORT("fatal error"); // TODO: implement + } break; + case GGML_OP_ADD: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_SILU_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + struct ggml_tensor * tmp = grad; + if (!ggml_are_same_shape(src0, src1)) { + tmp = ggml_repeat_back(ctx, tmp, src1); + } + ggml_add_or_set(ctx, cgraph, isrc1, tmp); } - case GGML_OP_NORM: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_ADD1: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_RMS_NORM: - { - // necessary for llama - if (src0->grad) { - float eps; - memcpy(&eps, tensor->op_params, sizeof(float)); - - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - zero_table); - } - } break; - case GGML_OP_RMS_NORM_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean } - case GGML_OP_GROUP_NORM: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_ACC: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_MUL_MAT: - { - // https://cs231n.github.io/optimization-2/#staged - // # forward pass - // s0 = np.random.randn(5, 10) - // s1 = np.random.randn(10, 3) - // t = s0.dot(s1) + if (src1_needs_grads) { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; - // # now suppose we had the gradient on t from above in the circuit - // dt = np.random.randn(*t.shape) # same shape as t - // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix - // ds1 = t.T.dot(dt) + struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); - // tensor.shape [m,p,qq,rr] - // src0.shape [n,m,q1,r1] - // src1.shape [n,p,qq,rr] - - // necessary for llama - if (src0->grad) { - struct ggml_tensor * s1_tg = - ggml_out_prod(ctx, // [n,m,qq,rr] - src1, // [n,p,qq,rr] - tensor->grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = ggml_repeat_back(ctx, s1_tg, src0); - } - src0->grad = - ggml_add_or_set(ctx, - src0->grad, // [n,m,q1,r1] - s1_tg, // [n,m,q1,r1] - zero_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, // [n,p,qq,rr] - // ggml_mul_mat(ctx, // [n,p,qq,rr] - // ggml_cont(ctx, // [m,n,q1,r1] - // ggml_transpose(ctx, src0)), // [m,n,q1,r1] - // tensor->grad), // [m,p,qq,rr] - - // // when src0 is bigger than tensor->grad (this is mostly the case in llama), - // // avoid transpose of src0, rather transpose smaller tensor->grad - // // and then use ggml_out_prod - ggml_out_prod(ctx, // [n,p,qq,rr] - src0, // [n,m,q1,r1] - ggml_transpose(ctx, // [p,m,qq,rr] - tensor->grad)), // [m,p,qq,rr] - zero_table); - } - } break; - case GGML_OP_MUL_MAT_ID: - { - GGML_ABORT("fatal error"); // TODO: not implemented + ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1)); } - case GGML_OP_OUT_PROD: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_SUB: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_SCALE: - { - // necessary for llama - if (src0->grad) { - float s; - memcpy(&s, tensor->op_params, sizeof(float)); - - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale_impl(ctx, tensor->grad, s, false), - zero_table); - } - } break; - case GGML_OP_SET: - { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct ggml_tensor * tensor_grad_view = NULL; - - if (src0->grad || src1->grad) { - GGML_ASSERT(src0->type == tensor->type); - GGML_ASSERT(tensor->grad->type == tensor->type); - GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type); - - tensor_grad_view = ggml_view_4d(ctx, - tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - nb1, nb2, nb3, offset); - } - - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_acc_impl(ctx, - tensor->grad, - ggml_neg(ctx, tensor_grad_view), - nb1, nb2, nb3, offset, false), - zero_table); - } - - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_reshape(ctx, - ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table); - } - } break; - case GGML_OP_CPY: - { - // necessary for llama - // cpy overwrites value of src1 by src0 and returns view(src1) - // the overwriting is mathematically equivalent to: - // tensor = src0 * 1 + src1 * 0 - if (src0->grad) { - // dsrc0 = dtensor * 1 - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - if (src1->grad) { - // dsrc1 = dtensor * 0 -> noop - } - } break; - case GGML_OP_CONT: - { - // same as cpy - if (src0->grad) { - GGML_ASSERT(ggml_is_contiguous(src0->grad)); - GGML_ASSERT(ggml_is_contiguous(tensor->grad)); - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - } break; - case GGML_OP_RESHAPE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_reshape(ctx, - ggml_is_contiguous(tensor->grad) - ? tensor->grad - : ggml_cont(ctx, tensor->grad), - src0->grad), - zero_table); - } - } break; - case GGML_OP_VIEW: - { - // necessary for llama - if (src0->grad) { - size_t offset; - - memcpy(&offset, tensor->op_params, sizeof(offset)); - - size_t nb1 = tensor->nb[1]; - size_t nb2 = tensor->nb[2]; - size_t nb3 = tensor->nb[3]; - - if (src0->type != src0->grad->type) { - // gradient is typically F32, but src0 could be other type - size_t ng = ggml_element_size(src0->grad); - size_t n0 = ggml_element_size(src0); - GGML_ASSERT(offset % n0 == 0); - GGML_ASSERT(nb1 % n0 == 0); - GGML_ASSERT(nb2 % n0 == 0); - GGML_ASSERT(nb3 % n0 == 0); - offset = (offset / n0) * ng; - nb1 = (nb1 / n0) * ng; - nb2 = (nb2 / n0) * ng; - nb3 = (nb3 / n0) * ng; - } - - src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table); - } - } break; - case GGML_OP_PERMUTE: - { - // necessary for llama - if (src0->grad) { - int32_t * axes = (int32_t *) tensor->op_params; - int axis0 = axes[0] & 0x3; - int axis1 = axes[1] & 0x3; - int axis2 = axes[2] & 0x3; - int axis3 = axes[3] & 0x3; - int axes_backward[4] = {0,0,0,0}; - axes_backward[axis0] = 0; - axes_backward[axis1] = 1; - axes_backward[axis2] = 2; - axes_backward[axis3] = 3; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_permute(ctx, - tensor->grad, - axes_backward[0], - axes_backward[1], - axes_backward[2], - axes_backward[3]), - zero_table); - } - } break; - case GGML_OP_TRANSPOSE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_transpose(ctx, tensor->grad), - zero_table); - } - } break; - case GGML_OP_GET_ROWS: - { - // necessary for llama (only for tokenizer) - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - // last ggml_get_rows_back argument src0->grad is only - // necessary to setup correct output shape - ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - zero_table); - } - if (src1->grad) { - // noop - } - } break; - case GGML_OP_GET_ROWS_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc1, grad); } - case GGML_OP_DIAG: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_MUL: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1)); } - case GGML_OP_DIAG_MASK_INF: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - /* ggml_diag_mask_inf_impl() shouldn't be here */ - /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ - ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table); + if (src1_needs_grads) { + struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); + if (!ggml_are_same_shape(src0, src1)) { + tmp = ggml_repeat_back(ctx, tmp, src1); } - } break; - case GGML_OP_DIAG_MASK_ZERO: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table); + ggml_add_or_set(ctx, cgraph, isrc1, tmp); + } + } break; + case GGML_OP_DIV: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1)); + } + if (src1_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1))); + } + } break; + case GGML_OP_SQR: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f)); + } + } break; + case GGML_OP_SQRT: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f)); + } + } break; + case GGML_OP_LOG: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0)); + } + } break; + case GGML_OP_SIN: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0))); + } + } break; + case GGML_OP_COS: { + if (src0_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0))); + } + } break; + case GGML_OP_SUM: { + if (src0_needs_grads) { + ggml_add1_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case GGML_OP_SUM_ROWS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0)); + } + } break; + case GGML_OP_MEAN: { + if (src0_needs_grads) { + ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); + } + } break; + case GGML_OP_REPEAT: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0)); + } + } break; + case GGML_OP_REPEAT_BACK: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0)); + } + } break; + case GGML_OP_RMS_NORM: { + if (src0_needs_grads) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps)); + } + } break; + case GGML_OP_MUL_MAT: { + // https://cs231n.github.io/optimization-2/#staged + // # forward pass + // s0 = np.random.randn(5, 10) + // s1 = np.random.randn(10, 3) + // t = s0.dot(s1) + + // # now suppose we had the gradient on t from above in the circuit + // dt = np.random.randn(*t.shape) # same shape as t + // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix + // ds1 = t.T.dot(dt) + + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] + + if (src0_needs_grads) { + GGML_ASSERT(grad->ne[2] == src1->ne[2]); + GGML_ASSERT(grad->ne[3] == src1->ne[3]); + struct ggml_tensor * tmp = + ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + grad); // [m,p,qq,rr] + if (!ggml_are_same_shape(tmp, src0)) { + GGML_ASSERT(tmp->ne[0] == src0->ne[0]); + GGML_ASSERT(tmp->ne[1] == src0->ne[1]); + GGML_ASSERT(tmp->ne[3] == 1); + + const int64_t nr2 = tmp->ne[2] / src0->ne[2]; + const size_t nb2 = tmp->nb[2] * nr2; + const size_t nb3 = tmp->nb[2]; + + tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0); + tmp = ggml_repeat_back(ctx, tmp, src0); } - } break; - case GGML_OP_SOFT_MAX: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_soft_max_back(ctx, tensor->grad, tensor), - zero_table); + ggml_add_or_set(ctx, cgraph, isrc0, tmp); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, + // ggml_mul_mat(ctx, // [n,p,qq,rr] + // ggml_cont(ctx, // [m,n,q1,r1] + // ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // grad), // [m,p,qq,rr] + + // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // avoid transpose of src0, rather transpose smaller tensor->grad + // and then use ggml_out_prod + ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + ggml_transpose(ctx, // [p,m,qq,rr] + grad))); // [m,p,qq,rr] + } + } break; + case GGML_OP_SCALE: { + if (src0_needs_grads) { + float s; + memcpy(&s, tensor->op_params, sizeof(float)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false)); + } + } break; + case GGML_OP_SET: { + const size_t nb1 = ((const int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((const int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((const int32_t *) tensor->op_params)[2]; + const size_t offset = ((const int32_t *) tensor->op_params)[3]; + + struct ggml_tensor * tensor_grad_view = NULL; + + if (src0_needs_grads || src1_needs_grads) { + GGML_ASSERT(src0->type == tensor->type); + GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type); + GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type); + + tensor_grad_view = ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); + } + + if (src0_needs_grads) { + struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false)); + } + + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1)); + } + } break; + case GGML_OP_CPY: { + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 + if (src0_needs_grads) { + // dsrc0 = dtensor * 1 + ggml_add_or_set(ctx, cgraph, isrc0, grad); + } + if (src1_needs_grads) { + // dsrc1 = dtensor * 0 -> noop + } + } break; + case GGML_OP_CONT: { + // same as cpy + if (src0_needs_grads) { + GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); + GGML_ASSERT(ggml_is_contiguous(grad)); + GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0)); + ggml_add_or_set(ctx, cgraph, isrc0, + ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0)); + } + } break; + case GGML_OP_RESHAPE: { + if (src0_needs_grads) { + struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0)); + } + } break; + case GGML_OP_VIEW: { + if (src0_needs_grads) { + size_t offset; + + memcpy(&offset, tensor->op_params, sizeof(offset)); + + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; + + if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) { + // gradient is typically F32, but src0 could be other type + size_t ng = ggml_element_size(cgraph->grads[isrc0]); + size_t n0 = ggml_element_size(src0); + GGML_ASSERT(offset % n0 == 0); + GGML_ASSERT(nb1 % n0 == 0); + GGML_ASSERT(nb2 % n0 == 0); + GGML_ASSERT(nb3 % n0 == 0); + offset = (offset / n0) * ng; + nb1 = (nb1 / n0) * ng; + nb2 = (nb2 / n0) * ng; + nb3 = (nb3 / n0) * ng; } - } break; - case GGML_OP_SOFT_MAX_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset); } - case GGML_OP_ROPE: - { - // necessary for llama - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + } break; + case GGML_OP_PERMUTE: { + if (src0_needs_grads) { + const int32_t * axes = (const int32_t *) tensor->op_params; + const int axis0 = axes[0] & 0x3; + const int axis1 = axes[1] & 0x3; + const int axis2 = axes[2] & 0x3; + const int axis3 = axes[3] & 0x3; + int axb[4] = {0,0,0,0}; // axes backward + axb[axis0] = 0; + axb[axis1] = 1; + axb[axis2] = 2; + axb[axis3] = 3; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3])); + } + } break; + case GGML_OP_TRANSPOSE: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad)); + } + } break; + case GGML_OP_GET_ROWS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0)); + } + if (src1_needs_grads) { + // noop + } + } break; + case GGML_OP_DIAG_MASK_INF: { + if (src0_needs_grads) { + /* ggml_diag_mask_inf_impl() shouldn't be here */ + /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ + const int n_past = ((const int32_t *) tensor->op_params)[0]; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); + } + } break; + case GGML_OP_DIAG_MASK_ZERO: { + if (src0_needs_grads) { + const int n_past = ((const int32_t *) tensor->op_params)[0]; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); + } + } break; + case GGML_OP_SOFT_MAX: { + if (src0_needs_grads) { + float scale = 1.0f; + float max_bias = 0.0f; - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); + memcpy(&scale, (const float *) tensor->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float)); - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rope_back(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow), - zero_table); - } - } break; - case GGML_OP_ROPE_BACK: - { - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias)); + } + GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented"); + } break; + case GGML_OP_ROPE: { + if (src0_needs_grads) { + //const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((const int32_t *) tensor->op_params)[1]; + const int mode = ((const int32_t *) tensor->op_params)[2]; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4] = {0, 0, 0, 0}; - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); + memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float)); + memcpy(§ions, tensor->op_params + 11, sizeof(sections)); - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rope_impl(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - false), - zero_table); - } - } break; - case GGML_OP_CLAMP: - { - GGML_ABORT("fatal error"); // TODO: not implemented + struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ? + ggml_rope_ext_back(ctx, grad, src1, src2, n_dims, + mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) : + ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections, + mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + ggml_add_or_set(ctx, cgraph, isrc0, rope_back); } - case GGML_OP_CONV_TRANSPOSE_1D: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_IM2COL: - { - if (src1->grad) { - const int32_t s0 = ggml_get_op_params_i32(tensor, 0); - const int32_t s1 = ggml_get_op_params_i32(tensor, 1); - const int32_t p0 = ggml_get_op_params_i32(tensor, 2); - const int32_t p1 = ggml_get_op_params_i32(tensor, 3); - const int32_t d0 = ggml_get_op_params_i32(tensor, 4); - const int32_t d1 = ggml_get_op_params_i32(tensor, 5); - const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; + GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented"); + } break; + case GGML_OP_IM2COL: { + if (src1_needs_grads) { + const int32_t s0 = ggml_get_op_params_i32(tensor, 0); + const int32_t s1 = ggml_get_op_params_i32(tensor, 1); + const int32_t p0 = ggml_get_op_params_i32(tensor, 2); + const int32_t p1 = ggml_get_op_params_i32(tensor, 3); + const int32_t d0 = ggml_get_op_params_i32(tensor, 4); + const int32_t d1 = ggml_get_op_params_i32(tensor, 5); + const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; - src1->grad = ggml_add_or_set(ctx, - src1->grad, - ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), - zero_table); - } - } break; - case GGML_OP_IM2COL_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); } - case GGML_OP_CONV_TRANSPOSE_2D: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_POOL_1D: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_POOL_2D: - { - if (src0->grad) { - const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0); - const int32_t k0 = ggml_get_op_params_i32(tensor, 1); - const int32_t k1 = ggml_get_op_params_i32(tensor, 2); - const int32_t s0 = ggml_get_op_params_i32(tensor, 3); - const int32_t s1 = ggml_get_op_params_i32(tensor, 4); - const int32_t p0 = ggml_get_op_params_i32(tensor, 5); - const int32_t p1 = ggml_get_op_params_i32(tensor, 6); + } break; + case GGML_OP_POOL_2D: { + if (src0_needs_grads) { + const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0); + const int32_t k0 = ggml_get_op_params_i32(tensor, 1); + const int32_t k1 = ggml_get_op_params_i32(tensor, 2); + const int32_t s0 = ggml_get_op_params_i32(tensor, 3); + const int32_t s1 = ggml_get_op_params_i32(tensor, 4); + const int32_t p0 = ggml_get_op_params_i32(tensor, 5); + const int32_t p1 = ggml_get_op_params_i32(tensor, 6); - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), - zero_table); - } - } break; - case GGML_OP_POOL_2D_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_UPSCALE: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_PAD: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_ARANGE: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_ARGSORT: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_LEAKY_RELU: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_FLASH_ATTN_EXT: - { - struct ggml_tensor * flash_grad = NULL; - if (src0->grad || src1->grad || tensor->src[2]->grad) { - int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - flash_grad = - ggml_flash_attn_back(ctx, - src0, - src1, - tensor->src[2], - tensor->grad, - masked); - } - - const int64_t elem_q = ggml_nelements(src0); - const int64_t elem_k = ggml_nelements(src1); - const int64_t elem_v = ggml_nelements(src2); - - enum ggml_type result_type = flash_grad->type; - GGML_ASSERT(ggml_blck_size(result_type) == 1); - const size_t tsize = ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); - const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); - - if (src0->grad) { - struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q); - struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0); - src0->grad = ggml_add_or_set(ctx, - src0->grad, - grad_q, - zero_table); - } - if (src1->grad) { - struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k); - struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1); - src1->grad = ggml_add_or_set(ctx, - src1->grad, - grad_k, - zero_table); - } - if (src2->grad) { - struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v); - struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2); - src2->grad = ggml_add_or_set(ctx, - src2->grad, - grad_v, - zero_table); - } - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - GGML_ABORT("fatal error"); // not supported - } - case GGML_OP_SSM_CONV: - case GGML_OP_SSM_SCAN: - { - GGML_ABORT("fatal error"); // TODO: not implemented + ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1)); } + } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: - case GGML_OP_UNARY: - { - switch (ggml_get_unary_op(tensor)) { - case GGML_UNARY_OP_ABS: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - ggml_sgn(ctx, src0), - tensor->grad), - zero_table); - } - } break; - case GGML_UNARY_OP_SGN: - { - if (src0->grad) { - // noop - } - } break; - case GGML_UNARY_OP_NEG: - { - if (src0->grad) { - src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table); - } - } break; - case GGML_UNARY_OP_STEP: - { - if (src0->grad) { - // noop - } - } break; - case GGML_UNARY_OP_TANH: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_ELU: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_RELU: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - ggml_step(ctx, src0), - tensor->grad), - zero_table); - } - } break; - case GGML_UNARY_OP_SIGMOID: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_GELU: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_GELU_QUICK: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_SILU: - { - // necessary for llama - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_silu_back(ctx, src0, tensor->grad), - zero_table); - } - } break; - case GGML_UNARY_OP_EXP: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, tensor, tensor->grad), - zero_table); - } - } break; - default: - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_GET_REL_POS: - case GGML_OP_ADD_REL_POS: - case GGML_OP_RWKV_WKV: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: - case GGML_OP_MAP_CUSTOM1: - case GGML_OP_MAP_CUSTOM2: - case GGML_OP_MAP_CUSTOM3: - { - GGML_ABORT("fatal error"); // not supported + case GGML_OP_UNARY: { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_ABS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad)); + } + } break; + case GGML_UNARY_OP_SGN: { + // noop + } break; + case GGML_UNARY_OP_NEG: { + if (src0_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case GGML_UNARY_OP_STEP: { + // noop + } break; + case GGML_UNARY_OP_RELU: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad)); + } + } break; + case GGML_UNARY_OP_SILU: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0)); + } + } break; + case GGML_UNARY_OP_EXP: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad)); + } + } break; + default: { + fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", + __func__, ggml_unary_op_name(ggml_get_unary_op(tensor))); + GGML_ABORT("fatal error"); + } //break; } - case GGML_OP_CROSS_ENTROPY_LOSS: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_cross_entropy_loss_back(ctx, - src0, - src1, - tensor->grad), - zero_table); - } - } break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - GGML_ABORT("fatal error"); // not supported + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1)); } - case GGML_OP_NONE: - { - // nop - } break; + GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); + } break; + case GGML_OP_NONE: { + // noop + } break; case GGML_OP_COUNT: - { - GGML_ABORT("fatal error"); - } + default: { + fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); + GGML_ABORT("fatal error"); + } //break; } - for (int i = 0; i < GGML_MAX_SRC; ++i) { - if (tensor->src[i] && tensor->src[i]->grad) { - GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); - } - } + GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0])); + GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1])); + GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - if (node->grad == NULL) { - // this usually happens when we generate intermediate nodes from constants in the backward pass - // it can also happen during forward pass, if the user performs computations with constants - if (node->op != GGML_OP_NONE) { - //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); - } - } - // check if already visited if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) { return; @@ -18926,7 +5746,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } } - if (node->op == GGML_OP_NONE && node->grad == NULL) { + if (node->op == GGML_OP_NONE && !(node->flags & GGML_TENSOR_FLAG_PARAM)) { // reached a leaf node, not part of the gradient graph (e.g. a constant) GGML_ASSERT(cgraph->n_leafs < cgraph->size); @@ -18944,9 +5764,6 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } cgraph->nodes[cgraph->n_nodes] = node; - if (cgraph->grads) { - cgraph->grads[cgraph->n_nodes] = node->grad; - } cgraph->n_nodes++; } } @@ -18974,50 +5791,101 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * ggml_build_forward_impl(cgraph, tensor, true); } -void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { - GGML_ASSERT(gf->n_nodes > 0); - GGML_ASSERT(gf->grads); +void ggml_build_backward_expand( + struct ggml_context * ctx_static, + struct ggml_context * ctx_compute, + struct ggml_cgraph * cgraph, + bool accumulate) { + GGML_ASSERT(cgraph->n_nodes > 0); + GGML_ASSERT(cgraph->grads); + GGML_ASSERT(cgraph->grad_accs); - // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph - if (keep) { - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; + const int n_nodes_f = cgraph->n_nodes; - if (node->grad) { - node->grad = ggml_dup_tensor(ctx, node); - gf->grads[i] = node->grad; + memset(cgraph->grads, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *)); + memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *)); + bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool)); + + { + bool any_params = false; + bool any_loss = false; + for (int i = 0; i < n_nodes_f; ++i) { + struct ggml_tensor * node = cgraph->nodes[i]; + any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM); + any_loss = any_loss || (node->flags & GGML_TENSOR_FLAG_LOSS); + } + GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?"); + GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?"); + } + + for (int i = 0; i < n_nodes_f; ++i) { + struct ggml_tensor * node = cgraph->nodes[i]; + + if (node->type == GGML_TYPE_I32) { + continue; + } + + bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS); + bool ignore_src[GGML_MAX_SRC] = {false}; + switch (node->op) { + // gradients in node->src[0] for one reason or another have no effect on output gradients + case GGML_OP_IM2COL: // only used for its shape + case GGML_OP_IM2COL_BACK: // same as IM2COL + ignore_src[0] = true; + break; + case GGML_OP_UNARY: { + const enum ggml_unary_op uop = ggml_get_unary_op(node); + // SGN and STEP unary ops are piecewise constant + if (uop == GGML_UNARY_OP_SGN || uop == GGML_UNARY_OP_STEP) { + ignore_src[0] = true; + } + } break; + + // gradients in node->src[1] for one reason or another have no effect on output gradients + case GGML_OP_CPY: // gradients in CPY target are irrelevant + case GGML_OP_GET_ROWS: // row indices not differentiable + case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS + case GGML_OP_ROPE: // positions not differentiable + ignore_src[1] = true; + break; + + default: + break; + } + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) { + continue; } + GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16); + node_needs_grad = true; + break; } + if (!node_needs_grad) { + continue; + } + + // inplace operations are currently not supported + GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || + node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); + + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + GGML_ASSERT(igrad != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad)); + if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { + cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node); + cgraph->grads[igrad] = cgraph->grad_accs[igrad]; + ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name); + } + grads_needed[igrad] = true; } - // remember original gradients which start with zero values - struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size); - for (int i = 0; i < gf->n_nodes; i++) { - if (gf->grads[i]) { - ggml_hash_insert(&zero_table, gf->grads[i]); - } - } - - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct ggml_tensor * node = gf->nodes[i]; - - // inplace operations to add gradients are not created by ggml_compute_backward + for (int i = n_nodes_f - 1; i >= 0; --i) { + // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations - if (node->grad) { - ggml_compute_backward(ctx, node, &zero_table); - } + ggml_compute_backward(ctx_compute, cgraph, i, grads_needed); } - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_expand(gb, node->grad); - } - } - - ggml_hash_set_free(&zero_table); + free(grads_needed); } static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { @@ -19035,7 +5903,8 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) { incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys if (grads) { - incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs } incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); @@ -19061,10 +5930,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz void * p = cgraph + 1; - struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); // check that we allocated the correct amount of memory @@ -19076,12 +5947,17 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.n_leafs =*/ 0, /*.nodes =*/ nodes_ptr, /*.grads =*/ grads_ptr, + /*.grad_accs =*/ grad_accs_ptr, /*.leafs =*/ leafs_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, }; ggml_hash_set_reset(&cgraph->visited_hash_set); + if (grads) { + memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *)); + memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *)); + } return cgraph; } @@ -19092,14 +5968,15 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) { struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) { struct ggml_cgraph cgraph = { - /*.size =*/ 0, - /*.n_nodes =*/ i1 - i0, - /*.n_leafs =*/ 0, - /*.nodes =*/ cgraph0->nodes + i0, - /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, - /*.leafs =*/ NULL, - /*.hash_table =*/ { 0, NULL, NULL }, - /*.order =*/ cgraph0->order, + /*.size =*/ 0, + /*.n_nodes =*/ i1 - i0, + /*.n_leafs =*/ 0, + /*.nodes =*/ cgraph0->nodes + i0, + /*.grads =*/ NULL, // gradients would need visited_hash_set + /*.grad_accs =*/ NULL, + /*.leafs =*/ NULL, + /*.visited_hash_set =*/ { 0, NULL, NULL }, + /*.order =*/ cgraph0->order, }; return cgraph; @@ -19122,19 +5999,33 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { dst->nodes[i] = src->nodes[i]; } - if (src->grads) { - GGML_ASSERT(dst->grads != NULL); - for (int i = 0; i < src->n_nodes; ++i) { - dst->grads[i] = src->grads[i]; - } - } - for (size_t i = 0; i < src->visited_hash_set.size; ++i) { // copy all hashset keys (tensors) that are in use if (ggml_bitset_get(src->visited_hash_set.used, i)) { ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); } } + + if (dst->grads) { + memset(dst->grads, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *)); + memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *)); + } + if (src->grads) { + GGML_ASSERT(dst->grads != NULL); + GGML_ASSERT(dst->grad_accs != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + + GGML_ASSERT(igrad_src != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src)); + GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst)); + + dst->grads[igrad_dst] = src->grads[igrad_src]; + dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; + } + } } struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { @@ -19143,14 +6034,48 @@ struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgrap return result; } +struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { + if (ggml_is_empty(tensor)) { + return tensor; + } + if (tensor->buffer) { + ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor)); + } else { + GGML_ASSERT(tensor->data); + memset(tensor->data, 0, ggml_nbytes(tensor)); + } + return tensor; +} + void ggml_graph_reset(struct ggml_cgraph * cgraph) { GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * grad = cgraph->grads[i]; + struct ggml_tensor * node = cgraph->nodes[i]; + struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node); - if (grad) { - ggml_set_zero(grad); + if (node->op == GGML_OP_OPT_STEP_ADAMW) { + // clear momenta + ggml_set_zero(node->src[2]); + ggml_set_zero(node->src[3]); + } + + // initial gradients of loss should be 1, 0 otherwise + if (grad_acc) { + if (node->flags & GGML_TENSOR_FLAG_LOSS) { + GGML_ASSERT(grad_acc->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_scalar(grad_acc)); + + const float onef = 1.0f; + if (grad_acc->buffer) { + ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float)); + } else { + GGML_ASSERT(grad_acc->data); + *((float *) grad_acc->data) = onef; + } + } else { + ggml_set_zero(grad_acc); + } } } } @@ -19161,1049 +6086,35 @@ void ggml_graph_clear(struct ggml_cgraph * cgraph) { ggml_hash_set_reset(&cgraph->visited_hash_set); } -// Android's libc implementation "bionic" does not support setting affinity -#if defined(__gnu_linux__) -static void set_numa_thread_affinity(int thread_n) { - if (!ggml_is_numa()) { - return; - } - - int node_num; - int rv; - size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); - - switch(g_state.numa.numa_strategy) { - case GGML_NUMA_STRATEGY_DISTRIBUTE: - // run thread on node_num thread_n / (threads per node) - node_num = thread_n % g_state.numa.n_nodes; - break; - case GGML_NUMA_STRATEGY_ISOLATE: - // run thread on current_node - node_num = g_state.numa.current_node; - break; - case GGML_NUMA_STRATEGY_NUMACTL: - // use the cpuset that numactl gave us - rv = pthread_setaffinity_np(pthread_self(), setsize, &g_state.numa.cpuset); - if (rv) { - fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n",strerror(rv)); - } - return; - default: - return; - } - - struct ggml_numa_node * node = &g_state.numa.nodes[node_num]; - - cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); - CPU_ZERO_S(setsize, cpus); - for (size_t i = 0; i < node->n_cpus; ++i) { - CPU_SET_S(node->cpus[i], setsize, cpus); - } - - rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); - if (rv) { - fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv)); - } - - CPU_FREE(cpus); +int ggml_graph_size(struct ggml_cgraph * cgraph) { + return cgraph->size; } -static void clear_numa_thread_affinity(void) { - if (!ggml_is_numa()) { - return; +struct ggml_tensor * ggml_graph_node(struct ggml_cgraph * cgraph, int i) { + if (i < 0) { + GGML_ASSERT(cgraph->n_nodes + i >= 0); + return cgraph->nodes[cgraph->n_nodes + i]; } - size_t setsize = CPU_ALLOC_SIZE(g_state.numa.total_cpus); - - cpu_set_t * cpus = CPU_ALLOC(g_state.numa.total_cpus); - CPU_ZERO_S(setsize, cpus); - for (unsigned i = 0; i < g_state.numa.total_cpus; ++i) { - CPU_SET_S(i, setsize, cpus); - } - - int rv = pthread_setaffinity_np(pthread_self(), setsize, cpus); - if (rv) { - fprintf(stderr, "warning: pthread_setaffinity_np() failed: %s\n", strerror(rv)); - } - - CPU_FREE(cpus); -} -#else -// TODO: Windows etc. -// (the linux implementation may also work on BSD, someone should test) -static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } -static void clear_numa_thread_affinity(void) {} -#endif - -static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { - int n_tasks = 0; - - if (ggml_is_empty(node)) { - // no need to multi-thread a no-op - n_tasks = 1; - return n_tasks; - } - - switch (node->op) { - case GGML_OP_CPY: - case GGML_OP_DUP: - case GGML_OP_CONT: - case GGML_OP_ADD: - case GGML_OP_ADD1: - case GGML_OP_ACC: - { - n_tasks = n_threads; - } break; - case GGML_OP_SUB: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_LOG: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - case GGML_OP_LEAKY_RELU: - { - n_tasks = 1; - } break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(node)) { - case GGML_UNARY_OP_ABS: - case GGML_UNARY_OP_SGN: - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_STEP: - case GGML_UNARY_OP_TANH: - case GGML_UNARY_OP_ELU: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_HARDSWISH: - case GGML_UNARY_OP_HARDSIGMOID: - case GGML_UNARY_OP_EXP: - { - n_tasks = 1; - } break; - - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_GELU_QUICK: - case GGML_UNARY_OP_SILU: - { - n_tasks = n_threads; - } break; - default: - GGML_ABORT("fatal error"); - } - break; - case GGML_OP_SILU_BACK: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_RMS_NORM_BACK: - case GGML_OP_GROUP_NORM: - case GGML_OP_CONCAT: - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - case GGML_OP_OUT_PROD: - { - n_tasks = n_threads; - } break; - case GGML_OP_GET_ROWS: - { - // FIXME: get_rows can use additional threads, but the cost of launching additional threads - // decreases performance with GPU offloading - //n_tasks = n_threads; - n_tasks = 1; - } break; - case GGML_OP_SCALE: - case GGML_OP_SET: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_GET_ROWS_BACK: - case GGML_OP_DIAG: - { - n_tasks = 1; - } break; - case GGML_OP_DIAG_MASK_ZERO: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: - case GGML_OP_ADD_REL_POS: - { - n_tasks = n_threads; - } break; - case GGML_OP_CLAMP: - { - n_tasks = 1; //TODO - } break; - case GGML_OP_SOFT_MAX: - { - n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); - } break; - case GGML_OP_IM2COL: - case GGML_OP_IM2COL_BACK: - case GGML_OP_CONV_TRANSPOSE_1D: - case GGML_OP_CONV_TRANSPOSE_2D: - { - n_tasks = n_threads; - } break; - case GGML_OP_POOL_1D: - case GGML_OP_POOL_2D: - case GGML_OP_POOL_2D_BACK: - { - n_tasks = 1; - } break; - case GGML_OP_UPSCALE: - case GGML_OP_PAD: - case GGML_OP_ARANGE: - case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_ARGSORT: - case GGML_OP_FLASH_ATTN_EXT: - case GGML_OP_FLASH_ATTN_BACK: - case GGML_OP_SSM_CONV: - case GGML_OP_SSM_SCAN: - { - n_tasks = n_threads; - } break; - case GGML_OP_WIN_PART: - case GGML_OP_WIN_UNPART: - case GGML_OP_GET_REL_POS: - case GGML_OP_RWKV_WKV: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: - { - n_tasks = 1; - } break; - case GGML_OP_MAP_CUSTOM1: - { - struct ggml_map_custom1_op_params p; - memcpy(&p, node->op_params, sizeof(p)); - if (p.n_tasks == GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p.n_tasks, n_threads); - } - } break; - case GGML_OP_MAP_CUSTOM2: - { - struct ggml_map_custom2_op_params p; - memcpy(&p, node->op_params, sizeof(p)); - if (p.n_tasks == GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p.n_tasks, n_threads); - } - } break; - case GGML_OP_MAP_CUSTOM3: - { - struct ggml_map_custom3_op_params p; - memcpy(&p, node->op_params, sizeof(p)); - if (p.n_tasks == GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p.n_tasks, n_threads); - } - } break; - case GGML_OP_CROSS_ENTROPY_LOSS: - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - n_tasks = n_threads; - } break; - case GGML_OP_NONE: - { - n_tasks = 1; - } break; - case GGML_OP_COUNT: - { - GGML_ABORT("fatal error"); - } - default: - { - fprintf(stderr, "%s: op not implemented: ", __func__); - if (node->op < GGML_OP_COUNT) { - fprintf(stderr, "%s\n", ggml_op_name(node->op)); - } else { - fprintf(stderr, "%d\n", node->op); - } - GGML_ABORT("fatal error"); - } - } - - assert(n_tasks > 0); - - return n_tasks; + GGML_ASSERT(i < cgraph->n_nodes); + return cgraph->nodes[i]; } -static thread_ret_t ggml_graph_compute_secondary_thread(void* data); - -#if defined(_WIN32) -#include "windows.h" - -// TODO: support > 64 CPUs -bool ggml_thread_apply_affinity(bool * mask) { - HANDLE h = GetCurrentThread(); - uint64_t bitmask = 0ULL; - - assert(GGML_MAX_N_THREADS >= 64); - - for (int32_t i = 0; i < 8; i++) { - int32_t idx = i * 8; - uint8_t val = 0; - val |= mask[idx + 0] << 0; - val |= mask[idx + 1] << 1; - val |= mask[idx + 2] << 2; - val |= mask[idx + 3] << 3; - val |= mask[idx + 4] << 4; - val |= mask[idx + 5] << 5; - val |= mask[idx + 6] << 6; - val |= mask[idx + 7] << 7; - bitmask |= (uint64_t)val << idx; - } - - for (int32_t i = 64; i < GGML_MAX_N_THREADS; i++) { - if (mask[i]) { - fprintf(stderr, "warn: setting thread-affinity for > 64 CPUs isn't supported on windows!\n"); - break; - } - } - - DWORD_PTR m = (DWORD_PTR)bitmask; - - m = SetThreadAffinityMask(h, m); - - return m != 0; +struct ggml_tensor ** ggml_graph_nodes(struct ggml_cgraph * cgraph) { + return cgraph->nodes; } -static bool ggml_thread_apply_priority(int32_t prio) { - // Note that on Windows the Process Priority Class must be updated in order to set Thread priority. - // This is up to the applications. - DWORD p = THREAD_PRIORITY_NORMAL; - switch (prio) { - case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break; - case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break; - case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break; - case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break; - } - - if (prio == GGML_SCHED_PRIO_NORMAL) { - // Keep inherited policy/priority - return true; - } - - if (!SetThreadPriority(GetCurrentThread(), p)) { - fprintf(stderr, "warn: failed to set thread priority %d : (%d)\n", prio, (int) GetLastError()); - return false; - } - - return true; +int ggml_graph_n_nodes(struct ggml_cgraph * cgraph) { + return cgraph->n_nodes; } -#elif defined(__APPLE__) -#include -#include - -static bool ggml_thread_apply_affinity(const bool * mask) { - // Not supported on Apple platforms - UNUSED(mask); - return true; +void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { + GGML_ASSERT(cgraph->size > cgraph->n_nodes); + cgraph->nodes[cgraph->n_nodes] = tensor; + cgraph->n_nodes++; } -static bool ggml_thread_apply_priority(int32_t prio) { - struct sched_param p; - int32_t policy = SCHED_OTHER; - switch (prio) { - case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; - case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; - case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; - case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; - } - - if (prio == GGML_SCHED_PRIO_NORMAL) { - // Keep inherited policy/priority - return true; - } - - int32_t err = pthread_setschedparam(pthread_self(), policy, &p); - if (err != 0) { - fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); - return false; - } - - return true; -} - -#elif defined(__gnu_linux__) -// TODO: this may not work on BSD, to be verified - -static bool ggml_thread_apply_affinity(const bool * mask) { - cpu_set_t cpuset; - int err; - - CPU_ZERO(&cpuset); - - for (uint32_t i = 0; i < GGML_MAX_N_THREADS; i++) { - if (mask[i]) { - GGML_PRINT_DEBUG("Thread %lx: adding %d to cpuset\n", pthread_self(), i); - CPU_SET(i, &cpuset); - } - } - -#ifdef __ANDROID__ - err = sched_setaffinity(0, sizeof(cpuset), &cpuset); - if (err < 0) { - err = errno; - } -#else - err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset); -#endif - if (err != 0) { - fprintf(stderr, "warn: failed to set affinity mask 0x%llx : %s (%d)\n", (unsigned long long)mask, strerror(err), err); - return false; - } - - return true; -} - -static bool ggml_thread_apply_priority(int32_t prio) { - struct sched_param p; - int32_t policy = SCHED_OTHER; - switch (prio) { - case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; - case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; - case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; - case GGML_SCHED_PRIO_REALTIME: policy = SCHED_FIFO; p.sched_priority = 90; break; - } - - if (prio == GGML_SCHED_PRIO_NORMAL) { - // Keep inherited policy/priority - return true; - } - - int32_t err = pthread_setschedparam(pthread_self(), policy, &p); - if (err != 0) { - fprintf(stderr, "warn: failed to set thread priority %d : %s (%d)\n", prio, strerror(err), err); - return false; - } - - return true; -} - -#else // unsupported platforms - -static bool ggml_thread_apply_affinity(const bool * mask) { - UNUSED(mask); - return true; -} - -static bool ggml_thread_apply_priority(int32_t prio) { - UNUSED(prio); - return true; -} - -#endif - -static bool ggml_thread_cpumask_is_valid(const bool * mask) { - for (int i = 0; i < GGML_MAX_N_THREADS; i++) { - if (mask[i]) { return true; } - } - return false; -} - -static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask, bool strict, int32_t* iter) { - if (!strict) { - memcpy(local_mask, global_mask, GGML_MAX_N_THREADS); - return; - } else { - memset(local_mask, 0, GGML_MAX_N_THREADS); - int32_t base_idx = *iter; - for (int32_t i = 0; i < GGML_MAX_N_THREADS; i++) { - int32_t idx = base_idx + i; - if (idx >= GGML_MAX_N_THREADS) { - // Just a cheaper modulo - idx -= GGML_MAX_N_THREADS; - } - if (global_mask[idx]) { - local_mask[idx] = 1; - *iter = idx + 1; - return; - } - } - } -} - -void ggml_threadpool_free(struct ggml_threadpool* threadpool) { - if (!threadpool) return; - -#ifndef GGML_USE_OPENMP - struct ggml_compute_state* workers = threadpool->workers; - const int n_threads = threadpool->n_threads_max; - - ggml_mutex_lock(&threadpool->mutex); - - threadpool->stop = true; - threadpool->pause = false; - - ggml_cond_broadcast(&threadpool->cond); - ggml_mutex_unlock(&threadpool->mutex); - - for (int j = 1; j < n_threads; j++) { - int32_t rc = ggml_thread_join(workers[j].thrd, NULL); - GGML_ASSERT(rc == GGML_EXIT_SUCCESS || rc == GGML_EXIT_ABORTED); - UNUSED(rc); - } - - ggml_mutex_destroy(&threadpool->mutex); - ggml_cond_destroy(&threadpool->cond); -#endif // GGML_USE_OPENMP - - GGML_ALIGNED_FREE(threadpool->workers); - GGML_ALIGNED_FREE(threadpool); -} - -#ifndef GGML_USE_OPENMP -// pause/resume must be called under mutex -static void ggml_threadpool_pause_locked(struct ggml_threadpool * threadpool) { - GGML_PRINT_DEBUG("Pausing threadpool\n"); - threadpool->pause = true; - ggml_cond_broadcast(&threadpool->cond); -} - -static void ggml_threadpool_resume_locked(struct ggml_threadpool * threadpool) { - GGML_PRINT_DEBUG("Resuming threadpool\n"); - threadpool->pause = false; - ggml_cond_broadcast(&threadpool->cond); -} -#endif - -void ggml_threadpool_pause(struct ggml_threadpool * threadpool) { -#ifndef GGML_USE_OPENMP - ggml_mutex_lock(&threadpool->mutex); - if (!threadpool->pause) { - ggml_threadpool_pause_locked(threadpool); - } - ggml_mutex_unlock(&threadpool->mutex); -#else - UNUSED(threadpool); -#endif -} - -void ggml_threadpool_resume(struct ggml_threadpool * threadpool) { -#ifndef GGML_USE_OPENMP - ggml_mutex_lock(&threadpool->mutex); - if (threadpool->pause) { - ggml_threadpool_resume_locked(threadpool); - } - ggml_mutex_unlock(&threadpool->mutex); -#else - UNUSED(threadpool); -#endif -} - -struct ggml_cplan ggml_graph_plan( - const struct ggml_cgraph * cgraph, - int n_threads, - struct ggml_threadpool * threadpool) { - - if (threadpool == NULL) { - GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); - } - if (n_threads <= 0) { - n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; - } - - size_t work_size = 0; - - struct ggml_cplan cplan; - memset(&cplan, 0, sizeof(struct ggml_cplan)); - - int max_tasks = 1; - - // thread scheduling for the different operations + work buffer size estimation - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * node = cgraph->nodes[i]; - - const int n_tasks = ggml_get_n_tasks(node, n_threads); - - max_tasks = MAX(max_tasks, n_tasks); - - size_t cur = 0; - - switch (node->op) { - case GGML_OP_CPY: - case GGML_OP_DUP: - { - if (ggml_is_quantized(node->type) || - // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 - (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || - (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { - cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; - } - } break; - case GGML_OP_ADD: - case GGML_OP_ADD1: - { - if (ggml_is_quantized(node->src[0]->type)) { - cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; - } - } break; - case GGML_OP_ACC: - { - if (ggml_is_quantized(node->src[0]->type)) { - cur = ggml_type_size(GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; - } - } break; - case GGML_OP_MUL_MAT: - { - const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; - - if (node->src[1]->type != vec_dot_type) { - cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); - } - } break; - case GGML_OP_MUL_MAT_ID: - { - cur = 0; - const struct ggml_tensor * src0 = node->src[0]; - const struct ggml_tensor * src1 = node->src[1]; - const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; - if (src1->type != vec_dot_type) { - cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)); - } - const int n_as = src0->ne[2]; - cur += GGML_PAD(cur, sizeof(int64_t)); // align - cur += n_as * sizeof(int64_t); // matrix_row_counts - cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows - } break; - case GGML_OP_OUT_PROD: - { - if (ggml_is_quantized(node->src[0]->type)) { - cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; - } - } break; - case GGML_OP_SOFT_MAX: - case GGML_OP_ROPE: - { - cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; - } break; - case GGML_OP_CONV_TRANSPOSE_1D: - { - GGML_ASSERT(node->src[0]->ne[3] == 1); - GGML_ASSERT(node->src[1]->ne[2] == 1); - GGML_ASSERT(node->src[1]->ne[3] == 1); - - const int64_t ne00 = node->src[0]->ne[0]; // K - const int64_t ne01 = node->src[0]->ne[1]; // Cout - const int64_t ne02 = node->src[0]->ne[2]; // Cin - - const int64_t ne10 = node->src[1]->ne[0]; // L - const int64_t ne11 = node->src[1]->ne[1]; // Cin - - if ((node->src[0]->type == GGML_TYPE_F16 || - node->src[0]->type == GGML_TYPE_BF16) && - node->src[1]->type == GGML_TYPE_F32) { - cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02; - cur += sizeof(ggml_fp16_t)*ne10*ne11; - } else if (node->src[0]->type == GGML_TYPE_F32 && - node->src[1]->type == GGML_TYPE_F32) { - cur += sizeof(float)*ne00*ne01*ne02; - cur += sizeof(float)*ne10*ne11; - } else { - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_CONV_TRANSPOSE_2D: - { - const int64_t ne00 = node->src[0]->ne[0]; // W - const int64_t ne01 = node->src[0]->ne[1]; // H - const int64_t ne02 = node->src[0]->ne[2]; // Channels Out - const int64_t ne03 = node->src[0]->ne[3]; // Channels In - - const int64_t ne10 = node->src[1]->ne[0]; // W - const int64_t ne11 = node->src[1]->ne[1]; // H - const int64_t ne12 = node->src[1]->ne[2]; // Channels In - - cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; - cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; - } break; - case GGML_OP_FLASH_ATTN_EXT: - { - const int64_t ne00 = node->src[0]->ne[0]; // D - - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - const int64_t D = node->src[0]->ne[0]; - const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); - const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back - if (node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_F16) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_BF16) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } - } break; - - case GGML_OP_CROSS_ENTROPY_LOSS: - { - cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); - } break; - case GGML_OP_COUNT: - { - GGML_ABORT("fatal error"); - } - default: - break; - } - - work_size = MAX(work_size, cur); - } - - if (work_size > 0) { - work_size += CACHE_LINE_SIZE*(n_threads); - } - - cplan.threadpool = threadpool; - cplan.n_threads = MIN(max_tasks, n_threads); - cplan.work_size = work_size; - cplan.work_data = NULL; - - return cplan; -} - -static thread_ret_t ggml_graph_compute_thread(void * data) { - struct ggml_compute_state * state = (struct ggml_compute_state *) data; - - const struct ggml_cgraph * cgraph = state->threadpool->cgraph; - const struct ggml_cplan * cplan = state->threadpool->cplan; - - set_numa_thread_affinity(state->ith); - - struct ggml_compute_params params = { - /*.ith =*/ state->ith, - /*.nth =*/ state->threadpool->n_threads_cur, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - /*.threadpool=*/ state->threadpool, - }; - - for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { - struct ggml_tensor * node = cgraph->nodes[node_n]; - - ggml_compute_forward(¶ms, node); - - if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - state->threadpool->ec = GGML_STATUS_ABORTED; - } - - ggml_barrier(state->threadpool); - - if (state->threadpool->ec != GGML_STATUS_SUCCESS) { - break; - } - } - - return 0; -} - -#ifndef GGML_USE_OPENMP - -static inline bool ggml_graph_compute_ready(struct ggml_compute_state * state) { - struct ggml_threadpool * threadpool = state->threadpool; - - if (state->pending || threadpool->stop || threadpool->pause) { return true; } - - // check for new graph/work - int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); - if (new_graph != state->last_graph) { - state->pending = (state->ith < threadpool->n_threads_cur); - state->last_graph = new_graph; - } - - return state->pending; -} - -static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) { - struct ggml_threadpool * threadpool = state->threadpool; - - // This seems to make 0 ... 100 a decent range for polling level across modern processors. - // Perhaps, we can adjust it dynamically based on load and things. - const uint64_t n_rounds = 1024UL * 128 * threadpool->poll; - - for (uint64_t i=0; !ggml_graph_compute_ready(state) && ipending; -} - -static inline bool ggml_graph_compute_check_for_work(struct ggml_compute_state * state) { - struct ggml_threadpool * threadpool = state->threadpool; - - if (ggml_graph_compute_poll_for_work(state)) { - return state->pending; - } - - ggml_mutex_lock_shared(&threadpool->mutex); - while (!ggml_graph_compute_ready(state)) { - // No new work. Wait for the signal. - GGML_PRINT_DEBUG("thread #%d waiting for work\n", state->ith); - ggml_cond_wait(&threadpool->cond, &threadpool->mutex); - } - ggml_mutex_unlock_shared(&threadpool->mutex); - - return state->pending; -} - -static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { - struct ggml_compute_state * state = (struct ggml_compute_state *) data; - struct ggml_threadpool * threadpool = state->threadpool; - - ggml_thread_apply_priority(threadpool->prio); - if (ggml_thread_cpumask_is_valid(state->cpumask)) { - ggml_thread_apply_affinity(state->cpumask); - } - - while (true) { - // Check if we need to sleep - while (threadpool->pause) { - GGML_PRINT_DEBUG("thread #%d inside pause loop\n", state->ith); - ggml_mutex_lock_shared(&threadpool->mutex); - if (threadpool->pause) { - ggml_cond_wait(&threadpool->cond, &threadpool->mutex); - } - GGML_PRINT_DEBUG("thread #%d resuming after wait\n", state->ith); - ggml_mutex_unlock_shared(&threadpool->mutex); - } - - // This needs to be checked for after the cond_wait - if (threadpool->stop) break; - - // Check if there is new work - // The main thread is the only one that can dispatch new work - - ggml_graph_compute_check_for_work(state); - if (state->pending) { - state->pending = false; - - ggml_graph_compute_thread(state); - } - } - - return (thread_ret_t) 0; -} - -// Start processing new graph -static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool) -{ - // always take the mutex here because the worker threads are doing hybrid poll/wait - - ggml_mutex_lock(&threadpool->mutex); - - atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_relaxed); - - if (threadpool->pause) { - // Update main thread prio and affinity to match the threadpool settings - ggml_thread_apply_priority(threadpool->prio); - if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { - ggml_thread_apply_affinity(threadpool->workers[0].cpumask); - } - - // resume does cond broadcast - ggml_threadpool_resume_locked(threadpool); - } else { - ggml_cond_broadcast(&threadpool->cond); - } - - ggml_mutex_unlock(&threadpool->mutex); -} - -#endif // GGML_USE_OPENMP - -void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) { - p->n_threads = n_threads; - p->prio = 0; // default priority (usually means normal or inherited) - p->poll = 50; // hybrid-polling enabled - p->strict_cpu = false; // no strict placement (all threads share same cpumask) - p->paused = false; // threads are ready to go - memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited) -} - -struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) { - struct ggml_threadpool_params p; - ggml_threadpool_params_init(&p, n_threads); - return p; -} - -bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { - if (p0->n_threads != p1->n_threads ) return false; - if (p0->prio != p1->prio ) return false; - if (p0->poll != p1->poll ) return false; - if (p0->strict_cpu != p1->strict_cpu ) return false; - return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; -} - -static struct ggml_threadpool * ggml_threadpool_new_impl( - struct ggml_threadpool_params * tpp, - struct ggml_cgraph * cgraph, - struct ggml_cplan * cplan) { - - struct ggml_threadpool * threadpool = - GGML_ALIGNED_MALLOC(sizeof(struct ggml_threadpool)); - { - threadpool->cgraph = cgraph; - threadpool->cplan = cplan; - threadpool->n_graph = 0; - threadpool->n_barrier = 0; - threadpool->n_barrier_passed = 0; - threadpool->current_chunk = 0; - threadpool->stop = false; - threadpool->pause = tpp->paused; - threadpool->workers = NULL; - threadpool->n_threads_max = tpp->n_threads; - threadpool->n_threads_cur = tpp->n_threads; - threadpool->poll = tpp->poll; - threadpool->prio = tpp->prio; - threadpool->ec = GGML_STATUS_SUCCESS; - } - - // Allocate and init workers state - const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads; - struct ggml_compute_state * workers = GGML_ALIGNED_MALLOC(workers_size); - - memset(workers, 0, workers_size); - for (int j = 0; j < tpp->n_threads; j++) { - workers[j].threadpool = threadpool; - workers[j].ith = j; - } - - threadpool->workers = workers; - -#ifndef GGML_USE_OPENMP - ggml_mutex_init(&threadpool->mutex); - ggml_cond_init(&threadpool->cond); - - // Spin the threads for all workers, and update CPU placements. - // Place the main thread last (towards the higher numbered CPU cores). - - int32_t cpumask_iter = 0; - - for (int j = 1; j < tpp->n_threads; j++) { - ggml_thread_cpumask_next(tpp->cpumask, workers[j].cpumask, tpp->strict_cpu, &cpumask_iter); - - int32_t rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_secondary_thread, &workers[j]); - GGML_ASSERT(rc == 0); - } - - ggml_thread_cpumask_next(tpp->cpumask, workers[0].cpumask, tpp->strict_cpu, &cpumask_iter); - - if (!threadpool->pause) { - // Update main thread prio and affinity at the start, otherwise we'll do it in resume - ggml_thread_apply_priority(threadpool->prio); - if (ggml_thread_cpumask_is_valid(threadpool->workers[0].cpumask)) { - ggml_thread_apply_affinity(threadpool->workers[0].cpumask); - } - } -#endif // GGML_USE_OPENMP - - return threadpool; -} - -struct ggml_threadpool * ggml_threadpool_new(struct ggml_threadpool_params * tpp) { - return ggml_threadpool_new_impl(tpp, NULL, NULL); -} - -enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { - GGML_ASSERT(cplan); - GGML_ASSERT(cplan->n_threads > 0); - GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL); - - int n_threads = cplan->n_threads; - struct ggml_threadpool * threadpool = cplan->threadpool; - - bool disposable_threadpool = false; - - if (threadpool == NULL) { - GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); - disposable_threadpool = true; - - struct ggml_threadpool_params ttp = ggml_threadpool_params_default(n_threads); - threadpool = ggml_threadpool_new_impl(&ttp, cgraph, cplan); - } else { - // Reset some of the parameters that need resetting - // No worker threads should be accessing the parameters below at this stage - threadpool->cgraph = cgraph; - threadpool->cplan = cplan; - threadpool->n_threads_cur = n_threads; - threadpool->current_chunk = 0; - threadpool->ec = GGML_STATUS_SUCCESS; - } - - if (n_threads > threadpool->n_threads_max) { - GGML_PRINT("WARNING: cplan is requesting more threads than the threadpool contains. Expect a bad time!\n"); - } - -#ifdef GGML_USE_OPENMP - if (n_threads > 1) { - #pragma omp parallel num_threads(n_threads) - { - #pragma omp single - { - // update the number of threads from the actual number of threads that we got from OpenMP - n_threads = omp_get_num_threads(); - threadpool->n_threads_cur = n_threads; - } - - ggml_graph_compute_thread(&threadpool->workers[omp_get_thread_num()]); - } - } else { - ggml_graph_compute_thread(&threadpool->workers[0]); - } -#else - // Kick all threads to start the new graph - ggml_graph_compute_kickoff(threadpool); - - // This is a work thread too - ggml_graph_compute_thread(&threadpool->workers[0]); -#endif - - // don't leave affinity set on the main thread - clear_numa_thread_affinity(); - - enum ggml_status ret = threadpool->ec; - - if (disposable_threadpool) { - ggml_threadpool_free(threadpool); - } - - return ret; -} - -enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { - struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads, NULL); - - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); - - cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - - return ggml_graph_compute(cgraph, &cplan); -} - -struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) { +struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) { for (int i = 0; i < cgraph->n_leafs; i++) { struct ggml_tensor * leaf = cgraph->leafs[i]; @@ -20223,516 +6134,42 @@ struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const ch return NULL; } -static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fout) { - const int64_t * ne = tensor->ne; - const size_t * nb = tensor->nb; - - fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", - ggml_type_name(tensor->type), - ggml_op_name (tensor->op), - ggml_n_dims(tensor), - ne[0], ne[1], ne[2], ne[3], - nb[0], nb[1], nb[2], nb[3], - tensor->data, - tensor->name); +struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL; } -static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char * arg, FILE * fout) { - const int64_t * ne = tensor->ne; - const size_t * nb = tensor->nb; - - fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", - arg, - ggml_type_name(tensor->type), - ggml_op_name (tensor->op), - ggml_n_dims(tensor), - ne[0], ne[1], ne[2], ne[3], - nb[0], nb[1], nb[2], nb[3], - tensor->data, - tensor->name); -} - -void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { - uint64_t size_eval = 0; - - // compute size of intermediate results - // TODO: does not take into account scratch buffers !!!! - for (int i = 0; i < cgraph->n_nodes; ++i) { - size_eval += ggml_nbytes_pad(cgraph->nodes[i]); - } - - // print - { - FILE * fout = stdout; - - fprintf(fout, "\n"); - fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); - fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); - fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); - fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); - fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval); - - // header - fprintf(fout, "\n"); - fprintf(fout, "%-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %16s %16s\n", - "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "DATA", "NAME"); - - for (int i = 0; i < cgraph->n_leafs; ++i) { - ggml_graph_export_leaf(cgraph->leafs[i], fout); - - GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE); - GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL); - GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL); - } - - // header - fprintf(fout, "\n"); - fprintf(fout, "%-6s %-6s %-12s %8s %8s %8s %8s %8s %16s %16s %16s %16s %8s %16s %16s\n", - "ARG", "TYPE", "OP", "NDIMS", "NE0", "NE1", "NE2", "NE3", "NB0", "NB1", "NB2", "NB3", "NTASKS", "DATA", "NAME"); - - for (int i = 0; i < cgraph->n_nodes; ++i) { - ggml_graph_export_node(cgraph->nodes[i], "DST", fout); - - for (int j = 0; j < GGML_MAX_SRC; ++j) { - if (cgraph->nodes[i]->src[j]) { - ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout); - } - } - - fprintf(fout, "\n"); - } - - fprintf(fout, "\n"); - } - - // write binary data - { - FILE * fout = ggml_fopen(fname, "wb"); - - if (!fout) { - fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno)); - return; - } - - // header - { - const uint32_t magic = GGML_FILE_MAGIC; - const uint32_t version = GGML_FILE_VERSION; - const uint32_t n_leafs = cgraph->n_leafs; - const uint32_t n_nodes = cgraph->n_nodes; - - fwrite(&magic, sizeof(uint32_t), 1, fout); - fwrite(&version, sizeof(uint32_t), 1, fout); - fwrite(&n_leafs, sizeof(uint32_t), 1, fout); - fwrite(&n_nodes, sizeof(uint32_t), 1, fout); - fwrite(&size_eval, sizeof(uint64_t), 1, fout); - } - - // leafs - { - for (int i = 0; i < cgraph->n_leafs; ++i) { - const struct ggml_tensor * tensor = cgraph->leafs[i]; - - const uint32_t type = tensor->type; - const uint32_t op = tensor->op; - const int32_t flags = tensor->flags; - - fwrite(&type, sizeof(uint32_t), 1, fout); - fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&flags, sizeof(int32_t), 1, fout); - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - const uint64_t ne = tensor->ne[j]; - const uint64_t nb = tensor->nb[j]; - - fwrite(&ne, sizeof(uint64_t), 1, fout); - fwrite(&nb, sizeof(uint64_t), 1, fout); - } - - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); - fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); - - // dump the data - // TODO: pad this to 32 byte boundary - { - const size_t size = ggml_nbytes(tensor); - - fwrite(tensor->data, sizeof(char), size, fout); - } - } - } - - // nodes - { - for (int i = 0; i < cgraph->n_nodes; ++i) { - const struct ggml_tensor * tensor = cgraph->nodes[i]; - - const uint32_t type = tensor->type; - const uint32_t op = tensor->op; - const int32_t flags = tensor->flags; - - fwrite(&type, sizeof(uint32_t), 1, fout); - fwrite(&op, sizeof(uint32_t), 1, fout); - fwrite(&flags, sizeof(int32_t), 1, fout); - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - const uint64_t ne = tensor->ne[j]; - const uint64_t nb = tensor->nb[j]; - - fwrite(&ne, sizeof(uint64_t), 1, fout); - fwrite(&nb, sizeof(uint64_t), 1, fout); - } - - fwrite(tensor->name, sizeof(char), GGML_MAX_NAME, fout); - fwrite(tensor->op_params, sizeof(char), GGML_MAX_OP_PARAMS, fout); - - // output the op arguments - { - struct ggml_tensor * args[GGML_MAX_SRC] = { NULL }; - - for (int j = 0; j < GGML_MAX_SRC; ++j) { - args[j] = tensor->src[j]; - } - - for (int j = 0; j < GGML_MAX_SRC; ++j) { - if (args[j]) { - int32_t idx = -1; - - // check if leaf - { - for (int k = 0; k < cgraph->n_leafs; ++k) { - if (args[j] == cgraph->leafs[k]) { - idx = k; - break; - } - } - } - - // check if node - if (idx == -1) { - for (int k = 0; k < cgraph->n_nodes; ++k) { - if (args[j] == cgraph->nodes[k]) { - idx = cgraph->n_leafs + k; - break; - } - } - } - - if (idx == -1) { - fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); - fclose(fout); - return; - } - - fwrite(&idx, sizeof(int32_t), 1, fout); - } else { - const int32_t nul = -1; - - fwrite(&nul, sizeof(int32_t), 1, fout); - } - } - } - - // dump the data - // TODO: pad this to 32 byte boundary - if ((flags & GGML_TENSOR_FLAG_PARAM)) { - const size_t size = ggml_nbytes(tensor); - - fwrite(tensor->data, sizeof(char), size, fout); - } - } - } - - fclose(fout); - } -} - -struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval) { - assert(*ctx_data == NULL); - assert(*ctx_eval == NULL); - - struct ggml_cgraph * result = NULL; - - struct ggml_tensor * data = NULL; - - // read file into data - { - FILE * fin = ggml_fopen(fname, "rb"); - if (!fin) { - fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno)); - return result; - } - - size_t fsize = 0; - - fseek(fin, 0, SEEK_END); - fsize = ftell(fin); - fseek(fin, 0, SEEK_SET); - - // create the data context - { - const size_t overhead = 1*ggml_tensor_overhead(); - - struct ggml_init_params params = { - .mem_size = fsize + overhead, - .mem_buffer = NULL, - .no_alloc = false, - }; - - *ctx_data = ggml_init(params); - - if (!*ctx_data) { - fprintf(stderr, "%s: failed to create ggml context\n", __func__); - fclose(fin); - return result; - } - } - - data = ggml_new_tensor_1d(*ctx_data, GGML_TYPE_I8, fsize); - - { - const size_t ret = fread(data->data, sizeof(char), fsize, fin); - if (ret != fsize) { - fprintf(stderr, "%s: failed to read %s\n", __func__, fname); - fclose(fin); - return result; - } - } - - fclose(fin); - } - - // populate result - { - char * ptr = (char *) data->data; - - const uint32_t magic = *(const uint32_t *) ptr; ptr += sizeof(magic); - - if (magic != GGML_FILE_MAGIC) { - fprintf(stderr, "%s: invalid magic number, got %08x\n", __func__, magic); - return result; - } - - const uint32_t version = *(const uint32_t *) ptr; ptr += sizeof(version); - - if (version != GGML_FILE_VERSION) { - fprintf(stderr, "%s: invalid version number\n", __func__); - return result; - } - - const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs); - const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes); - const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval); - const int graph_size = MAX(n_leafs, n_nodes); - - // create the data context - { - const size_t overhead = (n_leafs + n_nodes)*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph_size, false); - - struct ggml_init_params params = { - .mem_size = size_eval + overhead, - .mem_buffer = NULL, - .no_alloc = true, - }; - - *ctx_eval = ggml_init(params); - - if (!*ctx_eval) { - fprintf(stderr, "%s: failed to create ggml context\n", __func__); - return result; - } - } - - result = ggml_new_graph_custom(*ctx_eval, graph_size, false); - - result->n_leafs = n_leafs; - result->n_nodes = n_nodes; - - - // leafs - { - uint32_t type; - uint32_t op; - int32_t flags; - - for (uint32_t i = 0; i < n_leafs; ++i) { - type = *(const uint32_t *) ptr; ptr += sizeof(type); - op = *(const uint32_t *) ptr; ptr += sizeof(op); - flags = *(const int32_t *) ptr; ptr += sizeof(flags); - - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - uint64_t ne_cur; - uint64_t nb_cur; - - ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); - nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); - - ne[j] = ne_cur; - nb[j] = nb_cur; - } - - struct ggml_tensor * tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, GGML_MAX_DIMS, ne); - - tensor->op = (enum ggml_op) op; - tensor->flags = flags; - - memcpy(tensor->name, ptr, GGML_MAX_NAME); ptr += GGML_MAX_NAME; - memcpy(tensor->op_params, ptr, GGML_MAX_OP_PARAMS); ptr += GGML_MAX_OP_PARAMS; - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - tensor->nb[j] = nb[j]; - } - - tensor->data = (void *) ptr; ptr += ggml_nbytes(tensor); - - result->leafs[i] = tensor; - - fprintf(stderr, "%s: loaded leaf %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); - } - } - - ggml_set_no_alloc(*ctx_eval, false); - - // nodes - { - uint32_t type; - uint32_t op; - int32_t flags; - - for (uint32_t i = 0; i < n_nodes; ++i) { - type = *(const uint32_t *) ptr; ptr += sizeof(type); - op = *(const uint32_t *) ptr; ptr += sizeof(op); - flags = *(const int32_t *) ptr; ptr += sizeof(flags); - - enum ggml_op eop = (enum ggml_op) op; - - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - uint64_t ne_cur; - uint64_t nb_cur; - - ne_cur = *(const uint64_t *) ptr; ptr += sizeof(ne_cur); - nb_cur = *(const uint64_t *) ptr; ptr += sizeof(nb_cur); - - ne[j] = ne_cur; - nb[j] = nb_cur; - } - - const char * ptr_name = ptr; ptr += GGML_MAX_NAME; - const char * ptr_op_params = ptr; ptr += GGML_MAX_OP_PARAMS; - - const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t); - - struct ggml_tensor * args[GGML_MAX_SRC] = { NULL }; - - // parse args - for (int j = 0; j < GGML_MAX_SRC; ++j) { - const int32_t arg_idx = ptr_arg_idx[j]; - - if (arg_idx == -1) { - continue; - } - - if (arg_idx < result->n_leafs) { - args[j] = result->leafs[arg_idx]; - } else { - args[j] = result->nodes[arg_idx - result->n_leafs]; - } - } - - // create the tensor - // "view" operations are handled differently - // TODO: handle inplace ops - currently a copy is always made - - struct ggml_tensor * tensor = NULL; - - switch (eop) { - // TODO: implement other view ops - case GGML_OP_RESHAPE: - { - tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]); - } break; - case GGML_OP_VIEW: - { - tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); - - size_t offs; - memcpy(&offs, ptr_op_params, sizeof(offs)); - - tensor->data = ((char *) tensor->data) + offs; - } break; - case GGML_OP_TRANSPOSE: - { - tensor = ggml_transpose(*ctx_eval, args[0]); - } break; - case GGML_OP_PERMUTE: - { - tensor = ggml_view_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3], 0, 0, 0, 0); - } break; - default: - { - tensor = ggml_new_tensor(*ctx_eval, (enum ggml_type) type, GGML_MAX_DIMS, ne); - - tensor->op = eop; - } break; - } - - memcpy(tensor->name, ptr_name, GGML_MAX_NAME); - memcpy(tensor->op_params, ptr_op_params, GGML_MAX_OP_PARAMS); - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - tensor->nb[j] = nb[j]; - } - - for (int j = 0; j < GGML_MAX_SRC; ++j) { - tensor->src[j] = args[j]; - } - - result->nodes[i] = tensor; - - // TODO tensor data is be duplicated due to ggml_new_tensor call above - if (flags & GGML_TENSOR_FLAG_PARAM) { - tensor->data = (void *) ptr; ptr += ggml_nbytes(tensor); - } - - fprintf(stderr, "%s: loaded node %u: '%16s', %9zu bytes\n", __func__, i, tensor->name, ggml_nbytes(tensor)); - } - } - } - - return result; +struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL; } void ggml_graph_print(const struct ggml_cgraph * cgraph) { - GGML_PRINT("=== GRAPH ===\n"); + GGML_LOG_INFO("=== GRAPH ===\n"); - GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); + GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes); for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; - GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", + GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", i, node->ne[0], node->ne[1], node->ne[2], - ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " "); + ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : + ggml_graph_get_grad(cgraph, node) ? "g" : " "); } - GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); + GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs); for (int i = 0; i < cgraph->n_leafs; i++) { struct ggml_tensor * node = cgraph->leafs[i]; - GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", + GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n", i, node->ne[0], node->ne[1], ggml_op_name(node->op), ggml_get_name(node)); } - GGML_PRINT("========================================\n"); + GGML_LOG_INFO("========================================\n"); } // check if node is part of the graph @@ -20753,8 +6190,9 @@ static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * parent = cgraph->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent); - if (parent->grad == node) { + if (grad == node) { return parent; } } @@ -20794,6 +6232,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph for (int i = 0; i < gb->n_nodes; i++) { struct ggml_tensor * node = gb->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(gb, node); if (ggml_graph_get_parent(gb, node) != NULL) { continue; @@ -20801,7 +6240,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->flags & GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); - } else if (node->grad) { + } else if (grad) { if (ggml_graph_find(gf, node)) { snprintf(color, sizeof(color), "green"); } else { @@ -20828,8 +6267,8 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op)); } - if (node->grad) { - fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(node->grad->op)); + if (grad) { + fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(grad->op)); } else { fprintf(fp, "\"; ]\n"); } @@ -20855,15 +6294,17 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (ggml_nelements(node) < 5 && node->data != NULL) { fprintf(fp, " | ("); for (int j = 0; j < ggml_nelements(node); j++) { - if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { - fprintf(fp, "%d", ggml_get_i32_1d(node, j)); - } - else if (node->type == GGML_TYPE_F32 || - node->type == GGML_TYPE_F16 || - node->type == GGML_TYPE_BF16) { - fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j)); - } - else { + // FIXME: use ggml-backend to obtain the tensor data + //if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { + // fprintf(fp, "%d", ggml_get_i32_1d(node, j)); + //} + //else if (node->type == GGML_TYPE_F32 || + // node->type == GGML_TYPE_F16 || + // node->type == GGML_TYPE_BF16) { + // fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, j)); + //} + //else + { fprintf(fp, "#"); } if (j < ggml_nelements(node) - 1) { @@ -20903,921 +6344,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph fclose(fp); - GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); -} - -//////////////////////////////////////////////////////////////////////////////// - -static void ggml_opt_set_params(int np, struct ggml_tensor * const ps[], const float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]) ; - // TODO: add function to set tensor from array - for (int64_t j = 0; j < ne; ++j) { - ggml_set_f32_1d(ps[p], j, x[i++]); - } - } -} - -static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * x) { - int i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int64_t j = 0; j < ne; ++j) { - x[i++] = ggml_get_f32_1d(ps[p], j); - } - } -} - -static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { - int64_t i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int64_t j = 0; j < ne; ++j) { - g[i++] = ggml_get_f32_1d(ps[p]->grad, j); - } - } -} - -static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) { - int64_t i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]) ; - // TODO: add function to get all elements at once - for (int64_t j = 0; j < ne; ++j) { - g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale; - } - } -} - -// -// Using AdamW - ref: https://arxiv.org/pdf/1711.05101v3.pdf -// -// (Original Adam - ref: https://arxiv.org/pdf/1412.6980.pdf) -// - -static enum ggml_opt_result ggml_opt_adam( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data) { - GGML_ASSERT(ggml_is_scalar(f)); - GGML_ASSERT(f->type == GGML_TYPE_F32); - - // these will store the parameters we want to optimize - struct ggml_tensor * ps[GGML_MAX_PARAMS]; - - int np = 0; - int64_t nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) { - GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - - GGML_ASSERT(np < GGML_MAX_PARAMS); - - ps[np++] = gf->nodes[i]; - nx += ggml_nelements(gf->nodes[i]); - } - } - - if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past)) { - int iter = opt->iter; - ggml_opt_init(opt->ctx, opt, params, nx); - opt->iter = iter; - } - - // constants - float sched = params.adam.sched; - const float alpha = params.adam.alpha; - const float decay = params.adam.decay * alpha; - const float beta1 = params.adam.beta1; - const float beta2 = params.adam.beta2; - const float eps = params.adam.eps; - const float gclip = params.adam.gclip; - const int decay_min_ndim = params.adam.decay_min_ndim; - const int n_accum = MAX(1, params.n_gradient_accumulation); - const float accum_norm = 1.0f / (float) n_accum; - - float * g = opt->adam.g->data; // gradients - float * m = opt->adam.m->data; // first moment - float * v = opt->adam.v->data; // second moment - - float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - - struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads, NULL); - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); - cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - - bool cancel = false; - - // compute the function value - float fx = 0; - ggml_set_zero(opt->adam.g); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - callback(callback_data, accum_step, &sched, &cancel); - if (cancel) { - return GGML_OPT_RESULT_CANCEL; - } - } - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(gb, &cplan); - ggml_opt_acc_grad(np, ps, g, accum_norm); - fx += ggml_get_f32_1d(f, 0); - } - fx *= accum_norm; - - opt->adam.fx_prev = fx; - opt->adam.fx_best = opt->adam.fx_prev; - if (pf) { - pf[opt->iter % params.past] = opt->adam.fx_prev; - } - - opt->loss_before = opt->adam.fx_prev; - opt->loss_after = opt->adam.fx_prev; - - // initialize - if (opt->just_initialized) { - opt->adam.n_no_improvement = 0; - opt->just_initialized = false; - } - - float * fx_best = &opt->adam.fx_best; - float * fx_prev = &opt->adam.fx_prev; - int * n_no_improvement = &opt->adam.n_no_improvement; - - int iter0 = opt->iter; - - // run the optimizer - for (int t = 0; t < params.adam.n_iter; ++t) { - opt->iter = iter0 + t + 1; - GGML_PRINT_DEBUG ("=== iter %d ===\n", t); - - GGML_PRINT_DEBUG ("f = %10.6f\n", ggml_get_f32_1d(f, 0)); - GGML_PRINT_DEBUG_5("df/dx0 = %10.6f\n", ggml_get_f32_1d(ps[0]->grad, 0)); - GGML_PRINT_DEBUG_5("df/dx1 = %10.6f\n", ggml_get_f32_1d(ps[1]->grad, 0)); - - for (int i = 0; i < np; ++i) { - GGML_PRINT_DEBUG("param %d: %10.6f, g = %10.6f\n", i, - ggml_get_f32_1d(ps[i], 0), ggml_get_f32_1d(ps[i]->grad, 0)); - } - - const int64_t t_start_wall = ggml_time_us(); - const int64_t t_start_cpu = ggml_cycles(); - UNUSED(t_start_wall); - UNUSED(t_start_cpu); - - { - float gnorm = 1.0f; - if (gclip > 0.0f) { - // gradient clipping - ggml_float sum = 0.0; - for (int64_t i = 0; i < nx; ++i) { - sum += (ggml_float)(g[i]*g[i]); - } - ggml_float norm = sqrt(sum); - if (norm > (ggml_float) gclip) { - gnorm = (float) ((ggml_float) gclip / norm); - } - } - const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); - int64_t i = 0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]); - const float p_decay = ((ggml_n_dims(ps[p]) >= decay_min_ndim) ? decay : 0.0f) * sched; - for (int64_t j = 0; j < ne; ++j) { - float x = ggml_get_f32_1d(ps[p], j); - float g_ = g[i]*gnorm; - m[i] = m[i]*beta1 + g_*(1.0f - beta1); - v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2); - float mh = m[i]*beta1h; - float vh = v[i]*beta2h; - vh = sqrtf(vh) + eps; - x = x*(1.0f - p_decay) - mh/vh; - ggml_set_f32_1d(ps[p], j, x); - ++i; - } - } - } - - fx = 0; - ggml_set_zero(opt->adam.g); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - callback(callback_data, accum_step, &sched, &cancel); - if (cancel) { - return GGML_OPT_RESULT_CANCEL;; - } - } - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(gb, &cplan); - ggml_opt_acc_grad(np, ps, g, accum_norm); - fx += ggml_get_f32_1d(f, 0); - } - fx *= accum_norm; - - opt->loss_after = fx; - - // check convergence - if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { - GGML_PRINT_DEBUG("converged\n"); - - return GGML_OPT_RESULT_OK; - } - - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= iter0 + t) { - const float rate = (pf[(iter0 + t)%params.past] - fx)/fx; - - if (fabsf(rate) < params.delta) { - return GGML_OPT_RESULT_OK; - } - } - - pf[(iter0 + t)%params.past] = fx; - } - - // check for improvement - if (params.max_no_improvement > 0) { - if (fx_best[0] > fx) { - fx_best[0] = fx; - n_no_improvement[0] = 0; - } else { - ++n_no_improvement[0]; - - if (n_no_improvement[0] >= params.max_no_improvement) { - return GGML_OPT_RESULT_OK; - } - } - } - - fx_prev[0] = fx; - - { - const int64_t t_end_cpu = ggml_cycles(); - GGML_PRINT_DEBUG("time iter: %5.3f s\n", ((float)(t_end_cpu - t_start_cpu))/CLOCKS_PER_SEC); - UNUSED(t_end_cpu); - - const int64_t t_end_wall = ggml_time_us(); - GGML_PRINT_DEBUG("wall time iter: %5.3f s\n", (t_end_wall - t_start_wall)/1e6); - UNUSED(t_end_wall); - } - } - - return GGML_OPT_RESULT_DID_NOT_CONVERGE; -} - -// -// L-BFGS -// -// the L-BFGS implementation below is based on the following implementation: -// -// https://github.com/chokkan/liblbfgs -// - -struct ggml_lbfgs_iteration_data { - float alpha; - float ys; - float * s; - float * y; -}; - -static enum ggml_opt_result linesearch_backtracking( - const struct ggml_opt_params * params, - int nx, - float * x, - float * fx, - float * g, - float * d, - float * step, - const float * xp, - struct ggml_tensor * f, - struct ggml_cgraph * gb, - struct ggml_cplan * cplan, - const int np, - struct ggml_tensor * ps[], - bool * cancel, - ggml_opt_callback callback, - void * callback_data) { - int count = 0; - - float width = 0.0f; - float dg = 0.0f; - float finit = 0.0f; - float dginit = 0.0f; - float dgtest = 0.0f; - - const float dec = 0.5f; - const float inc = 2.1f; - - const int n_accum = MAX(1, params->n_gradient_accumulation); - const float accum_norm = 1.0f / (float) n_accum; - - if (*step <= 0.f) { - return GGML_LINESEARCH_INVALID_PARAMETERS; - } - - // compute the initial gradient in the search direction - ggml_vec_dot_f32(nx, &dginit, 0, g, 0, d, 0, 1); - - // make sure that d points to a descent direction - if (0 < dginit) { - return GGML_LINESEARCH_FAIL; - } - - // initialize local variables - finit = *fx; - dgtest = params->lbfgs.ftol*dginit; - - while (true) { - ggml_vec_cpy_f32(nx, x, xp); - ggml_vec_mad_f32(nx, x, d, *step); - - // evaluate the function and gradient values - { - ggml_opt_set_params(np, ps, x); - - *fx = 0; - memset(g, 0, sizeof(float)*nx); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, accum_step, &sched, cancel); - if (*cancel) { - return GGML_OPT_RESULT_CANCEL; - } - } - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(gb, cplan); - ggml_opt_acc_grad(np, ps, g, accum_norm); - *fx += ggml_get_f32_1d(f, 0); - } - *fx *= accum_norm; - - } - - ++count; - - if (*fx > finit + (*step)*dgtest) { - width = dec; - } else { - // Armijo condition is satisfied - if (params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_ARMIJO) { - return count; - } - - ggml_vec_dot_f32(nx, &dg, 0, g, 0, d, 0, 1); - - // check the Wolfe condition - if (dg < params->lbfgs.wolfe * dginit) { - width = inc; - } else { - if(params->lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE) { - // regular Wolfe conditions - return count; - } - - if(dg > -params->lbfgs.wolfe*dginit) { - width = dec; - } else { - // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) - return count; - } - } - } - - if (*step < params->lbfgs.min_step) { - return GGML_LINESEARCH_MINIMUM_STEP; - } - if (*step > params->lbfgs.max_step) { - return GGML_LINESEARCH_MAXIMUM_STEP; - } - if (params->lbfgs.max_linesearch <= count) { - return GGML_LINESEARCH_MAXIMUM_ITERATIONS; - } - - (*step) *= width; - } - - GGML_ABORT("line search failed"); - - //return GGML_LINESEARCH_FAIL; -} - -static enum ggml_opt_result ggml_opt_lbfgs( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data) { - if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || - params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { - if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { - return GGML_OPT_RESULT_INVALID_WOLFE; - } - } - - const int m = params.lbfgs.m; - - // these will store the parameters we want to optimize - struct ggml_tensor * ps[GGML_MAX_PARAMS]; - - int np = 0; - int nx = 0; - for (int i = 0; i < gf->n_nodes; ++i) { - if (gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) { - GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); - - GGML_ASSERT(np < GGML_MAX_PARAMS); - - ps[np++] = gf->nodes[i]; - nx += ggml_nelements(gf->nodes[i]); - } - } - - if ((opt->params.type != params.type) || (opt->nx != nx) || (opt->params.past != params.past) || (opt->params.lbfgs.m != params.lbfgs.m)) { - int iter = opt->iter; - ggml_opt_init(ctx, opt, params, nx); - opt->iter = iter; - } - - struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads, NULL); - struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_WORK_BUFFER, cplan.work_size); - cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - - float * x = opt->lbfgs.x->data; // current parameters - float * xp = opt->lbfgs.xp->data; // previous parameters - float * g = opt->lbfgs.g->data; // current gradient - float * gp = opt->lbfgs.gp->data; // previous gradient - float * d = opt->lbfgs.d->data; // search direction - - float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values - - const int n_accum = MAX(1, params.n_gradient_accumulation); - const float accum_norm = 1.0f / (float) n_accum; - - float fx = 0.0f; // cost function value - float xnorm = 0.0f; // ||x|| - float gnorm = 0.0f; // ||g|| - - // initialize x from the graph nodes - ggml_opt_get_params(np, ps, x); - - // the L-BFGS memory - float * lm_alpha = opt->lbfgs.lmal->data; - float * lm_ys = opt->lbfgs.lmys->data; - float * lm_s = opt->lbfgs.lms->data; - float * lm_y = opt->lbfgs.lmy->data; - - bool cancel = false; - - // evaluate the function value and its gradient - { - ggml_opt_set_params(np, ps, x); - - fx = 0; - memset(g, 0, sizeof(float)*nx); - for (int accum_step = 0; accum_step < n_accum; ++accum_step) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, accum_step, &sched, &cancel); - if (cancel) { - return GGML_OPT_RESULT_CANCEL; - } - } - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(gb, &cplan); - ggml_opt_acc_grad(np, ps, g, accum_norm); - fx += ggml_get_f32_1d(f, 0); - } - fx *= accum_norm; - - opt->loss_before = fx; - opt->loss_after = fx; - } - - // search direction = -gradient - ggml_vec_neg_f32(nx, d, g); - - // ||x||, ||g|| - ggml_vec_norm_f32(nx, &xnorm, x); - ggml_vec_norm_f32(nx, &gnorm, g); - - if (xnorm < 1.0f) { - xnorm = 1.0f; - } - - // already optimized - if (gnorm/xnorm <= params.lbfgs.eps) { - return GGML_OPT_RESULT_OK; - } - - if (opt->just_initialized) { - if (pf) { - pf[0] = fx; - } - opt->lbfgs.fx_best = fx; - - // initial step - ggml_vec_norm_inv_f32(nx, &opt->lbfgs.step, d); - opt->lbfgs.j = 0; - opt->lbfgs.k = 1; - opt->lbfgs.end = 0; - opt->lbfgs.n_no_improvement = 0; - opt->just_initialized = false; - } - - float * fx_best = &opt->lbfgs.fx_best; - float * step = &opt->lbfgs.step; - int * j = &opt->lbfgs.j; - int * k = &opt->lbfgs.k; - int * end = &opt->lbfgs.end; - int * n_no_improvement = &opt->lbfgs.n_no_improvement; - - int ls = 0; - int bound = 0; - - float ys = 0.0f; - float yy = 0.0f; - float beta = 0.0f; - - int it = 0; - - while (true) { - // store the current position and gradient vectors - ggml_vec_cpy_f32(nx, xp, x); - ggml_vec_cpy_f32(nx, gp, g); - - // TODO: instead of passing &cancel here, use the return code of the linesearch - // to determine if the optimization should be cancelled - // this is a simple change, but not doing this atm, since I don't have a nice - // way to test and don't want to break something with so many changes lined up - ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); - if (cancel) { - return GGML_OPT_RESULT_CANCEL; - } - - if (ls < 0) { - // linesearch failed - go back to the previous point and return - ggml_vec_cpy_f32(nx, x, xp); - ggml_vec_cpy_f32(nx, g, gp); - - return ls; - } - - opt->loss_after = fx; - - ggml_vec_norm_f32(nx, &xnorm, x); - ggml_vec_norm_f32(nx, &gnorm, g); - - GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); - - if (xnorm < 1.0f) { - xnorm = 1.0f; - } - if (gnorm/xnorm <= params.lbfgs.eps) { - // converged - return GGML_OPT_RESULT_OK; - } - - // delta-based convergence test - if (pf != NULL) { - // need at least params.past iterations to start checking for convergence - if (params.past <= k[0]) { - const float rate = (pf[k[0]%params.past] - fx)/fx; - - if (fabsf(rate) < params.delta) { - return GGML_OPT_RESULT_OK; - } - } - - pf[k[0]%params.past] = fx; - } - - // check for improvement - if (params.max_no_improvement > 0) { - if (fx < fx_best[0]) { - fx_best[0] = fx; - n_no_improvement[0] = 0; - } else { - n_no_improvement[0]++; - - if (n_no_improvement[0] >= params.max_no_improvement) { - return GGML_OPT_RESULT_OK; - } - } - } - - if (params.lbfgs.n_iter != 0 && params.lbfgs.n_iter < it + 1) { - // reached the maximum number of iterations - return GGML_OPT_RESULT_DID_NOT_CONVERGE; - } - - // update vectors s and y: - // s_{k+1} = x_{k+1} - x_{k} = \step * d_{k}. - // y_{k+1} = g_{k+1} - g_{k}. - // - ggml_vec_sub_f32(nx, &lm_s[end[0]*nx], x, xp); - ggml_vec_sub_f32(nx, &lm_y[end[0]*nx], g, gp); - - // compute scalars ys and yy: - // ys = y^t \cdot s -> 1 / \rho. - // yy = y^t \cdot y. - // - ggml_vec_dot_f32(nx, &ys, 0, &lm_y[end[0]*nx], 0, &lm_s[end[0]*nx], 0, 1); - ggml_vec_dot_f32(nx, &yy, 0, &lm_y[end[0]*nx], 0, &lm_y[end[0]*nx], 0, 1); - - lm_ys[end[0]] = ys; - - // find new search direction - // ref: https://en.wikipedia.org/wiki/Limited-memory_BFGS - - bound = (m <= k[0]) ? m : k[0]; - k[0]++; - it++; - end[0] = (end[0] + 1)%m; - - // initialize search direction with -g - ggml_vec_neg_f32(nx, d, g); - - j[0] = end[0]; - for (int i = 0; i < bound; ++i) { - j[0] = (j[0] + m - 1) % m; - // \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1} - ggml_vec_dot_f32(nx, &lm_alpha[j[0]], 0, &lm_s[j[0]*nx], 0, d, 0, 1); - lm_alpha[j[0]] /= lm_ys[j[0]]; - // q_{i} = q_{i+1} - \alpha_{i} y_{i} - ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]); - } - - ggml_vec_scale_f32(nx, d, ys/yy); - - for (int i = 0; i < bound; ++i) { - // \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i} - ggml_vec_dot_f32(nx, &beta, 0, &lm_y[j[0]*nx], 0, d, 0, 1); - beta /= lm_ys[j[0]]; - // \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j} - ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta); - j[0] = (j[0] + 1)%m; - } - - step[0] = 1.0; - } - - GGML_ABORT("lbfgs failed"); - - //return GGML_OPT_RESULT_DID_NOT_CONVERGE; -} - -struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { - struct ggml_opt_params result; - - switch (type) { - case GGML_OPT_TYPE_ADAM: - { - result = (struct ggml_opt_params) { - .type = GGML_OPT_TYPE_ADAM, - .graph_size = GGML_DEFAULT_GRAPH_SIZE, - .n_threads = 1, // FIXME: GGML_DEFAULT_N_THREADS ? - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 100, - - .print_forward_graph = true, - .print_backward_graph = true, - - .n_gradient_accumulation = 1, - - .adam = { - .n_iter = 10000, - .sched = 1.000f, - .decay = 0.0f, - .decay_min_ndim = 2, - .alpha = 0.001f, - .beta1 = 0.9f, - .beta2 = 0.999f, - .eps = 1e-8f, - .eps_f = 1e-5f, - .eps_g = 1e-3f, - .gclip = 0.0f, - }, - }; - } break; - case GGML_OPT_TYPE_LBFGS: - { - result = (struct ggml_opt_params) { - .type = GGML_OPT_TYPE_LBFGS, - .graph_size = GGML_DEFAULT_GRAPH_SIZE, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, - - .max_no_improvement = 0, - - .print_forward_graph = true, - .print_backward_graph = true, - - .n_gradient_accumulation = 1, - - .lbfgs = { - .m = 6, - .n_iter = 100, - .max_linesearch = 20, - - .eps = 1e-5f, - .ftol = 1e-4f, - .wolfe = 0.9f, - .min_step = 1e-20f, - .max_step = 1e+20f, - - .linesearch = GGML_LINESEARCH_DEFAULT, - }, - }; - } break; - } - - return result; -} - -GGML_API void ggml_opt_init( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - int64_t nx) { - opt->ctx = ctx; - opt->params = params; - opt->iter = 0; - opt->nx = nx; - opt->just_initialized = true; - if (opt->ctx == NULL) { - struct ggml_init_params ctx_opt_params; - if (opt->params.type == GGML_OPT_TYPE_ADAM) { - ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3; - if (opt->params.past > 0) { - ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; - } - } else if (opt->params.type == GGML_OPT_TYPE_LBFGS) { - ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2); - if (opt->params.past > 0) { - ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; - } - } - ctx_opt_params.mem_buffer = NULL; - ctx_opt_params.no_alloc = false; - - opt->ctx = ggml_init(ctx_opt_params); - } - switch (opt->params.type) { - case GGML_OPT_TYPE_ADAM: - { - opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->adam.pf = params.past > 0 - ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) - : NULL; - ggml_set_zero(opt->adam.m); - ggml_set_zero(opt->adam.v); - if (opt->adam.pf) { - ggml_set_zero(opt->adam.pf); - } - } break; - case GGML_OPT_TYPE_LBFGS: - { - opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); - opt->lbfgs.pf = params.past > 0 - ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past) - : NULL; - opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); - opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m); - ggml_set_zero(opt->lbfgs.x); - ggml_set_zero(opt->lbfgs.xp); - ggml_set_zero(opt->lbfgs.g); - ggml_set_zero(opt->lbfgs.gp); - ggml_set_zero(opt->lbfgs.d); - if (opt->lbfgs.pf) { - ggml_set_zero(opt->lbfgs.pf); - } - ggml_set_zero(opt->lbfgs.lmal); - ggml_set_zero(opt->lbfgs.lmys); - ggml_set_zero(opt->lbfgs.lms); - ggml_set_zero(opt->lbfgs.lmy); - } break; - } -} - -enum ggml_opt_result ggml_opt( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f) { - GGML_ASSERT(f->grad && "ggml_set_param called for at least one parent tensor."); - - bool free_ctx = false; - if (ctx == NULL) { - struct ggml_init_params params_ctx = { - .mem_size = 16*1024*1024, - .mem_buffer = NULL, - .no_alloc = false, - }; - - ctx = ggml_init(params_ctx); - if (ctx == NULL) { - return GGML_OPT_RESULT_NO_CONTEXT; - } - - free_ctx = true; - } - - enum ggml_opt_result result = GGML_OPT_RESULT_OK; - - struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); - - ggml_opt_init(ctx, opt, params, 0); - result = ggml_opt_resume(ctx, opt, f); - - if (free_ctx) { - ggml_free(ctx); - } - - return result; -} - -enum ggml_opt_result ggml_opt_resume( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f) { - - // build forward + backward compute graphs - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, opt->params.graph_size, true); - ggml_build_forward_expand(gf, f); - - struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf); - ggml_build_backward_expand(ctx, gf, gb, true); - - return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); -} - -enum ggml_opt_result ggml_opt_resume_g( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data) { - - GGML_ASSERT(f->grad && "ggml_set_param must be called for at least one ancestor"); - - // build forward + backward compute graphs - enum ggml_opt_result result = GGML_OPT_RESULT_OK; - - switch (opt->params.type) { - case GGML_OPT_TYPE_ADAM: - { - result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data); - } break; - case GGML_OPT_TYPE_LBFGS: - { - result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data); - } break; - } - - if (opt->params.print_forward_graph) { - ggml_graph_print (gf); - ggml_graph_dump_dot(gf, NULL, "opt-forward.dot"); - } - - if (opt->params.print_backward_graph) { - ggml_graph_print (gb); - ggml_graph_dump_dot(gb, gf, "opt-backward.dot"); - } - - return result; + GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename); } //////////////////////////////////////////////////////////////////////////////// @@ -21830,6 +6357,17 @@ void ggml_set_output(struct ggml_tensor * tensor) { tensor->flags |= GGML_TENSOR_FLAG_OUTPUT; } +void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) { + GGML_UNUSED(ctx); // TODO: remove this parameter + tensor->flags |= GGML_TENSOR_FLAG_PARAM; +} + +void ggml_set_loss(struct ggml_tensor * tensor) { + GGML_ASSERT(ggml_is_scalar(tensor)); + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + tensor->flags |= GGML_TENSOR_FLAG_LOSS; +} + //////////////////////////////////////////////////////////////////////////////// void ggml_quantize_init(enum ggml_type type) { @@ -21915,9 +6453,6 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -21947,1431 +6482,30 @@ size_t ggml_quantize_chunk( //////////////////////////////////////////////////////////////////////////////// -struct gguf_str { - uint64_t n; // GGUFv2 - char * data; -}; - -static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = { - [GGUF_TYPE_UINT8] = sizeof(uint8_t), - [GGUF_TYPE_INT8] = sizeof(int8_t), - [GGUF_TYPE_UINT16] = sizeof(uint16_t), - [GGUF_TYPE_INT16] = sizeof(int16_t), - [GGUF_TYPE_UINT32] = sizeof(uint32_t), - [GGUF_TYPE_INT32] = sizeof(int32_t), - [GGUF_TYPE_FLOAT32] = sizeof(float), - [GGUF_TYPE_BOOL] = sizeof(bool), - [GGUF_TYPE_STRING] = sizeof(struct gguf_str), - [GGUF_TYPE_UINT64] = sizeof(uint64_t), - [GGUF_TYPE_INT64] = sizeof(int64_t), - [GGUF_TYPE_FLOAT64] = sizeof(double), - [GGUF_TYPE_ARRAY] = 0, // undefined -}; -static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); - -static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = { - [GGUF_TYPE_UINT8] = "u8", - [GGUF_TYPE_INT8] = "i8", - [GGUF_TYPE_UINT16] = "u16", - [GGUF_TYPE_INT16] = "i16", - [GGUF_TYPE_UINT32] = "u32", - [GGUF_TYPE_INT32] = "i32", - [GGUF_TYPE_FLOAT32] = "f32", - [GGUF_TYPE_BOOL] = "bool", - [GGUF_TYPE_STRING] = "str", - [GGUF_TYPE_ARRAY] = "arr", - [GGUF_TYPE_UINT64] = "u64", - [GGUF_TYPE_INT64] = "i64", - [GGUF_TYPE_FLOAT64] = "f64", -}; -static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); - -union gguf_value { - uint8_t uint8; - int8_t int8; - uint16_t uint16; - int16_t int16; - uint32_t uint32; - int32_t int32; - float float32; - uint64_t uint64; - int64_t int64; - double float64; - bool bool_; - - struct gguf_str str; - - struct { - enum gguf_type type; - - uint64_t n; // GGUFv2 - void * data; - } arr; -}; - -struct gguf_kv { - struct gguf_str key; - - enum gguf_type type; - union gguf_value value; -}; - -struct gguf_header { - char magic[4]; - - uint32_t version; - uint64_t n_tensors; // GGUFv2 - uint64_t n_kv; // GGUFv2 -}; - -struct gguf_tensor_info { - struct gguf_str name; - - uint32_t n_dims; - uint64_t ne[GGML_MAX_DIMS]; - - enum ggml_type type; - - uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` - - // for writing API - const void * data; - size_t size; -}; - -struct gguf_context { - struct gguf_header header; - - struct gguf_kv * kv; - struct gguf_tensor_info * infos; - - size_t alignment; - size_t offset; // offset of `data` from beginning of file - size_t size; // size of `data` in bytes - - //uint8_t * padding; - void * data; -}; - -static size_t gguf_type_size(enum gguf_type type) { - GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT); - return GGUF_TYPE_SIZE[type]; -} - -static void gguf_tensor_info_sanitize(struct gguf_tensor_info * info) { - GGML_ASSERT(info->n_dims <= GGML_MAX_DIMS); - GGML_ASSERT(0 <= info->type && info->type < GGML_TYPE_COUNT); - - for (uint32_t i = 0; i < info->n_dims; ++i) { - GGML_ASSERT(info->ne[i] > 0); - } - - // prevent overflow for total number of elements - GGML_ASSERT(INT64_MAX/info->ne[1] > info->ne[0]); - GGML_ASSERT(INT64_MAX/info->ne[2] > info->ne[0]*info->ne[1]); - GGML_ASSERT(INT64_MAX/info->ne[3] > info->ne[0]*info->ne[1]*info->ne[2]); -} - -static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) { - const size_t n = fread(dst, 1, size, file); - *offset += n; - return n == size; -} - -static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) { - p->n = 0; - p->data = NULL; - - bool ok = true; - - ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset); - - // early exit if string length is invalid, prevents from integer overflow - if (p->n == SIZE_MAX) { - fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n); - return false; - } - - p->data = GGML_CALLOC(p->n + 1, 1); - - ok = ok && gguf_fread_el(file, p->data, p->n, offset); - - return ok; -} - -static void gguf_free_kv(struct gguf_kv * kv) { - if (kv->key.data) { - GGML_FREE(kv->key.data); - } - - if (kv->type == GGUF_TYPE_STRING) { - if (kv->value.str.data) { - GGML_FREE(kv->value.str.data); - } - } - - if (kv->type == GGUF_TYPE_ARRAY) { - if (kv->value.arr.data) { - if (kv->value.arr.type == GGUF_TYPE_STRING) { - for (uint64_t j = 0; j < kv->value.arr.n; ++j) { - struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j]; - if (str->data) { - GGML_FREE(str->data); - } - } - } - GGML_FREE(kv->value.arr.data); - } - } -} - -struct gguf_context * gguf_init_empty(void) { - struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context)); - - memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic)); - ctx->header.version = GGUF_VERSION; - ctx->header.n_tensors = 0; - ctx->header.n_kv = 0; - - ctx->kv = NULL; - ctx->infos = NULL; - - ctx->alignment = GGUF_DEFAULT_ALIGNMENT; - ctx->offset = 0; - ctx->size = 0; - - ctx->data = NULL; - - return ctx; -} - -struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { - FILE * file = ggml_fopen(fname, "rb"); - if (!file) { - fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno)); - return NULL; - } - - // offset from start of file - size_t offset = 0; - - char magic[4]; - - // check the magic before making allocations - { - gguf_fread_el(file, &magic, sizeof(magic), &offset); - - for (uint32_t i = 0; i < sizeof(magic); i++) { - if (magic[i] != GGUF_MAGIC[i]) { - fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]); - fclose(file); - return NULL; - } - } - } - - bool ok = true; - - struct gguf_context * ctx = GGML_CALLOC(1, sizeof(struct gguf_context)); - - // read the header - { - strncpy(ctx->header.magic, magic, 4); - - ctx->kv = NULL; - ctx->infos = NULL; - ctx->data = NULL; - - ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset); - ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); - ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); - - if (ctx->header.version == 1) { - fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__); - fclose(file); - gguf_free(ctx); - return NULL; - } - - // sanity-checks to prevent from integer/buffer overflows - - ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info)); - ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead()); - ok = ok && (ctx->header.n_kv < (SIZE_MAX/2)/sizeof(struct gguf_kv)); - - if (!ok) { - fprintf(stderr, "%s: failed to read header\n", __func__); - fclose(file); - gguf_free(ctx); - return NULL; - } - } - - // read the kv pairs - { - const uint64_t n_kv = ctx->header.n_kv; - - // header.n_kv will hold the actual value of pairs that were successfully read in the loop below - ctx->header.n_kv = 0; - ctx->kv = GGML_CALLOC(n_kv, sizeof(struct gguf_kv)); - - for (uint64_t i = 0; i < n_kv; ++i) { - struct gguf_kv * kv = &ctx->kv[i]; - - //fprintf(stderr, "%s: reading kv %d\n", __func__, i); - - ok = ok && gguf_fread_str(file, &kv->key, &offset); - ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset); - - //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data); - - switch (kv->type) { - case GGUF_TYPE_UINT8: ok = ok && gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break; - case GGUF_TYPE_INT8: ok = ok && gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break; - case GGUF_TYPE_UINT16: ok = ok && gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break; - case GGUF_TYPE_INT16: ok = ok && gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break; - case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break; - case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break; - case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break; - case GGUF_TYPE_UINT64: ok = ok && gguf_fread_el (file, &kv->value.uint64, sizeof(kv->value.uint64), &offset); break; - case GGUF_TYPE_INT64: ok = ok && gguf_fread_el (file, &kv->value.int64, sizeof(kv->value.int64), &offset); break; - case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break; - case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break; - case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break; - case GGUF_TYPE_ARRAY: - { - ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset); - ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); - - switch (kv->value.arr.type) { - case GGUF_TYPE_UINT8: - case GGUF_TYPE_INT8: - case GGUF_TYPE_UINT16: - case GGUF_TYPE_INT16: - case GGUF_TYPE_UINT32: - case GGUF_TYPE_INT32: - case GGUF_TYPE_FLOAT32: - case GGUF_TYPE_UINT64: - case GGUF_TYPE_INT64: - case GGUF_TYPE_FLOAT64: - case GGUF_TYPE_BOOL: - { - // prevent from integer overflow in the malloc below - if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) { - fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n); - fclose(file); - gguf_free(ctx); - return NULL; - } - - kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, gguf_type_size(kv->value.arr.type)); - - ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset); - } break; - case GGUF_TYPE_STRING: - { - // prevent from integer overflow in the malloc below - if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) { - fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n); - fclose(file); - gguf_free(ctx); - return NULL; - } - - kv->value.arr.data = GGML_CALLOC(kv->value.arr.n, sizeof(struct gguf_str)); - - for (uint64_t j = 0; j < kv->value.arr.n; ++j) { - ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset); - } - } break; - case GGUF_TYPE_ARRAY: - default: GGML_ABORT("invalid type"); - } - } break; - default: GGML_ABORT("invalid type"); - } - - if (!ok) { - break; - } - - ctx->header.n_kv++; - } - - if (!ok) { - fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); - fclose(file); - gguf_free(ctx); - return NULL; - } - } - - // read the tensor infos - if (ctx->header.n_tensors > 0) { - ctx->infos = GGML_CALLOC(ctx->header.n_tensors, sizeof(struct gguf_tensor_info)); - - for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - for (int j = 0; j < GGML_MAX_DIMS; ++j) { - info->ne[j] = 1; - } - - ok = ok && gguf_fread_str(file, &info->name, &offset); - ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset); - - ok = ok && (info->n_dims <= GGML_MAX_DIMS); - - for (uint32_t j = 0; j < info->n_dims; ++j) { - ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); - } - - ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset); - ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); - - // TODO: return an error instead of crashing with GGML_ASSERT - gguf_tensor_info_sanitize(info); - - // make sure there is no duplicated tensor names - for (uint64_t j = 0; j < i && ok; ++j) { - if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) { - fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data); - ok = false; - } - } - - if (!ok) { - fprintf(stderr, "%s: failed to read tensor info\n", __func__); - fclose(file); - gguf_free(ctx); - return NULL; - } - } - } - - ctx->alignment = GGUF_DEFAULT_ALIGNMENT; - - int alignment_idx = gguf_find_key(ctx, "general.alignment"); - if (alignment_idx != -1) { - ctx->alignment = gguf_get_val_u32(ctx, alignment_idx); - } - - // we require the data section to be aligned, so take into account any padding - { - const size_t offset_pad = offset % ctx->alignment; - - if (offset_pad != 0) { - offset += ctx->alignment - offset_pad; - fseek(file, offset, SEEK_SET); - } - } - - // store the current file offset - this is where the data section starts - ctx->offset = offset; - - // compute the total size of the data section, taking into account the alignment - { - ctx->size = 0; - for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - const int64_t ne = - (int64_t) info->ne[0] * - (int64_t) info->ne[1] * - (int64_t) info->ne[2] * - (int64_t) info->ne[3]; - - if (ggml_blck_size(info->type) == 0 || ne % ggml_blck_size(info->type) != 0) { - fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", - __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); - fclose(file); - gguf_free(ctx); - return NULL; - } - - const size_t size_cur = ggml_row_size(info->type, ne); - - ctx->size += GGML_PAD(size_cur, ctx->alignment); - } - } - - // load the tensor data only if requested - if (params.ctx != NULL) { - // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob - // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of - // the ggml_tensor structs to the appropriate locations in the binary blob - - // compute the exact size needed for the new ggml_context - const size_t mem_size = - params.no_alloc ? - (ctx->header.n_tensors )*ggml_tensor_overhead() : - (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size; - - struct ggml_init_params pdata = { - .mem_size = mem_size, - .mem_buffer = NULL, - .no_alloc = params.no_alloc, - }; - - *params.ctx = ggml_init(pdata); - if (*params.ctx == NULL) { - fprintf(stderr, "%s: failed to initialize context\n", __func__); - fclose(file); - gguf_free(ctx); - return NULL; - } - - struct ggml_context * ctx_data = *params.ctx; - - struct ggml_tensor * data = NULL; - - if (!params.no_alloc) { - data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); - - ok = ok && data != NULL; - - // read the binary blob with the tensor data - ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset); - - if (!ok) { - fprintf(stderr, "%s: failed to read tensor data\n", __func__); - fclose(file); - ggml_free(ctx_data); - gguf_free(ctx); - return NULL; - } - - ctx->data = data->data; - } - - ggml_set_no_alloc(ctx_data, true); - - // create the tensors - for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { - const int64_t ne[GGML_MAX_DIMS] = { - ctx->infos[i].ne[0], - ctx->infos[i].ne[1], - ctx->infos[i].ne[2], - ctx->infos[i].ne[3], - }; - - struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne); - - ok = ok && cur != NULL; - - if (!ok) { - break; - } - - ggml_set_name(cur, ctx->infos[i].name.data); - - // point the data member to the appropriate location in the binary blob using the tensor infos - if (!params.no_alloc) { - //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file - cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data - } - } - - if (!ok) { - fprintf(stderr, "%s: failed to read the tensor data\n", __func__); - fclose(file); - ggml_free(ctx_data); - gguf_free(ctx); - return NULL; - } - - ggml_set_no_alloc(ctx_data, params.no_alloc); - } - - fclose(file); - - return ctx; -} - -void gguf_free(struct gguf_context * ctx) { - if (ctx == NULL) { - return; - } - - if (ctx->kv) { - // free string memory - not great.. - for (uint64_t i = 0; i < ctx->header.n_kv; ++i) { - gguf_free_kv(&ctx->kv[i]); - } - - GGML_FREE(ctx->kv); - } - - if (ctx->infos) { - for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - if (info->name.data) { - GGML_FREE(info->name.data); - } - } - - GGML_FREE(ctx->infos); - } - - GGML_FREE(ctx); -} - -const char * gguf_type_name(enum gguf_type type) { - return GGUF_TYPE_NAME[type]; -} - -int gguf_get_version(const struct gguf_context * ctx) { - return ctx->header.version; -} - -size_t gguf_get_alignment(const struct gguf_context * ctx) { - return ctx->alignment; -} - -size_t gguf_get_data_offset(const struct gguf_context * ctx) { - return ctx->offset; -} - -void * gguf_get_data(const struct gguf_context * ctx) { - return ctx->data; -} - -int gguf_get_n_kv(const struct gguf_context * ctx) { - return ctx->header.n_kv; -} - -int gguf_find_key(const struct gguf_context * ctx, const char * key) { - // return -1 if key not found - int keyfound = -1; - - const int n_kv = gguf_get_n_kv(ctx); - - for (int i = 0; i < n_kv; ++i) { - if (strcmp(key, gguf_get_key(ctx, i)) == 0) { - keyfound = i; - break; - } - } - - return keyfound; -} - -const char * gguf_get_key(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - return ctx->kv[key_id].key.data; -} - -enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - return ctx->kv[key_id].type; -} - -enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - return ctx->kv[key_id].value.arr.type; -} - -const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - return ctx->kv[key_id].value.arr.data; -} - -const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - struct gguf_kv * kv = &ctx->kv[key_id]; - struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; - return str->data; -} - -int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY); - return ctx->kv[key_id].value.arr.n; -} - -uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8); - return ctx->kv[key_id].value.uint8; -} - -int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8); - return ctx->kv[key_id].value.int8; -} - -uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16); - return ctx->kv[key_id].value.uint16; -} - -int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16); - return ctx->kv[key_id].value.int16; -} - -uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32); - return ctx->kv[key_id].value.uint32; -} - -int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32); - return ctx->kv[key_id].value.int32; -} - -float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32); - return ctx->kv[key_id].value.float32; -} - -uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64); - return ctx->kv[key_id].value.uint64; -} - -int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64); - return ctx->kv[key_id].value.int64; -} - -double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64); - return ctx->kv[key_id].value.float64; -} - -bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL); - return ctx->kv[key_id].value.bool_; -} - -const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING); - return ctx->kv[key_id].value.str.data; -} - -const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) { - GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); - GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY); - GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING); - return &ctx->kv[key_id].value; -} - -int gguf_get_n_tensors(const struct gguf_context * ctx) { - return ctx->header.n_tensors; -} - -int gguf_find_tensor(const struct gguf_context * ctx, const char * name) { - // return -1 if tensor not found - int tensorfound = -1; - - const int n_tensors = gguf_get_n_tensors(ctx); - - for (int i = 0; i < n_tensors; ++i) { - if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { - tensorfound = i; - break; - } - } - - return tensorfound; -} - -size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) { - return ctx->infos[i].offset; -} - -char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) { - return ctx->infos[i].name.data; -} - -enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) { - return ctx->infos[i].type; -} - -// returns the index -static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) { - const int idx = gguf_find_key(ctx, key); - if (idx >= 0) { - return idx; - } - - const int n_kv = gguf_get_n_kv(ctx); - - ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv)); - ctx->kv[n_kv].key.n = strlen(key); - ctx->kv[n_kv].key.data = strdup(key); - ctx->header.n_kv++; - - return n_kv; -} - -void gguf_remove_key(struct gguf_context * ctx, const char * key) { - const int idx = gguf_find_key(ctx, key); - if (idx >= 0) { - const int n_kv = gguf_get_n_kv(ctx); - gguf_free_kv(&ctx->kv[idx]); - for (int i = idx; i < n_kv-1; ++i) { - ctx->kv[i] = ctx->kv[i+1]; - } - ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv)); - ctx->header.n_kv--; - } -} - -void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT8; - ctx->kv[idx].value.uint8 = val; -} - -void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT8; - ctx->kv[idx].value.int8 = val; -} - -void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT16; - ctx->kv[idx].value.uint16 = val; -} - -void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT16; - ctx->kv[idx].value.int16 = val; +void ggml_log_set(ggml_log_callback log_callback, void * user_data) { + g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default; + g_logger_state.log_callback_user_data = user_data; } -void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT32; - ctx->kv[idx].value.uint32 = val; +void ggml_threadpool_params_init(struct ggml_threadpool_params * p, int n_threads) { + p->n_threads = n_threads; + p->prio = 0; // default priority (usually means normal or inherited) + p->poll = 50; // hybrid-polling enabled + p->strict_cpu = false; // no strict placement (all threads share same cpumask) + p->paused = false; // threads are ready to go + memset(p->cpumask, 0, GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited) } -void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT32; - ctx->kv[idx].value.int32 = val; +struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) { + struct ggml_threadpool_params p; + ggml_threadpool_params_init(&p, n_threads); + return p; } -void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_FLOAT32; - ctx->kv[idx].value.float32 = val; -} - -void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_UINT64; - ctx->kv[idx].value.uint64 = val; -} - -void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_INT64; - ctx->kv[idx].value.int64 = val; -} - -void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_FLOAT64; - ctx->kv[idx].value.float64 = val; -} - -void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_BOOL; - ctx->kv[idx].value.bool_ = val; +bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; + return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; } - -void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_STRING; - ctx->kv[idx].value.str.n = strlen(val); - ctx->kv[idx].value.str.data = strdup(val); -} - -void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_ARRAY; - ctx->kv[idx].value.arr.type = type; - ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type)); - memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type)); -} - -void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) { - const int idx = gguf_get_or_add_key(ctx, key); - - ctx->kv[idx].type = GGUF_TYPE_ARRAY; - ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING; - ctx->kv[idx].value.arr.n = n; - ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str)); - for (int i = 0; i < n; i++) { - struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i]; - str->n = strlen(data[i]); - str->data = strdup(data[i]); - } -} - -// set or add KV pairs from another context -void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { - for (uint32_t i = 0; i < src->header.n_kv; i++) { - switch (src->kv[i].type) { - case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break; - case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break; - case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break; - case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break; - case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break; - case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break; - case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break; - case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64); break; - case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64); break; - case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64); break; - case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break; - case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break; - case GGUF_TYPE_ARRAY: - { - if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) { - const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *)); - for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) { - data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data; - } - gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n); - GGML_FREE((void *)data); - } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) { - GGML_ABORT("nested arrays not supported"); - } else { - gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n); - } - } break; - default: GGML_ABORT("invalid type"); - } - } -} - -void gguf_add_tensor( - struct gguf_context * ctx, - const struct ggml_tensor * tensor) { - GGML_ASSERT(tensor); - if (gguf_find_tensor(ctx, tensor->name) != -1) { - GGML_ABORT("duplicated tensor name"); - } - - const int idx = ctx->header.n_tensors; - ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info)); - - ctx->infos[idx].name.n = strlen(tensor->name); - ctx->infos[idx].name.data = strdup(tensor->name); - - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - ctx->infos[idx].ne[i] = 1; - } - - ctx->infos[idx].n_dims = ggml_n_dims(tensor); - for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) { - ctx->infos[idx].ne[i] = tensor->ne[i]; - } - - ctx->infos[idx].type = tensor->type; - ctx->infos[idx].offset = 0; - ctx->infos[idx].data = tensor->data; - ctx->infos[idx].size = ggml_nbytes(tensor); - - if (ctx->header.n_tensors > 0) { - ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment); - } - - ctx->header.n_tensors++; -} - -void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { - const int idx = gguf_find_tensor(ctx, name); - if (idx < 0) { - GGML_ABORT("tensor not found"); - } - - ctx->infos[idx].type = type; -} - -void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) { - const int idx = gguf_find_tensor(ctx, name); - if (idx < 0) { - GGML_ABORT("tensor not found"); - } - - ctx->infos[idx].data = data; - ctx->infos[idx].size = size; - - // update offsets - for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) { - ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment); - } -} - -//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) { -// fwrite(&val->n, sizeof(val->n), 1, file); -// fwrite(val->data, sizeof(char), val->n, file); -//} -// -//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) { -// fwrite(val, sizeof(char), size, file); -//} - -struct gguf_buf { - void * data; - size_t size; - size_t offset; -}; - -static struct gguf_buf gguf_buf_init(size_t size) { - struct gguf_buf buf = { - /*buf.data =*/ size == 0 ? NULL : GGML_CALLOC(1, size), - /*buf.size =*/ size, - /*buf.offset =*/ 0, - }; - - return buf; -} - -static void gguf_buf_free(struct gguf_buf buf) { - if (buf.data) { - GGML_FREE(buf.data); - } -} - -static void gguf_buf_grow(struct gguf_buf * buf, size_t size) { - if (buf->offset + size > buf->size) { - buf->size = 1.5*(buf->offset + size); - if (buf->data) { - buf->data = realloc(buf->data, buf->size); - } - } -} - -static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) { - gguf_buf_grow(buf, sizeof(val->n) + val->n); - - if (buf->data) { - memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n)); - } - buf->offset += sizeof(val->n); - - if (buf->data) { - memcpy((char *) buf->data + buf->offset, val->data, val->n); - } - buf->offset += val->n; -} - -static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) { - gguf_buf_grow(buf, el_size); - - if (buf->data) { - memcpy((char *) buf->data + buf->offset, val, el_size); - } - buf->offset += el_size; -} - -static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) { - // write header - gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic)); - gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version)); - gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors)); - gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv)); - - // write key-value pairs - for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { - struct gguf_kv * kv = &ctx->kv[i]; - - gguf_bwrite_str(buf, &kv->key); - gguf_bwrite_el (buf, &kv->type, sizeof(kv->type)); - - switch (kv->type) { - case GGUF_TYPE_UINT8: gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break; - case GGUF_TYPE_INT8: gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break; - case GGUF_TYPE_UINT16: gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break; - case GGUF_TYPE_INT16: gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break; - case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break; - case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break; - case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break; - case GGUF_TYPE_UINT64: gguf_bwrite_el (buf, &kv->value.uint64, sizeof(kv->value.uint64) ); break; - case GGUF_TYPE_INT64: gguf_bwrite_el (buf, &kv->value.int64, sizeof(kv->value.int64) ); break; - case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break; - case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break; - case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break; - case GGUF_TYPE_ARRAY: - { - gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type)); - gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) ); - - switch (kv->value.arr.type) { - case GGUF_TYPE_UINT8: - case GGUF_TYPE_INT8: - case GGUF_TYPE_UINT16: - case GGUF_TYPE_INT16: - case GGUF_TYPE_UINT32: - case GGUF_TYPE_INT32: - case GGUF_TYPE_FLOAT32: - case GGUF_TYPE_UINT64: - case GGUF_TYPE_INT64: - case GGUF_TYPE_FLOAT64: - case GGUF_TYPE_BOOL: - { - gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type)); - } break; - case GGUF_TYPE_STRING: - { - for (uint32_t j = 0; j < kv->value.arr.n; ++j) { - gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]); - } - } break; - case GGUF_TYPE_ARRAY: - default: GGML_ABORT("invalid type"); - } - } break; - default: GGML_ABORT("invalid type"); - } - } - - // write tensor infos - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - gguf_bwrite_str(buf, &info->name); - gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims)); - for (uint32_t j = 0; j < info->n_dims; ++j) { - gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j])); - } - gguf_bwrite_el(buf, &info->type, sizeof(info->type)); - gguf_bwrite_el(buf, &info->offset, sizeof(info->offset)); - } - - // we require the data section to be aligned, so take into account any padding - { - const size_t offset = buf->offset; - const size_t offset_pad = GGML_PAD(offset, ctx->alignment); - - if (offset_pad != offset) { - uint8_t pad = 0; - for (size_t i = 0; i < offset_pad - offset; ++i) { - gguf_bwrite_el(buf, &pad, sizeof(pad)); - } - } - } - - if (only_meta) { - return; - } - - size_t offset = 0; - - // write tensor data - for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { - struct gguf_tensor_info * info = &ctx->infos[i]; - - const size_t size = info->size; - const size_t size_pad = GGML_PAD(size, ctx->alignment); - - gguf_bwrite_el(buf, info->data, size); - - if (size_pad != size) { - uint8_t pad = 0; - for (size_t j = 0; j < size_pad - size; ++j) { - gguf_bwrite_el(buf, &pad, sizeof(pad)); - } - } - - GGML_ASSERT(offset == info->offset); - - offset += size_pad; - } -} - -void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { - FILE * file = ggml_fopen(fname, "wb"); - if (!file) { - GGML_ABORT("failed to open file for writing"); - } - - struct gguf_buf buf = gguf_buf_init(16*1024); - - gguf_write_to_buf(ctx, &buf, only_meta); - - fwrite(buf.data, 1, buf.offset, file); - - gguf_buf_free(buf); - - fclose(file); -} - -size_t gguf_get_meta_size(const struct gguf_context * ctx) { - // no allocs - only compute size - struct gguf_buf buf = gguf_buf_init(0); - - gguf_write_to_buf(ctx, &buf, true); - - return buf.offset; -} - -void gguf_get_meta_data(const struct gguf_context * ctx, void * data) { - struct gguf_buf buf = gguf_buf_init(16*1024); - - gguf_write_to_buf(ctx, &buf, true); - - memcpy(data, buf.data, buf.offset); - - gguf_buf_free(buf); -} - -//////////////////////////////////////////////////////////////////////////////// - -int ggml_cpu_has_avx(void) { -#if defined(__AVX__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx_vnni(void) { -#if defined(__AVXVNNI__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx2(void) { -#if defined(__AVX2__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx512(void) { -#if defined(__AVX512F__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx512_vbmi(void) { -#if defined(__AVX512VBMI__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx512_vnni(void) { -#if defined(__AVX512VNNI__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_avx512_bf16(void) { -#if defined(__AVX512BF16__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_fma(void) { -#if defined(__FMA__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_neon(void) { -#if defined(__ARM_NEON) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_sve(void) { -#if defined(__ARM_FEATURE_SVE) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_arm_fma(void) { -#if defined(__ARM_FEATURE_FMA) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_metal(void) { -#if defined(GGML_USE_METAL) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_f16c(void) { -#if defined(__F16C__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_fp16_va(void) { -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_wasm_simd(void) { -#if defined(__wasm_simd128__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_blas(void) { -#if defined(GGML_USE_BLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_cuda(void) { -#if defined(GGML_USE_CUDA) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_vulkan(void) { -#if defined(GGML_USE_VULKAN) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_kompute(void) { -#if defined(GGML_USE_KOMPUTE) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_sycl(void) { -#if defined(GGML_USE_SYCL) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_rpc(void) { -#if defined(GGML_USE_RPC) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_cann(void) { -#if defined(GGML_USE_CANN) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_llamafile(void) { -#if defined(GGML_USE_LLAMAFILE) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_gpublas(void) { - return ggml_cpu_has_cuda() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() || ggml_cpu_has_sycl(); -} - -int ggml_cpu_has_sse3(void) { -#if defined(__SSE3__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_ssse3(void) { -#if defined(__SSSE3__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_vsx(void) { -#if defined(__POWER9_VECTOR__) - return 1; -#else - return 0; -#endif -} - -int ggml_cpu_has_matmul_int8(void) { -#if defined(__ARM_FEATURE_MATMUL_INT8) - return 1; -#else - return 0; -#endif -} - -//////////////////////////////////////////////////////////////////////////////// diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp new file mode 100644 index 000000000..ab13669c5 --- /dev/null +++ b/ggml/src/gguf.cpp @@ -0,0 +1,1329 @@ +#include "ggml.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "gguf.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +struct type_to_gguf_type; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT8; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT8; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT16; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT16; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT32; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT32; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_FLOAT32; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_BOOL; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_STRING; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_UINT64; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_INT64; +}; + +template <> +struct type_to_gguf_type { + static constexpr enum gguf_type value = GGUF_TYPE_FLOAT64; +}; + +static const std::map GGUF_TYPE_SIZE = { + {GGUF_TYPE_UINT8, sizeof(uint8_t)}, + {GGUF_TYPE_INT8, sizeof(int8_t)}, + {GGUF_TYPE_UINT16, sizeof(uint16_t)}, + {GGUF_TYPE_INT16, sizeof(int16_t)}, + {GGUF_TYPE_UINT32, sizeof(uint32_t)}, + {GGUF_TYPE_INT32, sizeof(int32_t)}, + {GGUF_TYPE_FLOAT32, sizeof(float)}, + {GGUF_TYPE_BOOL, sizeof(int8_t)}, + {GGUF_TYPE_STRING, 0}, // undefined + {GGUF_TYPE_ARRAY, 0}, // undefined + {GGUF_TYPE_UINT64, sizeof(uint64_t)}, + {GGUF_TYPE_INT64, sizeof(int64_t)}, + {GGUF_TYPE_FLOAT64, sizeof(double)}, +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +static const std::map GGUF_TYPE_NAME = { + {GGUF_TYPE_UINT8, "u8"}, + {GGUF_TYPE_INT8, "i8"}, + {GGUF_TYPE_UINT16, "u16"}, + {GGUF_TYPE_INT16, "i16"}, + {GGUF_TYPE_UINT32, "u32"}, + {GGUF_TYPE_INT32, "i32"}, + {GGUF_TYPE_FLOAT32, "f32"}, + {GGUF_TYPE_BOOL, "bool"}, + {GGUF_TYPE_STRING, "str"}, + {GGUF_TYPE_ARRAY, "arr"}, + {GGUF_TYPE_UINT64, "u64"}, + {GGUF_TYPE_INT64, "i64"}, + {GGUF_TYPE_FLOAT64, "f64"}, +}; +static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13"); + +size_t gguf_type_size(enum gguf_type type) { + auto it = GGUF_TYPE_SIZE.find(type); + return it == GGUF_TYPE_SIZE.end() ? 0 : it->second; +} + +struct gguf_kv { + std::string key; + + bool is_array; + enum gguf_type type; + + std::vector data; + std::vector data_string; + + template + gguf_kv(const std::string & key, const T value) + : key(key), is_array(false), type(type_to_gguf_type::value) { + GGML_ASSERT(!key.empty()); + data.resize(sizeof(T)); + memcpy(data.data(), &value, sizeof(T)); + } + + template + gguf_kv(const std::string & key, const std::vector & value) + : key(key), is_array(true), type(type_to_gguf_type::value) { + GGML_ASSERT(!key.empty()); + data.resize(value.size()*sizeof(T)); + for (size_t i = 0; i < value.size(); ++i) { + const T tmp = value[i]; + memcpy(data.data() + i*sizeof(T), &tmp, sizeof(T)); + } + } + + gguf_kv(const std::string & key, const std::string & value) + : key(key), is_array(false), type(GGUF_TYPE_STRING) { + GGML_ASSERT(!key.empty()); + data_string.push_back(value); + } + + gguf_kv(const std::string & key, const std::vector & value) + : key(key), is_array(true), type(GGUF_TYPE_STRING) { + GGML_ASSERT(!key.empty()); + data_string = value; + } + + const std::string & get_key() const { + return key; + } + + const enum gguf_type & get_type() const { + return type; + } + + size_t get_ne() const { + if (type == GGUF_TYPE_STRING) { + const size_t ne = data_string.size(); + GGML_ASSERT(is_array || ne == 1); + return ne; + } + const size_t type_size = gguf_type_size(type); + GGML_ASSERT(data.size() % type_size == 0); + const size_t ne = data.size() / type_size; + GGML_ASSERT(is_array || ne == 1); + return ne; + } + + template + const T & get_val(const size_t i = 0) const { + GGML_ASSERT(type_to_gguf_type::value == type); + if constexpr (std::is_same::value) { + GGML_ASSERT(data_string.size() >= i+1); + return data_string[i]; + } + const size_t type_size = gguf_type_size(type); + GGML_ASSERT(data.size() % type_size == 0); + GGML_ASSERT(data.size() >= (i+1)*type_size); + return reinterpret_cast(data.data())[i]; + } + + void cast(const enum gguf_type new_type) { + const size_t new_type_size = gguf_type_size(new_type); + GGML_ASSERT(data.size() % new_type_size == 0); + type = new_type; + } +}; + +struct gguf_tensor_info { + struct ggml_tensor t; // for holding the equivalent info + uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` +}; + +struct gguf_context { + uint32_t version = GGUF_VERSION; + + std::vector kv; + std::vector info; + + size_t alignment = GGUF_DEFAULT_ALIGNMENT; + size_t offset = 0; // offset of `data` from beginning of file + size_t size = 0; // size of `data` in bytes + + void * data = nullptr; +}; + +struct gguf_reader { + FILE * file; + + gguf_reader(FILE * file) : file(file) {} + + template + bool read(T & dst) const { + return fread(&dst, 1, sizeof(dst), file) == sizeof(dst); + } + + template + bool read(std::vector & dst, const size_t n) const { + dst.resize(n); + for (size_t i = 0; i < dst.size(); ++i) { + if constexpr (std::is_same::value) { + bool tmp; + if (!read(tmp)) { + return false; + } + dst[i] = tmp; + } else { + if (!read(dst[i])) { + return false; + } + } + } + return true; + } + + bool read(bool & dst) const { + int8_t tmp = -1; + if (!read(tmp)) { + return false; + } + dst = tmp != 0; + return true; + } + + bool read(enum ggml_type & dst) const { + int32_t tmp = -1; + if (!read(tmp)) { + return false; + } + dst = ggml_type(tmp); + return true; + } + + bool read(enum gguf_type & dst) const { + int32_t tmp = -1; + if (!read(tmp)) { + return false; + } + dst = gguf_type(tmp); + return true; + } + + bool read(std::string & dst) const { + uint64_t size = -1; + if (!read(size)) { + return false; + } + dst.resize(size); + return fread(dst.data(), 1, dst.length(), file) == dst.length(); + } + + bool read(void * dst, const size_t size) const { + return fread(dst, 1, size, file) == size; + } +}; + +struct gguf_context * gguf_init_empty(void) { + return new gguf_context; +} + +template +bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector & kv, const std::string & key, const bool is_array, const size_t n) { + if (is_array) { + std::vector value; + try { + if (!gr.read(value, n)) { + return false; + } + } catch (std::length_error &) { + fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str()); + return false; + } catch (std::bad_alloc &) { + fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str()); + return false; + } + kv.emplace_back(key, value); + } else { + T value; + if (!gr.read(value)) { + return false; + } + kv.emplace_back(key, value); + } + return true; +} + +struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) { + const struct gguf_reader gr(file); + struct gguf_context * ctx = new gguf_context; + + bool ok = true; + + // file magic + { + std::vector magic; + ok = ok && gr.read(magic, 4); + + if (!ok) { + fprintf(stderr, "%s: failed to read magic\n", __func__); + gguf_free(ctx); + return nullptr; + } + + for (uint32_t i = 0; i < magic.size(); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]); + gguf_free(ctx); + return nullptr; + } + } + } + + // header + int64_t n_kv = 0; + int64_t n_tensors = 0; + + if (ok && gr.read(ctx->version)) { + if (ctx->version == 1) { + fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); + ok = false; + } + if (ctx->version > GGUF_VERSION) { + fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", + __func__, ctx->version, GGUF_VERSION); + ok = false; + } + } else { + ok = false; + } + + if (ok && gr.read(n_tensors)) { + static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); + if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) { + fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", + __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info)); + ok = false; + } + } else { + ok = false; + } + + if (ok && gr.read(n_kv)) { + static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); + if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) { + fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", + __func__, n_kv, SIZE_MAX/sizeof(gguf_kv)); + ok = false; + } + } else { + ok = false; + } + + if (!ok) { + fprintf(stderr, "%s: failed to read header\n", __func__); + gguf_free(ctx); + return nullptr; + } + + // KV pairs + { + for (int64_t i = 0; ok && i < n_kv; ++i) { + std::string key; + gguf_type type = gguf_type(-1); + bool is_array = false; + uint64_t n = 1; + + try { + ok = ok && gr.read(key); + } catch (std::length_error &) { + fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); + ok = false; + } catch (std::bad_alloc &) { + fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); + ok = false; + } + for (size_t j = 0; ok && j < ctx->kv.size(); ++j) { + if (key == ctx->kv[j].key) { + fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); + ok = false; + } + } + if (!ok) { + break; + } + + ok = ok && gr.read(type); + if (type == GGUF_TYPE_ARRAY) { + is_array = true; + ok = ok && gr.read(type); + ok = ok && gr.read(n); + } + if (!ok) { + break; + } + + switch (type) { + case GGUF_TYPE_UINT8: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT8: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_UINT16: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT16: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_UINT32: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT32: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_FLOAT32: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_BOOL: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_STRING: ok = ok && gguf_read_emplace_helper(gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_UINT64: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_INT64: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_FLOAT64: ok = ok && gguf_read_emplace_helper (gr, ctx->kv, key, is_array, n); break; + case GGUF_TYPE_ARRAY: + default: + { + fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); + ok = false; + } break; + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); + gguf_free(ctx); + return nullptr; + } + GGML_ASSERT(int64_t(ctx->kv.size()) == n_kv); + + const int alignment_idx = gguf_find_key(ctx, GGUF_KEY_GENERAL_ALIGNMENT); + ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx); + + if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) { + fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); + gguf_free(ctx); + return nullptr; + } + } + + // read the tensor info + for (int64_t i = 0; ok && i < n_tensors; ++i) { + struct gguf_tensor_info info; + + // tensor name + { + std::string name; + try { + ok = ok && gr.read(name); + } catch (std::length_error &) { + fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); + ok = false; + } catch (std::bad_alloc &) { + fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); + ok = false; + } + if (name.length() >= GGML_MAX_NAME) { + fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); + ok = false; + break; + } + ggml_set_name(&info.t, name.c_str()); + + // make sure there are no duplicate tensor names + for (int64_t j = 0; ok && j < i; ++j) { + if (strcmp(info.t.name, ctx->info[j].t.name) == 0) { + fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); + ok = false; + break; + } + } + } + if (!ok) { + break; + } + + // tensor shape + { + uint32_t n_dims = -1; + ok = ok && gr.read(n_dims); + if (n_dims > GGML_MAX_DIMS) { + fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", + __func__, info.t.name, n_dims, GGML_MAX_DIMS); + ok = false; + break; + } + for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) { + info.t.ne[j] = 1; + if (j < n_dims) { + ok = ok && gr.read(info.t.ne[j]); + } + + // check that all ne are non-negative + if (info.t.ne[j] < 0) { + fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", + __func__, info.t.name, j, info.t.ne[j]); + ok = false; + break; + } + } + + // check that the total number of elements is representable + if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) || + (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) || + (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) { + + fprintf(stderr, "%s: total number of elements in tensor '%s' with shape " + "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n", + __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX); + ok = false; + break; + } + } + if (!ok) { + break; + } + + // tensor type + { + ok = ok && gr.read(info.t.type); + + // check that tensor type is within defined range + if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { + fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n", + __func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); + ok = false; + break; + } + const size_t type_size = ggml_type_size(info.t.type); + const int64_t blck_size = ggml_blck_size(info.t.type); + + // check that row size is divisible by block size + if (blck_size == 0 || info.t.ne[0] % blck_size != 0) { + fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " + "not a multiple of block size (%" PRId64 ")\n", + __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size); + ok = false; + break; + } + + // calculate byte offsets given the tensor shape and type + info.t.nb[0] = type_size; + info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size); + for (int j = 2; j < GGML_MAX_DIMS; ++j) { + info.t.nb[j] = info.t.nb[j - 1]*info.t.ne[j - 1]; + } + } + if (!ok) { + break; + } + + // tensor data offset within buffer + ok = ok && gr.read(info.offset); + + ctx->info.push_back(info); + } + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor info\n", __func__); + gguf_free(ctx); + return nullptr; + } + GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors); + + // we require the data section to be aligned, so take into account any padding + if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { + fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__); + gguf_free(ctx); + return nullptr; + } + + // store the current file offset - this is where the data section starts + ctx->offset = ftell(file); + + // compute the total size of the data section, taking into account the alignment + { + ctx->size = 0; + for (size_t i = 0; i < ctx->info.size(); ++i) { + const gguf_tensor_info & ti = ctx->info[i]; + if (ti.offset != ctx->size) { + fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", + __func__, ti.t.name, ti.offset, ctx->size); + fprintf(stderr, "%s: failed to read tensor data\n", __func__); + gguf_free(ctx); + return nullptr; + } + ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment); + } + } + + // load the tensor data only if requested + if (params.ctx != nullptr) { + // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob + // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of + // the ggml_tensor structs to the appropriate locations in the binary blob + + // compute the exact size needed for the new ggml_context + const size_t mem_size = + params.no_alloc ? + (n_tensors )*ggml_tensor_overhead() : + (n_tensors + 1)*ggml_tensor_overhead() + ctx->size; + + struct ggml_init_params pdata = { + /*mem_size =*/ mem_size, + /*mem_buffer =*/ nullptr, + /*no_alloc =*/ params.no_alloc, + }; + + *params.ctx = ggml_init(pdata); + if (*params.ctx == nullptr) { + fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__); + gguf_free(ctx); + return nullptr; + } + + struct ggml_context * ctx_data = *params.ctx; + + struct ggml_tensor * data = nullptr; + + if (!params.no_alloc) { + data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); + + ok = ok && data != nullptr; + + if (ok) { + ggml_set_name(data, "GGUF tensor data binary blob"); + } + + // read the binary blob with the tensor data + ok = ok && gr.read(data->data, ctx->size); + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__); + ggml_free(ctx_data); + *params.ctx = nullptr; + gguf_free(ctx); + return nullptr; + } + + ctx->data = data->data; + } + + ggml_set_no_alloc(ctx_data, true); + + // create the tensors + for (size_t i = 0; i < ctx->info.size(); ++i) { + const struct gguf_tensor_info & info = ctx->info[i]; + + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, info.t.type, GGML_MAX_DIMS, info.t.ne); + + ok = ok && cur != nullptr; + + if (!ok) { + break; + } + + ggml_set_name(cur, info.t.name); + + // point the data member to the appropriate location in the binary blob using the tensor info + if (!params.no_alloc) { + cur->data = (char *) data->data + info.offset; + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to create tensors\n", __func__); + ggml_free(ctx_data); + *params.ctx = nullptr; + gguf_free(ctx); + return nullptr; + } + + ggml_set_no_alloc(ctx_data, params.no_alloc); + } + + return ctx; +} + +struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { + FILE * file = ggml_fopen(fname, "rb"); + + if (!file) { + fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname); + return nullptr; + } + + struct gguf_context * result = gguf_init_from_file_impl(file, params); + fclose(file); + return result; +} + +void gguf_free(struct gguf_context * ctx) { + if (ctx == nullptr) { + return; + } + delete ctx; +} + +const char * gguf_type_name(enum gguf_type type) { + auto it = GGUF_TYPE_NAME.find(type); + return it == GGUF_TYPE_NAME.end() ? nullptr : it->second; +} + +uint32_t gguf_get_version(const struct gguf_context * ctx) { + return ctx->version; +} + +size_t gguf_get_alignment(const struct gguf_context * ctx) { + return ctx->alignment; +} + +size_t gguf_get_data_offset(const struct gguf_context * ctx) { + return ctx->offset; +} + +int64_t gguf_get_n_kv(const struct gguf_context * ctx) { + return ctx->kv.size(); +} + +int64_t gguf_find_key(const struct gguf_context * ctx, const char * key) { + // return -1 if key not found + int64_t keyfound = -1; + + const int64_t n_kv = gguf_get_n_kv(ctx); + + for (int64_t i = 0; i < n_kv; ++i) { + if (strcmp(key, gguf_get_key(ctx, i)) == 0) { + keyfound = i; + break; + } + } + + return keyfound; +} + +const char * gguf_get_key(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].get_key().c_str(); +} + +enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + return ctx->kv[key_id].is_array ? GGUF_TYPE_ARRAY : ctx->kv[key_id].get_type(); +} + +enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].is_array); + return ctx->kv[key_id].get_type(); +} + +const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); + return ctx->kv[key_id].data.data(); +} + +const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING); + return ctx->kv[key_id].data_string[i].c_str(); +} + +size_t gguf_get_arr_n(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + + if (ctx->kv[key_id].type == GGUF_TYPE_STRING) { + return ctx->kv[key_id].data_string.size(); + } + + const size_t type_size = gguf_type_size(ctx->kv[key_id].type); + GGML_ASSERT(ctx->kv[key_id].data.size() % type_size == 0); + return ctx->kv[key_id].data.size() / type_size; +} + +uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int8_t gguf_get_val_i8(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int16_t gguf_get_val_i16(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int32_t gguf_get_val_i32(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +float gguf_get_val_f32(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +int64_t gguf_get_val_i64(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +double gguf_get_val_f64(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val(); +} + +const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + return ctx->kv[key_id].get_val().c_str(); +} + +const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) { + GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); + GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); + GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING); + return ctx->kv[key_id].data.data(); +} + +int64_t gguf_get_n_tensors(const struct gguf_context * ctx) { + return ctx->info.size(); +} + +int64_t gguf_find_tensor(const struct gguf_context * ctx, const char * name) { + // return -1 if tensor not found + int64_t tensor_id = -1; + + const int64_t n_tensors = gguf_get_n_tensors(ctx); + + for (int64_t i = 0; i < n_tensors; ++i) { + if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { + tensor_id = i; + break; + } + } + + return tensor_id; +} + +size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ctx->info[tensor_id].offset; +} + +const char * gguf_get_tensor_name(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ctx->info[tensor_id].t.name; +} + +enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ctx->info[tensor_id].t.type; +} + +size_t gguf_get_tensor_size(const struct gguf_context * ctx, int64_t tensor_id) { + GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx)); + return ggml_nbytes(&ctx->info[tensor_id].t); +} + +int64_t gguf_remove_key(struct gguf_context * ctx, const char * key) { + const int64_t key_id = gguf_find_key(ctx, key); + if (key_id >= 0) { + ctx->kv.erase(ctx->kv.begin() + key_id); + } + return key_id; +} + +template +static void gguf_check_reserved_keys(const std::string & key, const T val) { + if (key == GGUF_KEY_GENERAL_ALIGNMENT) { + if constexpr (std::is_same::value) { + GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2"); + } else { + GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32"); + } + } +} + +void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, val); +} + +void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) { + gguf_check_reserved_keys(key, val); + gguf_remove_key(ctx, key); + ctx->kv.emplace_back(key, std::string(val)); +} + +void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n) { + gguf_check_reserved_keys(key, data); + gguf_remove_key(ctx, key); + + const size_t nbytes = n*gguf_type_size(type); + std::vector tmp(nbytes); + if (!tmp.empty()) { + memcpy(tmp.data(), data, nbytes); + } + ctx->kv.emplace_back(key, tmp); + ctx->kv.back().cast(type); +} + +void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, size_t n) { + gguf_check_reserved_keys(key, data); + gguf_remove_key(ctx, key); + + std::vector tmp(n); + for (size_t i = 0; i < n; ++i) { + tmp[i] = data[i]; + } + ctx->kv.emplace_back(key, tmp); +} + +// set or add KV pairs from another context +void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src) { + const int64_t n_kv = gguf_get_n_kv(src); + for (int64_t i = 0; i < n_kv; ++i) { + const struct gguf_kv & kv = src->kv[i]; + + if (!kv.is_array) { + switch (kv.get_type()) { + case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_UINT64: gguf_set_val_u64 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_INT64: gguf_set_val_i64 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, kv.get_key().c_str(), kv.get_val()); break; + case GGUF_TYPE_STRING: gguf_set_val_str (ctx, kv.get_key().c_str(), kv.get_val().c_str()); break; + case GGUF_TYPE_ARRAY: + default: GGML_ABORT("invalid type"); + } + continue; + } + + const size_t ne = kv.get_ne(); + + switch (kv.get_type()) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: + case GGUF_TYPE_BOOL: { + gguf_set_arr_data(ctx, kv.get_key().c_str(), kv.get_type(), kv.data.data(), ne); + } break; + case GGUF_TYPE_STRING: { + std::vector tmp(ne); + for (size_t j = 0; j < ne; ++j) { + tmp[j] = kv.data_string[j].c_str(); + } + gguf_set_arr_str(ctx, kv.get_key().c_str(), tmp.data(), ne); + } break; + case GGUF_TYPE_ARRAY: + default: GGML_ABORT("invalid type"); + } + } +} + +void gguf_add_tensor( + struct gguf_context * ctx, + const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor); + if (gguf_find_tensor(ctx, tensor->name) != -1) { + GGML_ABORT("duplicate tensor name: %s", tensor->name); + } + + struct gguf_tensor_info ti; + ti.t = *tensor; + ti.offset = ctx->info.empty() ? 0 : + ctx->info.back().offset + GGML_PAD(ggml_nbytes(&ctx->info.back().t), ctx->alignment); + ctx->info.push_back(ti); +} + +void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { + const int64_t tensor_id = gguf_find_tensor(ctx, name); + if (tensor_id < 0) { + GGML_ABORT("tensor not found: %s", name); + } + struct ggml_tensor * tensor = &ctx->info[tensor_id].t; + const size_t type_size = ggml_type_size(type); + const int64_t blck_size = ggml_blck_size(type); + + tensor->type = type; + GGML_ASSERT(tensor->ne[0] % blck_size == 0 && "tensor row size not divisible by block size of new type"); + + tensor->nb[0] = type_size; + tensor->nb[1] = tensor->nb[0]*(tensor->ne[0]/blck_size); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + tensor->nb[i] = tensor->nb[i - 1]*tensor->ne[i - 1]; + } + + // update offsets + const int64_t n_tensors = gguf_get_n_tensors(ctx); + for (int64_t i = tensor_id + 1; i < n_tensors; ++i) { + ctx->info[i].offset = ctx->info[i - 1].offset + GGML_PAD(ggml_nbytes(&ctx->info[i - 1].t), ctx->alignment); + } +} + +void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data) { + const int64_t tensor_id = gguf_find_tensor(ctx, name); + if (tensor_id < 0) { + GGML_ABORT("tensor not found: %s", name); + } + + ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const +} + +struct gguf_writer { + std::vector & buf; + + gguf_writer(std::vector & buf) : buf(buf) {} + + template + void write(const T & val) const { + for (size_t i = 0; i < sizeof(val); ++i) { + buf.push_back(reinterpret_cast(&val)[i]); + } + } + + void write(const std::vector & val) const { + buf.insert(buf.end(), val.begin(), val.end()); + } + + void write(const bool & val) const { + const int8_t val8 = val ? 1 : 0; + write(val8); + } + + void write(const std::string & val) const { + { + const uint64_t n = val.length(); + write(n); + } + for (size_t i = 0; i < val.length(); ++i) { + buf.push_back(reinterpret_cast(val.data())[i]); + } + } + + void write(const char * val) const { + write(std::string(val)); + } + + void write(const enum ggml_type & val) const { + write(int32_t(val)); + } + + void write(const enum gguf_type & val) const { + write(int32_t(val)); + } + + void write(const struct gguf_kv & kv) const { + const uint64_t ne = kv.get_ne(); + + write(kv.get_key()); + + if (kv.is_array) { + write(GGUF_TYPE_ARRAY); + write(kv.get_type()); + write(ne); + } else { + write(kv.get_type()); + } + + switch (kv.get_type()) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_UINT64: + case GGUF_TYPE_INT64: + case GGUF_TYPE_FLOAT64: { + write(kv.data); + } break; + case GGUF_TYPE_BOOL: { + for (size_t i = 0; i < ne; ++i) { + write(kv.get_val(i)); + } + } break; + case GGUF_TYPE_STRING: { + for (size_t i = 0; i < ne; ++i) { + write(kv.get_val(i)); + } + } break; + case GGUF_TYPE_ARRAY: + default: GGML_ABORT("invalid type"); + } + } + + void write_tensor_meta(const struct gguf_tensor_info & info) const { + write(info.t.name); + + const uint32_t n_dims = ggml_n_dims(&info.t); + write(n_dims); + + for (uint32_t j = 0; j < n_dims; ++j) { + write(info.t.ne[j]); + } + write(info.t.type); + write(info.offset); + } + + void pad(const size_t alignment) const { + while (buf.size() % alignment != 0) { + const int8_t zero = 0; + write(zero); + } + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const { + GGML_ASSERT(buf.size() - offset_data == info.offset); + + GGML_ASSERT(ggml_is_contiguous(&info.t)); + const size_t offset = buf.size(); + const size_t nbytes = ggml_nbytes(&info.t); + + buf.resize(offset + nbytes); + if (info.t.buffer) { + ggml_backend_tensor_get(&info.t, buf.data() + offset, 0, nbytes); + } else { + GGML_ASSERT(info.t.data); + memcpy(buf.data() + offset, info.t.data, nbytes); + } + + pad(alignment); + } +}; + +void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) { + const struct gguf_writer gw(buf); + + const int64_t n_kv = gguf_get_n_kv(ctx); + const int64_t n_tensors = gguf_get_n_tensors(ctx); + + // write header + gw.write(GGUF_MAGIC[0]); + gw.write(GGUF_MAGIC[1]); + gw.write(GGUF_MAGIC[2]); + gw.write(GGUF_MAGIC[3]); + gw.write(ctx->version); + gw.write(n_tensors); + gw.write(n_kv); + + // write key-value pairs + for (int64_t i = 0; i < n_kv; ++i) { + gw.write(ctx->kv[i]); + } + + // write tensor info + for (int64_t i = 0; i < n_tensors; ++i) { + gw.write_tensor_meta(ctx->info[i]); + } + + // we require the data section to be aligned + gw.pad(ctx->alignment); + + if (only_meta) { + return; + } + + const size_t offset_data = gw.buf.size(); + + // write tensor data + for (int64_t i = 0; i < n_tensors; ++i) { + gw.write_tensor_data(ctx->info[i], offset_data, ctx->alignment); + } +} + +bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { + FILE * file = ggml_fopen(fname, "wb"); + + if (!file) { + fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); + return false; + } + + std::vector buf; + gguf_write_to_buf(ctx, buf, only_meta); + const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size(); + fclose(file); + return ok; +} + +size_t gguf_get_meta_size(const struct gguf_context * ctx) { + // only return size + std::vector buf; + gguf_write_to_buf(ctx, buf, /*only_meta =*/ true); + return buf.size(); +} + +void gguf_get_meta_data(const struct gguf_context * ctx, void * data) { + std::vector buf; + gguf_write_to_buf(ctx, buf, /*only_meta =*/ true); + memcpy(data, buf.data(), buf.size()); +} diff --git a/ggml/src/kompute-shaders/op_rope_f16.comp b/ggml/src/kompute-shaders/op_rope_f16.comp deleted file mode 100644 index 0ecfb2eab..000000000 --- a/ggml/src/kompute-shaders/op_rope_f16.comp +++ /dev/null @@ -1,73 +0,0 @@ -#version 450 - -#include "rope_common.comp" - -layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; -layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; }; - -void main() { - const uint i3 = gl_WorkGroupID.z; - const uint i2 = gl_WorkGroupID.y; - const uint i1 = gl_WorkGroupID.x; - - const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0; - - float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); - - const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); - - const int p = inB[pcs.inBOff + i2]; - - float theta = float(p); - - if (!is_neox) { - for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) { - float cos_theta, sin_theta; - rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); - - theta *= theta_scale; - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ - - const float x0 = float(inA[src]); - const float x1 = float(inA[src+1]); - - out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta); - out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta); - } - } else { - const float inv_ndims = -1.f/pcs.n_dims; - for (uint ic = 0; ic < pcs.n_dims; ic += 2) { - const uint cur_rot = ic; - - float cos_theta, sin_theta; - rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); - - theta *= theta_scale; - - const uint i0 = ic/2; - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ - - const float x0 = float(inA[src]); - const float x1 = float(inA[src+pcs.n_dims/2]); - - out_[dst_data] = float16_t(x0*cos_theta - x1*sin_theta); - out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta); - } - - for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) { - const uint i0 = ic; - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 2) + pcs.outOff; // Based from out_ - - out_[dst_data + 0] = inA[src + 0]; - out_[dst_data + 1] = inA[src + 1]; - } - } -} diff --git a/ggml/src/kompute-shaders/op_rope_f32.comp b/ggml/src/kompute-shaders/op_rope_f32.comp deleted file mode 100644 index cec0fd9a5..000000000 --- a/ggml/src/kompute-shaders/op_rope_f32.comp +++ /dev/null @@ -1,73 +0,0 @@ -#version 450 - -#include "rope_common.comp" - -layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; }; -layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; }; -layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; }; - -void main() { - const uint i3 = gl_WorkGroupID.z; - const uint i2 = gl_WorkGroupID.y; - const uint i1 = gl_WorkGroupID.x; - - const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0; - - float corr_dims[2]; - rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims); - - const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims); - - const int p = inB[pcs.inBOff + i2]; - - float theta = float(p); - - if (!is_neox) { - for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) { - float cos_theta, sin_theta; - rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); - - theta *= theta_scale; - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - - const float x0 = inA[src]; - const float x1 = inA[src+1]; - - out_[dst_data] = x0*cos_theta - x1*sin_theta; - out_[dst_data+1] = x0*sin_theta + x1*cos_theta; - } - } else { - const float inv_ndims = -1.f/pcs.n_dims; - for (uint ic = 0; ic < pcs.n_dims; ic += 2) { - const uint cur_rot = ic; - - float cos_theta, sin_theta; - rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta); - - theta *= theta_scale; - - const uint i0 = ic/2; - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - - const float x0 = inA[src]; - const float x1 = inA[src+pcs.n_dims/2]; - - out_[dst_data] = x0*cos_theta - x1*sin_theta; - out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta; - } - - for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) { - const uint i0 = ic; - - const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in - const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ - - out_[dst_data + 0] = inA[src + 0]; - out_[dst_data + 1] = inA[src + 1]; - } - } -} diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp deleted file mode 100644 index d0c2bb284..000000000 --- a/ggml/src/llamafile/sgemm.cpp +++ /dev/null @@ -1,1180 +0,0 @@ -// Copyright 2024 Mozilla Foundation -// -// Permission is hereby granted, free of charge, to any person obtaining -// a copy of this software and associated documentation files (the -// "Software"), to deal in the Software without restriction, including -// without limitation the rights to use, copy, modify, merge, publish, -// distribute, sublicense, and/or sell copies of the Software, and to -// permit persons to whom the Software is furnished to do so, subject to -// the following conditions: -// -// The above copyright notice and this permission notice shall be -// included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS -// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN -// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// -// _ _ ___ _ _ ___ -// | |_(_)_ _ _ _| _ ) | /_\ / __| -// | _| | ' \ || | _ \ |__ / _ \\__ \. -// \__|_|_||_\_, |___/____/_/ \_\___/ -// |__/ -// -// BASIC LINEAR ALGEBRA SUBPROGRAMS -// -// -// This file implements multithreaded CPU matrix multiplication for the -// common contiguous use case C = Aᵀ * B. These kernels are designed to -// have excellent performance[1] for matrices that fit in the CPU cache -// without imposing any overhead such as cache filling or malloc calls. -// -// This implementation does not guarantee any upper bound with rounding -// errors, which grow along with k. Our goal's to maximally exploit the -// hardware for performance, and then use whatever resources remain for -// improving numerical accuracy. -// -// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. -// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. - -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Wpedantic" -#pragma GCC diagnostic ignored "-Wignored-attributes" -#endif - -#include "sgemm.h" -#include "ggml-impl.h" -#include "ggml-quants.h" - -#ifdef _MSC_VER -#define NOINLINE __declspec(noinline) -#else -#define NOINLINE __attribute__((__noinline__)) -#endif - -#if defined(__ARM_NEON) || defined(__AVX512F__) -#define VECTOR_REGISTERS 32 -#else -#define VECTOR_REGISTERS 16 -#endif - -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -namespace { - -inline float unhalf(ggml_fp16_t d) { - return GGML_FP16_TO_FP32(d); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// VECTORIZED ARITHMETIC OPERATIONS - -#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); } -inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); } -inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } -#endif // __SSE__ - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); } -inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); } -inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); } -#endif // __AVX__ - -#if defined(__AVX512F__) -inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } -inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); } -inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); } -#endif // __AVX512F__ - -#if defined(__ARM_NEON) -inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); } -inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); } -inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); } -#endif // __ARM_NEON - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) -inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); } -inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } -inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// VECTORIZED FUSED MULTIPLY ADD - -/** - * Computes a * b + c. - */ -template -inline U madd(T a, T b, U c) { - return add(mul(a, b), c); -} - -#if defined(__FMA__) -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -template <> -inline __m256 madd(__m256 a, __m256 b, __m256 c) { - return _mm256_fmadd_ps(a, b, c); -} -#endif -#if defined(__AVX512F__) -template <> -inline __m512 madd(__m512 a, __m512 b, __m512 c) { - return _mm512_fmadd_ps(a, b, c); -} -#endif -#endif - -#if defined(__ARM_FEATURE_FMA) -template <> -inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { - return vfmaq_f32(c, b, a); -} -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) -template <> -inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { - return vfmaq_f16(c, b, a); -} -#endif -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// VECTORIZED HORIZONTAL SUM - -#if defined(__ARM_NEON) -inline float hsum(float32x4_t x) { - return vaddvq_f32(x); -} -#endif // __ARM_NEON - -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) -inline float hsum(float16x8_t x) { - return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), - vcvt_f32_f16(vget_high_f16(x)))); -} -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -inline float hsum(__m128 x) { -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) - x = _mm_add_ps(x, _mm_movehl_ps(x, x)); - x = _mm_add_ss(x, _mm_movehdup_ps(x)); -#else - __m128 t; - t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1)); - x = _mm_add_ps(x, t); - t = _mm_movehl_ps(t, x); - x = _mm_add_ss(x, t); -#endif - return _mm_cvtss_f32(x); -} -#endif - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -inline float hsum(__m256 x) { - return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), - _mm256_castps256_ps128(x))); -} -#endif // __AVX__ - -#if defined(__AVX512F__) -inline float hsum(__m512 x) { - return _mm512_reduce_add_ps(x); -} -#endif // __AVX512F__ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// VECTORIZED MEMORY LOADING - -template T load(const U *); - -#if defined(__ARM_NEON) -template <> inline float32x4_t load(const float *p) { - return vld1q_f32(p); -} -#if !defined(_MSC_VER) -template <> inline float16x8_t load(const ggml_fp16_t *p) { - return vld1q_f16((const float16_t *)p); -} -template <> inline float32x4_t load(const ggml_fp16_t *p) { - return vcvt_f32_f16(vld1_f16((const float16_t *)p)); -} -#endif // _MSC_VER -#endif // __ARM_NEON - -#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -template <> inline __m128 load(const float *p) { - return _mm_loadu_ps(p); -} -#endif // __SSE__ - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) -template <> inline __m256 load(const float *p) { - return _mm256_loadu_ps(p); -} -#endif // __AVX__ - -#if defined(__F16C__) -template <> inline __m256 load(const ggml_fp16_t *p) { - return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); -} -#endif // __F16C__ - -#if defined(__AVX512F__) -template <> inline __m512 load(const float *p) { - return _mm512_loadu_ps(p); -} -template <> inline __m512 load(const ggml_fp16_t *p) { - return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); -} -#endif // __AVX512F__ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// FLOATING POINT MATRIX MULTIPLICATION - -template -class tinyBLAS { - public: - tinyBLAS(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, - int ith, int nth) - : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - } - - void matmul(int64_t m, int64_t n) { - mnpack(0, m, 0, n); - } - - private: - NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { -#if VECTOR_REGISTERS == 32 - case 0x55: - mc = 5; - nc = 5; - gemm<5, 5>(m0, m, n0, n); - break; - case 0x45: - mc = 4; - nc = 5; - gemm<4, 5>(m0, m, n0, n); - break; - case 0x54: - mc = 5; - nc = 4; - gemm<5, 4>(m0, m, n0, n); - break; - case 0x44: - mc = 4; - nc = 4; - gemm<4, 4>(m0, m, n0, n); - break; - case 0x53: - mc = 5; - nc = 3; - gemm<5, 3>(m0, m, n0, n); - break; - case 0x35: - mc = 3; - nc = 5; - gemm<3, 5>(m0, m, n0, n); - break; - case 0x43: - mc = 4; - nc = 3; - gemm<4, 3>(m0, m, n0, n); - break; -#else - case 0x55: - case 0x54: - case 0x53: - case 0x45: - case 0x44: - case 0x43: - mc = 4; - nc = 3; - gemm<4, 3>(m0, m, n0, n); - break; - case 0x35: -#endif - case 0x34: - mc = 3; - nc = 4; - gemm<3, 4>(m0, m, n0, n); - break; - case 0x52: - mc = 5; - nc = 2; - gemm<5, 2>(m0, m, n0, n); - break; - case 0x33: - mc = 3; - nc = 3; - gemm<3, 3>(m0, m, n0, n); - break; - case 0x25: - mc = 2; - nc = 5; - gemm<2, 5>(m0, m, n0, n); - break; - case 0x42: - mc = 4; - nc = 2; - gemm<4, 2>(m0, m, n0, n); - break; - case 0x24: - mc = 2; - nc = 4; - gemm<2, 4>(m0, m, n0, n); - break; - case 0x32: - mc = 3; - nc = 2; - gemm<3, 2>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm<2, 3>(m0, m, n0, n); - break; - case 0x51: - mc = 5; - nc = 1; - gemm<5, 1>(m0, m, n0, n); - break; - case 0x41: - mc = 4; - nc = 1; - gemm<4, 1>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm<2, 2>(m0, m, n0, n); - break; - case 0x15: - mc = 1; - nc = 5; - gemm<1, 5>(m0, m, n0, n); - break; - case 0x14: - mc = 1; - nc = 4; - gemm<1, 4>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm<3, 1>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm<1, 3>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm<2, 1>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm<1, 1>(m0, m, n0, n); - break; - default: - return; - } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; - mnpack(mp, m, n0, np); - mnpack(m0, m, np, n); - } - - template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) - end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; - D Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; l += KN) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = madd(load(A + lda * (ii + i) + l), - load(B + ldb * (jj + j) + l), - Cv[j][i]); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); - } - } - - const TA *const A; - const TB *const B; - TC *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; - -////////////////////////////////////////////////////////////////////////////////////////// -// QUANT ZERO MATRIX MULTIPLICATION - -#if defined(__ARM_FEATURE_DOTPROD) -template -class tinyBLAS_Q0_ARM { - public: - tinyBLAS_Q0_ARM(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth) - : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - } - - void matmul(int64_t m, int64_t n) { - mnpack(0, m, 0, n); - } - - private: - NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { - case 0x33: - mc = 3; - nc = 3; - gemm<3, 3>(m0, m, n0, n); - break; - case 0x32: - mc = 3; - nc = 2; - gemm<3, 2>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm<2, 3>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm<2, 2>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm<3, 1>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm<1, 3>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm<2, 1>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm<1, 1>(m0, m, n0, n); - break; - default: - return; - } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; - mnpack(mp, m, n0, np); - mnpack(m0, m, np, n); - } - - template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) - end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; - float32x4_t Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; ++l) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - Cv[j][i] = vmlaq_n_f32(Cv[j][i], - vcvtq_f32_s32(vdotq_s32( - vdotq_s32(vdupq_n_s32(0), - load_lo(A + lda * (ii + i) + l), - load_lo(B + ldb * (jj + j) + l)), - load_hi(A + lda * (ii + i) + l), - load_hi(B + ldb * (jj + j) + l))), - unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)); - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); - } - } - - inline int8x16_t load_lo(const block_q8_0 *b) { - return vld1q_s8(b->qs); - } - - inline int8x16_t load_hi(const block_q8_0 *b) { - return vld1q_s8(b->qs + 16); - } - - inline int8x16_t load_lo(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), - vdupq_n_u8(0x0f))), - vdupq_n_s8(0x8)); - } - - inline int8x16_t load_hi(const block_q4_0 *b) { - return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), - vdupq_n_s8(0x8)); - } - - const TA *const A; - const block_q8_0 *const B; - float *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; -#endif // __ARM_FEATURE_DOTPROD - -#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) -template -class tinyBLAS_Q0_AVX { - public: - tinyBLAS_Q0_AVX(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, - int ith, int nth) - : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - } - - void matmul(int64_t m, int64_t n) { - mnpack(0, m, 0, n); - } - - private: - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { -#if VECTOR_REGISTERS == 32 - case 0x44: - mc = 4; - nc = 4; -#if defined(__AVX2__) && defined(__F16C__) - gemm4xN<4>(m0, m, n0, n); -#else - gemm<4, 4>(m0, m, n0, n); -#endif - break; - case 0x43: - mc = 4; - nc = 3; -#if defined(__AVX2__) && defined(__F16C__) - gemm4xN<3>(m0, m, n0, n); -#else - gemm<4, 3>(m0, m, n0, n); -#endif - break; - case 0x34: - mc = 3; - nc = 4; -#if defined(__AVX2__) && defined(__F16C__) - gemmMx4<3>(m0, m, n0, n); -#else - gemm<3, 4>(m0, m, n0, n); -#endif - break; - case 0x33: - mc = 3; - nc = 3; - gemm<3, 3>(m0, m, n0, n); - break; - case 0x42: - mc = 4; - nc = 2; -#if defined(__AVX2__) && defined(__F16C__) - gemm4xN<2>(m0, m, n0, n); -#else - gemm<4, 2>(m0, m, n0, n); -#endif - break; - case 0x24: - mc = 2; - nc = 4; -#if defined(__AVX2__) && defined(__F16C__) - gemmMx4<2>(m0, m, n0, n); -#else - gemm<2, 4>(m0, m, n0, n); -#endif - break; -#else - case 0x44: - case 0x43: - case 0x42: - mc = 4; - nc = 2; -#if defined(__AVX2__) && defined(__F16C__) - gemm4xN<2>(m0, m, n0, n); -#else - gemm<4, 2>(m0, m, n0, n); -#endif - break; - case 0x34: - case 0x24: - mc = 2; - nc = 4; -#if defined(__AVX2__) && defined(__F16C__) - gemmMx4<2>(m0, m, n0, n); -#else - gemm<2, 4>(m0, m, n0, n); -#endif - break; - case 0x33: -#endif - case 0x32: - mc = 3; - nc = 2; - gemm<3, 2>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm<2, 3>(m0, m, n0, n); - break; - case 0x41: - mc = 4; - nc = 1; -#if defined(__AVX2__) && defined(__F16C__) - gemm4xN<1>(m0, m, n0, n); -#else - gemm<4, 1>(m0, m, n0, n); -#endif - break; - case 0x22: - mc = 2; - nc = 2; - gemm<2, 2>(m0, m, n0, n); - break; - case 0x14: - mc = 1; - nc = 4; -#if defined(__AVX2__) && defined(__F16C__) - gemmMx4<1>(m0, m, n0, n); -#else - gemm<1, 4>(m0, m, n0, n); -#endif - break; - case 0x31: - mc = 3; - nc = 1; - gemm<3, 1>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm<1, 3>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm<2, 1>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm<1, 1>(m0, m, n0, n); - break; - default: - return; - } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; - mnpack(mp, m, n0, np); - mnpack(m0, m, np, n); - } - -#if defined(__AVX2__) && defined(__F16C__) -// Templated functions for gemm of dimensions 4xN - template - NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / 4; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) - end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * 4; - int64_t jj = n0 + job % xtiles * RN; - __m256 Cv[RN][4] = {}; - for (int64_t l = 0; l < k; ++l) { - uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d); - // Convert delta values for four blocks to float values - __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta)); - __m256i avec0 = load(A + lda * (ii + 0) + l); - __m256i avec1 = load(A + lda * (ii + 1) + l); - __m256i avec2 = load(A + lda * (ii + 2) + l); - __m256i avec3 = load(A + lda * (ii + 3) + l); - for (int64_t j = 0; j < RN; ++j) { - __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d)); - // Computation of product of delta values for four blocks and replicate it across 256 bit lane - __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db)); - dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); - // Computation of dot product and multiplication with appropriate delta value products - Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0), - updot(_mm256_sign_epi8(avec0, avec0), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)), - Cv[j][0]); - Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85), - updot(_mm256_sign_epi8(avec1, avec1), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)), - Cv[j][1]); - Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170), - updot(_mm256_sign_epi8(avec2, avec2), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)), - Cv[j][2]); - Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255), - updot(_mm256_sign_epi8(avec3, avec3), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)), - Cv[j][3]); - } - } - - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < 4; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); - } - } - - // Templated functions for gemm of dimensions Mx4 - template - NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / 4; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) - end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * 4; - __m256 Cv[4][RM] = {}; - for (int64_t l = 0; l < k; ++l) { - uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d); - // Convert delta values for four blocks to float values - __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta)); - __m256i bvec0 = load(B + ldb * (jj + 0) + l); - __m256i bvec1 = load(B + ldb * (jj + 1) + l); - __m256i bvec2 = load(B + ldb * (jj + 2) + l); - __m256i bvec3 = load(B + ldb * (jj + 3) + l); - for (int64_t i = 0; i < RM; ++i) { - __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d))); - // Computation of product of delta values for four blocks and replicate it across 256 bit lane - __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db)); - dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); - // Computation of dot product and multiplication with appropriate delta value products - Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))), - Cv[0][i]); - Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))), - Cv[1][i]); - Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))), - Cv[2][i]); - Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255), - updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))), - Cv[3][i]); - } - } - for (int64_t j = 0; j < 4; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); - } - } -#endif - - template - NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t ytiles = (m - m0) / RM; - int64_t xtiles = (n - n0) / RN; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) - end = tiles; - for (int64_t job = start; job < end; ++job) { - int64_t ii = m0 + job / xtiles * RM; - int64_t jj = n0 + job % xtiles * RN; - __m256 Cv[RN][RM] = {}; - for (int64_t l = 0; l < k; ++l) - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) { -#if defined(__AVX2__) - __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), - load(A + lda * (ii + i) + l)), - _mm256_sign_epi8(load(B + ldb * (jj + j) + l), - load(A + lda * (ii + i) + l))); -#else - __m128i ali0 = load0(A + lda * (ii + i) + l); - __m128i ali1 = load1(A + lda * (ii + i) + l); - __m128i blj0 = load0(B + ldb * (jj + j) + l); - __m128i blj1 = load1(B + ldb * (jj + j) + l); - - __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); - __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); - __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); - __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); - - // updot - const __m128i oneFill = _mm_set1_epi16(1); - __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); - __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); - __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); -#endif - Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * - unhalf(B[ldb * (jj + j) + l].d)), - udTmp, - Cv[j][i]); - } - for (int64_t j = 0; j < RN; ++j) - for (int64_t i = 0; i < RM; ++i) - C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); - } - } - - inline __m256i load(const block_q8_0 *b) { - return _mm256_loadu_si256((const __m256i *)b->qs); - } - - inline __m128i load0(const block_q8_0 *b) { - return _mm_loadu_si128((const __m128i *)b->qs); - } - - inline __m128i load1(const block_q8_0 *b) { - return _mm_loadu_si128(((const __m128i *)b->qs) + 1); - } - - inline __m256i load(const block_q4_0 *b) { - return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); - } - - inline __m128i load0(const block_q4_0 *b) { - const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); - return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); - } - - inline __m128i load1(const block_q4_0 *b) { - const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); - return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); - } - - inline __m256 updot(__m256i u, __m256i s) { - __m256i res; -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) - res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); -#else - res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); -#endif - return _mm256_cvtepi32_ps(res); - } - - static inline __m256i denibble(const uint8_t *p) { - __m128i x = _mm_loadu_si128((const __m128i *)p); - return _mm256_and_si256(_mm256_set1_epi8(15), - _mm256_insertf128_si256(_mm256_castsi128_si256(x), - _mm_srli_epi16(x, 4), 1)); - } - - const TA *const A; - const TB *const B; - TC *const C; - const int64_t k; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; -#endif // __AVX__ - -} // namespace - -/** - * Performs optimized matrix multiplication on CPU. - * - * This subroutine may compute C = Aᵀ * B with column major ordering. - * Despite its name, this isn't a generalized implementation. Work is - * only performed when a handwritten kernel is written and available. - * Otherwise the caller should fall back to a general matmul routine. - * - * For example, for single-threaded single-precision GEMM you can say - * - * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, - * 0, 1, - * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32); - * - * @param m is rows in `A` and `C` - * @param n is cols in `B` and `C` - * @param k is cols in `A` and rows in `B` - * @param A is first input matrix (always transposed) - * @param lda is row stride of `A` - * @param B is second input matrix (never transposed) - * @param ldb is row stride of `B` - * @param C is input/output array of output matrices - * @param ldc is row stride of `C` - * @param ith is thread id (must be less than `nth`) - * @param nth is number of threads (must be greater than zero) - * @param Atype is GGML data type of `A` - * @param Btype is GGML data type of `B` - * @param Ctype is GGML data type of `C` - * @return true if this function was able to service the matmul request - */ -bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, - int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) { - - assert(m >= 0); - assert(n >= 0); - assert(k >= 0); - assert(lda >= k); - assert(ldb >= k); - assert(ldc >= m); - assert(nth > 0); - assert(ith < nth); - - // only enable sgemm for prompt processing - if (n < 2) - return false; - - if (Ctype != GGML_TYPE_F32) - return false; - - switch (Atype) { - - case GGML_TYPE_F32: { - if (Btype != GGML_TYPE_F32) - return false; -#if defined(__AVX512F__) - if (k % 16) - return false; - tinyBLAS<16, __m512, __m512, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#elif defined(__AVX__) || defined(__AVX2__) - if (k % 8) - return false; - tinyBLAS<8, __m256, __m256, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#elif defined(__ARM_NEON) - if (n < 4) - return false; - if (k % 4) - return false; - tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ - k, (const float *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#else - return false; -#endif - } - - case GGML_TYPE_F16: { -#if defined(__AVX512F__) - if (k % 16) - return false; - if (Btype != GGML_TYPE_F32) - return false; - tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) - if (k % 8) - return false; - if (Btype != GGML_TYPE_F32) - return false; - tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) - if (n < 8) - return false; - if (k % 8) - return false; - if (Btype != GGML_TYPE_F16) - return false; - tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#elif defined(__ARM_NEON) && !defined(_MSC_VER) - if (k % 4) - return false; - if (Btype != GGML_TYPE_F32) - return false; - tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ - k, (const ggml_fp16_t *)A, lda, - (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#else - return false; -#endif - } - - case GGML_TYPE_Q8_0: { - if (Btype != GGML_TYPE_Q8_0) - return false; -#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) - tinyBLAS_Q0_AVX tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ - k, (const block_q8_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#else - return false; -#endif - } - - case GGML_TYPE_Q4_0: { - if (Btype != GGML_TYPE_Q8_0) - return false; -#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) - tinyBLAS_Q0_AVX tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#elif defined(__ARM_FEATURE_DOTPROD) - tinyBLAS_Q0_ARM tb{ - k, (const block_q4_0 *)A, lda, - (const block_q8_0 *)B, ldb, - (float *)C, ldc, - ith, nth}; - tb.matmul(m, n); - return true; -#else - return false; -#endif - } - - default: - return false; - } - - (void)m; - (void)n; - (void)k; - (void)A; - (void)lda; - (void)B; - (void)ldb; - (void)C; - (void)ldc; - (void)ith; - (void)nth; - (void)Atype; - (void)Btype; - (void)Ctype; -} diff --git a/ggml/src/llamafile/sgemm.h b/ggml/src/llamafile/sgemm.h deleted file mode 100644 index caf6dd556..000000000 --- a/ggml/src/llamafile/sgemm.h +++ /dev/null @@ -1,14 +0,0 @@ -#pragma once -#include -#include -#ifdef __cplusplus -extern "C" { -#endif - -bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, - const void *, int64_t, void *, int64_t, int, int, - int, int, int); - -#ifdef __cplusplus -} -#endif diff --git a/ggml/src/vulkan-shaders/acc.comp b/ggml/src/vulkan-shaders/acc.comp deleted file mode 100644 index 4c8739efe..000000000 --- a/ggml/src/vulkan-shaders/acc.comp +++ /dev/null @@ -1,24 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_binary_head.comp" - -void main() { - const uint idx = gl_GlobalInvocationID.x; - if (idx >= p.ne) { - return; - } - - const uint offset = p.param3; - const uint src1_i = idx - offset; - const uint oz = src1_i / p.nb02; - const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; - const uint ox = src1_i % p.nb01; - - if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); - } else { - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)])); - } -} - diff --git a/ggml/src/vulkan-shaders/add.comp b/ggml/src/vulkan-shaders/add.comp deleted file mode 100644 index 3974845d6..000000000 --- a/ggml/src/vulkan-shaders/add.comp +++ /dev/null @@ -1,14 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_binary_head.comp" - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)])); -} diff --git a/ggml/src/vulkan-shaders/clamp.comp b/ggml/src/vulkan-shaders/clamp.comp deleted file mode 100644 index 7071302a4..000000000 --- a/ggml/src/vulkan-shaders/clamp.comp +++ /dev/null @@ -1,15 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_unary_head.comp" - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); -} diff --git a/ggml/src/vulkan-shaders/copy.comp b/ggml/src/vulkan-shaders/copy.comp deleted file mode 100644 index c26917c0f..000000000 --- a/ggml/src/vulkan-shaders/copy.comp +++ /dev/null @@ -1,18 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_unary_head.comp" - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - -#ifndef OPTIMIZATION_ERROR_WORKAROUND - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]); -#else - data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)]; -#endif -} diff --git a/ggml/src/vulkan-shaders/cos.comp b/ggml/src/vulkan-shaders/cos.comp deleted file mode 100644 index f9a858cbf..000000000 --- a/ggml/src/vulkan-shaders/cos.comp +++ /dev/null @@ -1,15 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_unary_head.comp" - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val)); -} diff --git a/ggml/src/vulkan-shaders/dequant_funcs.comp b/ggml/src/vulkan-shaders/dequant_funcs.comp deleted file mode 100644 index d5b989735..000000000 --- a/ggml/src/vulkan-shaders/dequant_funcs.comp +++ /dev/null @@ -1,68 +0,0 @@ -#if !defined(DATA_A_F32) && !defined(DATA_A_F16) -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require -#endif - -#if defined(DATA_A_F32) -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); -} -#endif - -#if defined(DATA_A_F16) -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); -} -#endif - -#if defined(DATA_A_Q4_0) -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; -} -#endif - -#if defined(DATA_A_Q4_1) -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - const float m = float(data_a[a_offset + ib].m); - const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return vec2(vui & 0xF, vui >> 4) * d + m; -} -#endif - -#if defined(DATA_A_Q5_0) -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; -} -#endif - -#if defined(DATA_A_Q5_1) -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - const float m = float(data_a[a_offset + ib].m); - const uint uint_qh = data_a[a_offset + ib].qh; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; -} -#endif - -#if defined(DATA_A_Q8_0) -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d; -} -#endif - -#if defined(DATA_A_IQ4_NL) -vec2 dequantize(uint ib, uint iqs, uint a_offset) { - const float d = float(data_a[a_offset + ib].d); - const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; -} -#endif diff --git a/ggml/src/vulkan-shaders/dequant_q4_k.comp b/ggml/src/vulkan-shaders/dequant_q4_k.comp deleted file mode 100644 index 92acb7540..000000000 --- a/ggml/src/vulkan-shaders/dequant_q4_k.comp +++ /dev/null @@ -1,56 +0,0 @@ -#version 450 - -#include "dequant_head.comp" - -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { - const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { - return; - } - - const uint tid = gl_LocalInvocationID.x; - const uint il = tid / 8; - const uint ir = tid % 8; - const uint is = 2 * il; - const uint n = 4; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); - - const uint y_idx = i * QUANT_K + 64 * il + n * ir; - const uint qs_idx = 32*il + n * ir; - - uint8_t sc; - uint8_t m; - if (is < 4) { - sc = uint8_t(data_a[i].scales[is] & 63); - m = uint8_t(data_a[i].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4)); - m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4)); - } - const FLOAT_TYPE d1 = dall * sc; - const FLOAT_TYPE m1 = dmin * m; - - if (is < 4) { - sc = uint8_t(data_a[i].scales[is + 1] & 63); - m = uint8_t(data_a[i].scales[is + 5] & 63); - } else { - sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4)); - m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4)); - } - const FLOAT_TYPE d2 = dall * sc; - const FLOAT_TYPE m2 = dmin * m; - - [[unroll]] for (uint l = 0; l < n; ++l) { - data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1); - data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2); - } - } -} diff --git a/ggml/src/vulkan-shaders/dequant_q5_k.comp b/ggml/src/vulkan-shaders/dequant_q5_k.comp deleted file mode 100644 index f314a76d1..000000000 --- a/ggml/src/vulkan-shaders/dequant_q5_k.comp +++ /dev/null @@ -1,58 +0,0 @@ -#version 450 - -#include "dequant_head.comp" - -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; - -void main() { - [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { - const uint i = gl_WorkGroupID.x * 256 + wgy; - if (i >= p.M * p.K / QUANT_K) { - return; - } - - const uint tid = gl_LocalInvocationID.x; - const uint il = tid / 16; - const uint ir = tid % 16; - const uint is = 2 * il; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); - - const uint y_idx = i * QUANT_K + 64 * il + 2 * ir; - const uint qs_idx = 32*il + 2 * ir; - const uint qh_idx = 2 * ir; - - uint8_t sc; - uint8_t m; - if (is < 4) { - sc = uint8_t(data_a[i].scales[is] & 63); - m = uint8_t(data_a[i].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4)); - m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4)); - } - const FLOAT_TYPE d1 = dall * sc; - const FLOAT_TYPE m1 = dmin * m; - - if (is < 4) { - sc = uint8_t(data_a[i].scales[is + 1] & 63); - m = uint8_t(data_a[i].scales[is + 5] & 63); - } else { - sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4)); - m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4)); - } - const FLOAT_TYPE d2 = dall * sc; - const FLOAT_TYPE m2 = dmin * m; - - const uint8_t hm1 = uint8_t(1 << (2 * il )); - const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); - data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx ] & 0xF) + (((data_a[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); - data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); - data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx ] >> 4) + (((data_a[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); - data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); - } -} diff --git a/ggml/src/vulkan-shaders/div.comp b/ggml/src/vulkan-shaders/div.comp deleted file mode 100644 index 8cfce58b1..000000000 --- a/ggml/src/vulkan-shaders/div.comp +++ /dev/null @@ -1,14 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_binary_head.comp" - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)])); -} diff --git a/ggml/src/vulkan-shaders/generic_binary_head.comp b/ggml/src/vulkan-shaders/generic_binary_head.comp deleted file mode 100644 index b6beaff1c..000000000 --- a/ggml/src/vulkan-shaders/generic_binary_head.comp +++ /dev/null @@ -1,52 +0,0 @@ -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint ne; - uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; - uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; - uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; - uint d_offset; - float param1; float param2; int param3; -} p; - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; - -uint get_idx() { - return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -} - -uint src0_idx(uint idx) { - const uint i03 = idx / (p.ne02*p.ne01*p.ne00); - const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; - const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); - const uint i02_offset = i02*p.ne01*p.ne00; - const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; - const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; - return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; -} - -uint src1_idx(uint idx) { - const uint i03 = idx / (p.ne02*p.ne01*p.ne00); - const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; - const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); - const uint i02_offset = i02*p.ne01*p.ne00; - const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; - const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; - - return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10; -} - -uint dst_idx(uint idx) { - const uint i23 = idx / (p.ne22*p.ne21*p.ne20); - const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20; - const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20); - const uint i22_offset = i22*p.ne21*p.ne20; - const uint i21 = (idx - i23_offset - i22_offset) / p.ne20; - const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20; - return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20; -} diff --git a/ggml/src/vulkan-shaders/generic_unary_head.comp b/ggml/src/vulkan-shaders/generic_unary_head.comp deleted file mode 100644 index eacdefc7d..000000000 --- a/ggml/src/vulkan-shaders/generic_unary_head.comp +++ /dev/null @@ -1,39 +0,0 @@ -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint ne; - uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; - uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; - uint d_offset; - float param1; float param2; -} p; - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -uint get_idx() { - return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; -} - -uint src0_idx(uint idx) { - const uint i03 = idx / (p.ne02*p.ne01*p.ne00); - const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; - const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00); - const uint i02_offset = i02*p.ne01*p.ne00; - const uint i01 = (idx - i03_offset - i02_offset) / p.ne00; - const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; - return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; -} - -uint dst_idx(uint idx) { - const uint i13 = idx / (p.ne12*p.ne11*p.ne10); - const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; - const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); - const uint i12_offset = i12*p.ne11*p.ne10; - const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; - const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; - return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; -} diff --git a/ggml/src/vulkan-shaders/im2col.comp b/ggml/src/vulkan-shaders/im2col.comp deleted file mode 100644 index 4d48610a3..000000000 --- a/ggml/src/vulkan-shaders/im2col.comp +++ /dev/null @@ -1,57 +0,0 @@ -#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint batch_offset; uint offset_delta; - uint IC; - uint IW; uint IH; - uint OW; uint OH; - uint KW; uint KH; - uint pelements; - uint CHW; - int s0; int s1; - int p0; int p1; - int d0; int d1; -} p; - -#include "types.comp" - -#define BLOCK_SIZE 256 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint i = gl_GlobalInvocationID.x; - if (i >= p.pelements) { - return; - } - - const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); - const uint kx = i / ksize; - const uint kd = kx * ksize; - const uint ky = (i - kd) / p.OW; - const uint ix = i % p.OW; - - const uint oh = gl_GlobalInvocationID.y; - const uint batch = gl_GlobalInvocationID.z / p.IC; - const uint ic = gl_GlobalInvocationID.z % p.IC; - - const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; - - const uint offset_dst = - ((batch * p.OH + oh) * p.OW + ix) * p.CHW + - (ic * (p.KW * p.KH) + ky * p.KW + kx); - - if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) { - data_d[offset_dst] = D_TYPE(0.0f); - } else { - const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; - data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]); - } -} diff --git a/ggml/src/vulkan-shaders/mul.comp b/ggml/src/vulkan-shaders/mul.comp deleted file mode 100644 index bfb61c92d..000000000 --- a/ggml/src/vulkan-shaders/mul.comp +++ /dev/null @@ -1,14 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_binary_head.comp" - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)])); -} diff --git a/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp b/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp deleted file mode 100644 index 825b91031..000000000 --- a/ggml/src/vulkan-shaders/mul_mat_split_k_reduce.comp +++ /dev/null @@ -1,29 +0,0 @@ -#version 450 - -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A {float data_a[];}; -layout (binding = 1) writeonly buffer D {float data_d[];}; - -layout (push_constant) uniform parameter { - uint ne; - uint k_num; -} p; - -void main() { - const uint idx = gl_GlobalInvocationID.x; - - if (idx >= p.ne) { - return; - } - - float result = 0.0f; - - [[unroll]] for (uint i = 0; i < p.k_num; i++) { - result += data_a[i * p.ne + idx]; - } - - data_d[idx] = result; -} diff --git a/ggml/src/vulkan-shaders/mul_mat_vec.comp b/ggml/src/vulkan-shaders/mul_mat_vec.comp deleted file mode 100644 index d3ccba7fc..000000000 --- a/ggml/src/vulkan-shaders/mul_mat_vec.comp +++ /dev/null @@ -1,56 +0,0 @@ -#version 450 - -#ifdef FLOAT16 -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#endif - -#include "mul_mat_vec_base.comp" - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 0) const uint BLOCK_SIZE = 32; - -shared FLOAT_TYPE tmp[BLOCK_SIZE]; - -void main() { - const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; - const uint tid = gl_LocalInvocationID.x; - - // There are not enough cols to use all threads - if (tid >= p.ncols) { - return; - } - - const uint block_size = min(p.ncols, BLOCK_SIZE); - - uint a_offset, b_offset, d_offset; - get_offsets(a_offset, b_offset, d_offset); - - const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; - - tmp[tid] = FLOAT_TYPE(0.0f); - - [[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) { - const uint col = i*block_size + 2*tid; - const uint ib = (row*p.ncols + col)/QUANT_K; // block index - const uint iqs = (col%QUANT_K)/QUANT_R; // quant index - const uint iybs = col - col%QUANT_K; // y block start index - - vec2 v = dequantize(ib, iqs, a_offset / QUANT_K); - - // matrix multiplication - tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid])); - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp deleted file mode 100644 index 5920bc936..000000000 --- a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp +++ /dev/null @@ -1,81 +0,0 @@ -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_8bit_storage : require - -#define K_QUANTS_PER_ITERATION 2 - -#ifdef MUL_MAT_ID -#define EXPERT_COUNT 8 -#endif - -#include "types.comp" - -layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; -#ifdef MUL_MAT_ID -layout (binding = 3) readonly buffer IDS {int data_ids[];}; -#endif - -#include "dequant_funcs.comp" - -layout (push_constant) uniform parameter -{ - uint ncols; - uint stride_a; - uint stride_b; - uint stride_d; - - uint batch_stride_a; - uint batch_stride_b; - uint batch_stride_d; - -#ifdef MUL_MAT_ID - uint nei0; - uint ne11; -#else - uint ne02; - uint ne12; - uint broadcast2; - uint broadcast3; -#endif -} p; - -void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { -#ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.y; -#else - const uint batch_idx = gl_GlobalInvocationID.y; -#endif - -#ifndef MUL_MAT_ID - const uint i13 = batch_idx / p.ne12; - const uint i12 = batch_idx % p.ne12; - - const uint i03 = i13 / p.broadcast3; - const uint i02 = i12 / p.broadcast2; - - const uint batch_idx_a = i03 * p.ne02 + i02; -#else - const uint expert_id = data_ids[expert_idx]; -#endif - - a_offset = -#ifdef MUL_MAT_ID - expert_id * p.batch_stride_a; -#else - batch_idx_a * p.batch_stride_a; -#endif - b_offset = -#ifdef MUL_MAT_ID - (expert_idx % p.ne11) * p.stride_b; -#else - batch_idx * p.batch_stride_b; -#endif - d_offset = -#ifdef MUL_MAT_ID - expert_idx * p.stride_d; -#else - batch_idx * p.batch_stride_d; -#endif -} diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp deleted file mode 100644 index ec8eadcd5..000000000 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp +++ /dev/null @@ -1,74 +0,0 @@ -#version 450 - -#include "mul_mat_vec_base.comp" - -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; - - uint a_offset, b_offset, d_offset; - get_offsets(a_offset, b_offset, d_offset); - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = tid - step*v_im; // 0...15 or 0...7 - - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 - const uint q_offset = 32*v_im + l0; - const uint s_offset = 8*v_im; - const uint y_offset = 128*v_im + l0; - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint y_idx = i * QUANT_K + y_offset; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); - - FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); - FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3), sum1)))))))); - sum2 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]), FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF), sum2)))))))); - } - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(dall, sum1, fma(-dmin, sum2, tmp[tmp_idx])); - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp deleted file mode 100644 index 3ca4ad85a..000000000 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp +++ /dev/null @@ -1,67 +0,0 @@ -#version 450 - -#include "mul_mat_vec_base.comp" - -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; - - uint a_offset, b_offset, d_offset; - get_offsets(a_offset, b_offset, d_offset); - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = tid - step*v_im; // 0...15 or 0...7 - - const uint8_t m = uint8_t(1 << (4 * v_im)); - - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 - const uint q_offset = 32*v_im + l0; - const uint y_offset = 128*v_im + l0; - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - const uint s_shift = 4 * v_im; - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint y_idx = i * QUANT_K + y_offset; - - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); - } - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(d, sum, tmp[tmp_idx]); - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp deleted file mode 100644 index d91e00e10..000000000 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp +++ /dev/null @@ -1,118 +0,0 @@ -#version 450 - -#include "mul_mat_vec_base.comp" - -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; - - uint a_offset, b_offset, d_offset; - get_offsets(a_offset, b_offset, d_offset); - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 - - const uint il = tid/step; // 0...3 - const uint ir = tid - step*il; // 0...7 or 0...3 - const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 - - const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const uint v_in = il % 2; - - const uint l0 = n * (2 * ir + v_in); // 0...15 - const uint q_offset = 32*v_im + l0; - const uint y_offset = 64*v_im + l0; - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); - - const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f); - const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f); - const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f); - const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f); - const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); - const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); - const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); - const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); - -#if K_QUANTS_PER_ITERATION == 2 - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4); - const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf); - const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf); - const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4); - const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4); - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]), q4_0, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), q4_1, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3))); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6, FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7))); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]), q4_8, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), q4_9, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11))); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7, - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7, - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7, - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7))))))))))))))); - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx])); -#else - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7, - + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7))))))); - - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + - sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f), - fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx])); -#endif - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp deleted file mode 100644 index 2306785af..000000000 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q5_k.comp +++ /dev/null @@ -1,109 +0,0 @@ -#version 450 - -#include "mul_mat_vec_base.comp" - -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; - - uint a_offset, b_offset, d_offset; - get_offsets(a_offset, b_offset, d_offset); - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1 - - const uint il = tid/4; // 0...3 - const uint ir = tid - 4*il; // 0...7 or 0...3 - - const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const uint v_in = il % 2; - - const uint l0 = 4*ir + 2*v_in; // 0...15 - const uint q_offset = 32*v_im + l0; - const uint y_offset = 64*v_im + l0; - - const uint8_t hm1 = uint8_t(1 << (2*v_im)); - const uint8_t hm2 = uint8_t(hm1 << 4); - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { - const uint y1_idx = i * QUANT_K + y_offset; - const uint y2_idx = y1_idx + 128; - - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y); - - const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f); - const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f); - const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f); - const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f); - const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2)); - const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2)); - const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); - const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); - - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4); - const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf); - const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf); - const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4); - const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4); - - const FLOAT_TYPE sx = - fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)), - FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))))); - const FLOAT_TYPE sy = - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)), - FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))))); - const FLOAT_TYPE sz = - fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)), - FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))))); - const FLOAT_TYPE sw = - fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)), - fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)), - FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))))); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2, - fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3, - fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6, - (FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7))); - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx])); - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp deleted file mode 100644 index 95c286eeb..000000000 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp +++ /dev/null @@ -1,79 +0,0 @@ -#version 450 - -#include "mul_mat_vec_base.comp" - -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; - -shared FLOAT_TYPE tmp[32]; - -void main() { - const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; - - uint a_offset, b_offset, d_offset; - get_offsets(a_offset, b_offset, d_offset); - - const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = tid - step*v_im; // 0...15 or 0...7 - -#if K_QUANTS_PER_ITERATION == 1 - const uint l0 = v_in; // 0...15 - const uint is = 0; -#else - const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 - const uint is = v_in / 4; -#endif - - const uint ql_offset = 64*v_im + l0; - const uint qh_offset = 32*v_im + l0; - const uint s_offset = 8*v_im + is; - const uint y_offset = 128*v_im + l0; - - tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - const uint y_idx = i * QUANT_K + y_offset; - - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - -#if K_QUANTS_PER_ITERATION == 1 - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx])))))))); -#else - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 4; ++l) { - sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum)))); - } - tmp[16 * ix + tid] += sum; -#endif - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); - } -} diff --git a/ggml/src/vulkan-shaders/scale.comp b/ggml/src/vulkan-shaders/scale.comp deleted file mode 100644 index 5cd2f668d..000000000 --- a/ggml/src/vulkan-shaders/scale.comp +++ /dev/null @@ -1,14 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_unary_head.comp" - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(p.param1)); -} diff --git a/ggml/src/vulkan-shaders/sin.comp b/ggml/src/vulkan-shaders/sin.comp deleted file mode 100644 index 7faf9be93..000000000 --- a/ggml/src/vulkan-shaders/sin.comp +++ /dev/null @@ -1,15 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_unary_head.comp" - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val)); -} diff --git a/ggml/src/vulkan-shaders/soft_max.comp b/ggml/src/vulkan-shaders/soft_max.comp deleted file mode 100644 index 0bd51ecab..000000000 --- a/ggml/src/vulkan-shaders/soft_max.comp +++ /dev/null @@ -1,106 +0,0 @@ -#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout (push_constant) uniform parameter -{ - uint KX; - uint KY; - float scale; - float max_bias; - float m0; - float m1; - uint n_head_log2; -} p; - -#include "types.comp" - -#extension GL_EXT_control_flow_attributes : enable -#define BLOCK_SIZE 512 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; -layout (binding = 2) buffer D {D_TYPE data_d[];}; - -shared FLOAT_TYPE vals[BLOCK_SIZE]; - -void main() { - const uint tid = gl_LocalInvocationID.x; - const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint rowy = rowx % p.KY; - - float slope = 1.0f; - - // ALiBi - if (p.max_bias > 0.0f) { - const uint h = rowx/p.KY; // head index - - const float base = h < p.n_head_log2 ? p.m0 : p.m1; - const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; - - slope = pow(base, exp); - } - - // Find max - FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); - - [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { - const uint col = col0 + tid; - - if (col >= p.KX) { - break; - } - - max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f))); - } - vals[tid] = max_val; - - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - vals[tid] = max(vals[tid], vals[tid + s]); - } - barrier(); - } - - max_val = vals[0]; - barrier(); - - // Sum up values - vals[tid] = FLOAT_TYPE(0.0f); - - [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { - const uint col = col0 + tid; - - if (col >= p.KX) { - break; - } - - const uint i = rowx * p.KX + col; - const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); - vals[tid] += val; - data_d[i] = D_TYPE(val); - } - - barrier(); - [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { - if (tid < s) { - vals[tid] += vals[tid + s]; - } - barrier(); - } - - const D_TYPE divisor = D_TYPE(vals[0]); - - [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { - const uint col = col0 + tid; - - if (col >= p.KX) { - break; - } - - data_d[rowx*p.KX + col] /= divisor; - } -} diff --git a/ggml/src/vulkan-shaders/square.comp b/ggml/src/vulkan-shaders/square.comp deleted file mode 100644 index 1fa118c99..000000000 --- a/ggml/src/vulkan-shaders/square.comp +++ /dev/null @@ -1,15 +0,0 @@ -#version 450 - -#include "types.comp" -#include "generic_unary_head.comp" - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]); - data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val); -} diff --git a/ggml/src/vulkan-shaders/types.comp b/ggml/src/vulkan-shaders/types.comp deleted file mode 100644 index 21dce72fc..000000000 --- a/ggml/src/vulkan-shaders/types.comp +++ /dev/null @@ -1,200 +0,0 @@ -#if !defined(DATA_A_F32) && !defined(DATA_A_F16) -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require -#endif - -#if defined(DATA_A_F32) -#define QUANT_K 1 -#define QUANT_R 1 - -#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -#define A_TYPE float -#elif LOAD_VEC_A == 4 -#define A_TYPE vec4 -#elif LOAD_VEC_A == 8 -#define A_TYPE mat2x4 -#endif -#endif - -#if defined(DATA_A_F16) -#define QUANT_K 1 -#define QUANT_R 1 - -#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 -#define A_TYPE float16_t -#elif LOAD_VEC_A == 4 -#define A_TYPE f16vec4 -#elif LOAD_VEC_A == 8 -#define A_TYPE f16mat2x4 -#endif -#endif - -#if defined(DATA_A_Q4_0) -#extension GL_EXT_shader_16bit_storage : require -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q4_0 -{ - float16_t d; - uint8_t qs[16]; -}; - -#define A_TYPE block_q4_0 -#endif - -#if defined(DATA_A_Q4_1) -#extension GL_EXT_shader_16bit_storage : require -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q4_1 -{ - float16_t d; - float16_t m; - uint8_t qs[16]; -}; - -#define A_TYPE block_q4_1 -#endif - -#if defined(DATA_A_Q5_0) -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q5_0 -{ - float16_t d; - uint16_t qh[2]; - uint8_t qs[16]; -}; - -#define A_TYPE block_q5_0 -#endif - -#if defined(DATA_A_Q5_1) -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q5_1 -{ - float16_t d; - float16_t m; - uint qh; - uint8_t qs[16]; -}; - -#define A_TYPE block_q5_1 -#endif - -#if defined(DATA_A_Q8_0) -#extension GL_EXT_shader_16bit_storage : require -#define QUANT_K 32 -#define QUANT_R 1 - -struct block_q8_0 -{ - float16_t d; - int8_t qs[32]; -}; - -#define A_TYPE block_q8_0 -#endif - -// K-quants -#if defined(DATA_A_Q2_K) -#extension GL_EXT_shader_16bit_storage : require -#define QUANT_K 256 - -struct block_q2_K -{ - uint8_t scales[QUANT_K/16]; - uint8_t qs[QUANT_K/4]; - f16vec2 d; -}; - -#define A_TYPE block_q2_K -#endif - -#if defined(DATA_A_Q3_K) -#extension GL_EXT_shader_16bit_storage : require -#define QUANT_K 256 - -struct block_q3_K -{ - uint8_t hmask[QUANT_K/8]; - uint8_t qs[QUANT_K/4]; - uint8_t scales[12]; - float16_t d; -}; - -#define A_TYPE block_q3_K -#endif - -#if defined(DATA_A_Q4_K) -#extension GL_EXT_shader_16bit_storage : require -#define QUANT_K 256 - -struct block_q4_K -{ - f16vec2 d; - uint8_t scales[3*QUANT_K/64]; - uint8_t qs[QUANT_K/2]; -}; - -#define A_TYPE block_q4_K -#endif - -#if defined(DATA_A_Q5_K) -#extension GL_EXT_shader_16bit_storage : require -#define QUANT_K 256 - -struct block_q5_K -{ - f16vec2 d; - uint8_t scales[12]; - uint8_t qh[QUANT_K/8]; - uint8_t qs[QUANT_K/2]; -}; - -#define A_TYPE block_q5_K -#endif - -#if defined(DATA_A_Q6_K) -#extension GL_EXT_shader_16bit_storage : require -#define QUANT_K 256 - -struct block_q6_K -{ - uint8_t ql[QUANT_K/2]; - uint8_t qh[QUANT_K/4]; - int8_t scales[QUANT_K/16]; - float16_t d; -}; - -#define A_TYPE block_q6_K -#endif - -// IQuants - -#if defined(DATA_A_IQ4_NL) -#extension GL_EXT_shader_16bit_storage : require -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_iq4_nl -{ - float16_t d; - uint8_t qs[QUANT_K/2]; -}; - -#define A_TYPE block_iq4_nl - -const int8_t kvalues_iq4nl[16] = { - int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), - int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) -}; -#endif diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp deleted file mode 100644 index 1bd1b6f67..000000000 --- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +++ /dev/null @@ -1,600 +0,0 @@ - - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef _WIN32 - #include - #include // For _mkdir on Windows - #include // For std::replace on w64devkit -#else - #include - #include - #include -#endif - -#define ASYNCIO_CONCURRENCY 64 - -std::mutex lock; -std::vector> shader_fnames; - -std::string GLSLC = "glslc"; -std::string input_dir = "vulkan-shaders"; -std::string output_dir = "/tmp"; -std::string target_hpp = "ggml-vulkan-shaders.hpp"; -std::string target_cpp = "ggml-vulkan-shaders.cpp"; -bool no_clean = false; - -const std::vector type_names = { - "f32", - "f16", - "q4_0", - "q4_1", - "q5_0", - "q5_1", - "q8_0", - "q2_k", - "q3_k", - "q4_k", - "q5_k", - "q6_k", - "iq4_nl" -}; - -void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { -#ifdef _WIN32 - HANDLE stdout_read, stdout_write; - HANDLE stderr_read, stderr_write; - SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; - - if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) || - !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) { - throw std::runtime_error("Failed to create stdout pipe"); - } - - if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) || - !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) { - throw std::runtime_error("Failed to create stderr pipe"); - } - - PROCESS_INFORMATION pi; - STARTUPINFOA si = { sizeof(STARTUPINFOA) }; - si.dwFlags = STARTF_USESTDHANDLES; - si.hStdOutput = stdout_write; - si.hStdError = stderr_write; - - std::vector cmd(command.begin(), command.end()); - cmd.push_back('\0'); - - if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) { - throw std::runtime_error("Failed to create process"); - } - - CloseHandle(stdout_write); - CloseHandle(stderr_write); - - std::array buffer; - DWORD bytes_read; - - while (ReadFile(stdout_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) { - stdout_str.append(buffer.data(), bytes_read); - } - - while (ReadFile(stderr_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) { - stderr_str.append(buffer.data(), bytes_read); - } - - CloseHandle(stdout_read); - CloseHandle(stderr_read); - WaitForSingleObject(pi.hProcess, INFINITE); - CloseHandle(pi.hProcess); - CloseHandle(pi.hThread); -#else -int stdout_pipe[2]; - int stderr_pipe[2]; - - if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { - throw std::runtime_error("Failed to create pipes"); - } - - pid_t pid = fork(); - if (pid < 0) { - throw std::runtime_error("Failed to fork process"); - } - - if (pid == 0) { - close(stdout_pipe[0]); - close(stderr_pipe[0]); - dup2(stdout_pipe[1], STDOUT_FILENO); - dup2(stderr_pipe[1], STDERR_FILENO); - close(stdout_pipe[1]); - close(stderr_pipe[1]); - execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr); - _exit(EXIT_FAILURE); - } else { - close(stdout_pipe[1]); - close(stderr_pipe[1]); - - std::array buffer; - ssize_t bytes_read; - - while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) { - stdout_str.append(buffer.data(), bytes_read); - } - - while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) { - stderr_str.append(buffer.data(), bytes_read); - } - - close(stdout_pipe[0]); - close(stderr_pipe[0]); - waitpid(pid, nullptr, 0); - } -#endif -} - -bool directory_exists(const std::string& path) { - struct stat info; - if (stat(path.c_str(), &info) != 0) { - return false; // Path doesn't exist or can't be accessed - } - return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory -} - -bool create_directory(const std::string& path) { -#ifdef _WIN32 - return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists -#else - return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions -#endif -} - -std::string to_uppercase(const std::string& input) { - std::string result = input; - for (char& c : result) { - c = std::toupper(c); - } - return result; -} - -bool string_ends_with(const std::string& str, const std::string& suffix) { - if (suffix.size() > str.size()) { - return false; - } - return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); -} - -static const char path_separator = '/'; - -std::string join_paths(const std::string& path1, const std::string& path2) { - return path1 + path_separator + path2; -} - -std::string basename(const std::string &path) { - return path.substr(path.find_last_of("/\\") + 1); -} - -void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true) { - std::string name = _name + (fp16 ? "" : "_fp32"); - std::string out_fname = join_paths(output_dir, name + ".spv"); - std::string in_path = join_paths(input_dir, in_fname); - - #ifdef _WIN32 - std::vector cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""}; - #else - std::vector cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname}; - #endif - - #ifdef GGML_VULKAN_SHADER_DEBUG_INFO - cmd.push_back("-g"); - #endif - - for (const auto& define : defines) { - cmd.push_back("-D" + define.first + "=" + define.second); - } - - std::string command; - for (const auto& part : cmd) { - command += part + " "; - } - - std::string stdout_str, stderr_str; - try { - // std::cout << "Executing command: "; - // for (const auto& part : cmd) { - // std::cout << part << " "; - // } - // std::cout << std::endl; - - execute_command(command, stdout_str, stderr_str); - if (!stderr_str.empty()) { - std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl; - return; - } - - std::lock_guard guard(lock); - shader_fnames.push_back(std::make_pair(name, out_fname)); - } catch (const std::exception& e) { - std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; - } -} - -std::map merge_maps(const std::map& a, const std::map& b) { - std::map result = a; - result.insert(b.begin(), b.end()); - return result; -} - -void matmul_shaders(std::vector>& tasks, bool fp16, bool matmul_id) { - std::string load_vec = fp16 ? "8" : "4"; - std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4"; - std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4"; - - std::map base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}}; - std::string shader_name = "matmul"; - - if (matmul_id) { - base_dict["MUL_MAT_ID"] = "1"; - shader_name = "matmul_id"; - } - - if (fp16) { - base_dict["FLOAT16"] = "1"; - } - - // Shaders with f16 B_TYPE - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16); - })); - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16); - })); - - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16); - })); - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16); - })); - - for (const auto& tname : type_names) { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (tname == "f32" || tname == "f16") ? "1" : "2"; - // For aligned matmul loads - std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2"; - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16); - })); - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16); - })); - } -} - -void process_shaders(std::vector>& tasks) { - std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; - std::map base_dict = {{"FLOAT_TYPE", "float"}}; - - for (const auto& fp16 : {false, true}) { - matmul_shaders(tasks, fp16, false); - matmul_shaders(tasks, fp16, true); - } - - for (const auto& tname : type_names) { - // mul mat vec - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; - - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - })); - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); - })); - - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - })); - - // Dequant shaders - if (tname != "f16") { - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); - })); - } - - if (!string_ends_with(tname, "_k")) { - shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; - - if (tname == "f16") { - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); - })); - } else { - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}); - })); - } - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}); - })); - } - } - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - - // Norms - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - })); - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - })); - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - })); - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - })); - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); - })); - - tasks.push_back(std::async(std::launch::async, [] { - string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); - })); - - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - })); - - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - })); - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); - })); - - tasks.push_back(std::async(std::launch::async, [=] { - string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - })); -} - -void write_output_files() { - FILE* hdr = fopen(target_hpp.c_str(), "w"); - FILE* src = fopen(target_cpp.c_str(), "w"); - - fprintf(hdr, "#include \n\n"); - fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str()); - - for (const auto& pair : shader_fnames) { - const std::string& name = pair.first; - #ifdef _WIN32 - std::string path = pair.second; - std::replace(path.begin(), path.end(), '/', '\\' ); - #else - const std::string& path = pair.second; - #endif - - FILE* spv = fopen(path.c_str(), "rb"); - if (!spv) { - std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; - continue; - } - - fseek(spv, 0, SEEK_END); - size_t size = ftell(spv); - fseek(spv, 0, SEEK_SET); - - std::vector data(size); - size_t read_size = fread(data.data(), 1, size, spv); - fclose(spv); - if (read_size != size) { - std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n"; - continue; - } - - fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size); - fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size); - - fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size); - for (size_t i = 0; i < size; ++i) { - fprintf(src, "0x%02x,", data[i]); - if ((i + 1) % 12 == 0) fprintf(src, "\n"); - } - fprintf(src, "\n};\n\n"); - - if (!no_clean) { - std::remove(path.c_str()); - } - } - - fclose(hdr); - fclose(src); -} - -int main(int argc, char** argv) { - std::map args; - for (int i = 1; i < argc; i += 2) { - if (i + 1 < argc) { - args[argv[i]] = argv[i + 1]; - } - } - - if (args.find("--glslc") != args.end()) { - GLSLC = args["--glslc"]; // Path to glslc - } - if (args.find("--input-dir") != args.end()) { - input_dir = args["--input-dir"]; // Directory containing shader sources - } - if (args.find("--output-dir") != args.end()) { - output_dir = args["--output-dir"]; // Directory for containing SPIR-V output - } - if (args.find("--target-hpp") != args.end()) { - target_hpp = args["--target-hpp"]; // Path to generated header file - } - if (args.find("--target-cpp") != args.end()) { - target_cpp = args["--target-cpp"]; // Path to generated cpp file - } - if (args.find("--no-clean") != args.end()) { - no_clean = true; // Keep temporary SPIR-V files in output-dir after build - } - - if (!directory_exists(input_dir)) { - std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl; - return EXIT_FAILURE; - } - - if (!directory_exists(output_dir)) { - if (!create_directory(output_dir)) { - std::cerr << "Error creating output directory: " << output_dir << "\n"; - return EXIT_FAILURE; - } - } - - std::vector> tasks; - process_shaders(tasks); - - for (auto& task : tasks) { - task.get(); - } - - write_output_files(); - - return EXIT_SUCCESS; -} diff --git a/gguf-py/README.md b/gguf-py/README.md index 24af96a17..2e513633d 100644 --- a/gguf-py/README.md +++ b/gguf-py/README.md @@ -15,13 +15,15 @@ pip install gguf [examples/writer.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model. -[scripts/gguf_dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console. +[examples/reader.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/reader.py) — Extracts and displays key-value pairs and tensor details from a GGUF file in a readable format. -[scripts/gguf_set_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key. +[gguf/scripts/gguf_dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console. -[scripts/gguf_convert_endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files. +[gguf/scripts/gguf_set_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key. -[scripts/gguf_new_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values. +[gguf/scripts/gguf_convert_endian.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_convert_endian.py) — Allows converting the endianness of GGUF files. + +[gguf/scripts/gguf_new_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values. ## Development Maintainers who participate in development of this package are advised to install it in editable mode: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ae90d70a6..b741ada81 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -64,20 +64,33 @@ class Keys: BASE_MODEL_AUTHOR = "general.base_model.{id}.author" BASE_MODEL_VERSION = "general.base_model.{id}.version" BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization" + BASE_MODEL_DESCRIPTION = "general.base_model.{id}.description" BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper BASE_MODEL_DOI = "general.base_model.{id}.doi" BASE_MODEL_UUID = "general.base_model.{id}.uuid" BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...) + # Dataset Source + DATASET_COUNT = "general.dataset.count" + DATASET_NAME = "general.dataset.{id}.name" + DATASET_AUTHOR = "general.dataset.{id}.author" + DATASET_VERSION = "general.dataset.{id}.version" + DATASET_ORGANIZATION = "general.dataset.{id}.organization" + DATASET_DESCRIPTION = "general.dataset.{id}.description" + DATASET_URL = "general.dataset.{id}.url" # Model Website/Paper + DATASET_DOI = "general.dataset.{id}.doi" + DATASET_UUID = "general.dataset.{id}.uuid" + DATASET_REPO_URL = "general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...) + # Array based KV stores TAGS = "general.tags" LANGUAGES = "general.languages" - DATASETS = "general.datasets" class LLM: VOCAB_SIZE = "{arch}.vocab_size" CONTEXT_LENGTH = "{arch}.context_length" EMBEDDING_LENGTH = "{arch}.embedding_length" + FEATURES_LENGTH = "{arch}.features_length" BLOCK_COUNT = "{arch}.block_count" LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" @@ -89,14 +102,20 @@ class Keys: EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" + EXPERT_GATING_FUNC = "{arch}.expert_gating_func" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + SWIN_NORM = "{arch}.swin_norm" RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim" TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" + RESIDUAL_SCALE = "{arch}.residual_scale" + EMBEDDING_SCALE = "{arch}.embedding_scale" + TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -107,14 +126,18 @@ class Keys: VALUE_LENGTH = "{arch}.attention.value_length" LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" + GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon" + GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups" CAUSAL = "{arch}.attention.causal" Q_LORA_RANK = "{arch}.attention.q_lora_rank" KV_LORA_RANK = "{arch}.attention.kv_lora_rank" REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" SLIDING_WINDOW = "{arch}.attention.sliding_window" + SCALE = "{arch}.attention.scale" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" + DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" FREQ_BASE = "{arch}.rope.freq_base" SCALING_TYPE = "{arch}.rope.scaling.type" SCALING_FACTOR = "{arch}.rope.scaling.factor" @@ -138,6 +161,14 @@ class Keys: class WKV: HEAD_SIZE = "{arch}.wkv.head_size" + class PosNet: + EMBEDDING_LENGTH = "{arch}.posnet.embedding_length" + BLOCK_COUNT = "{arch}.posnet.block_count" + + class ConvNext: + EMBEDDING_LENGTH = "{arch}.convnext.embedding_length" + BLOCK_COUNT = "{arch}.convnext.block_count" + class Tokenizer: MODEL = "tokenizer.ggml.model" PRE = "tokenizer.ggml.pre" @@ -148,10 +179,11 @@ class Keys: MERGES = "tokenizer.ggml.merges" BOS_ID = "tokenizer.ggml.bos_token_id" EOS_ID = "tokenizer.ggml.eos_token_id" + EOT_ID = "tokenizer.ggml.eot_token_id" + EOM_ID = "tokenizer.ggml.eom_token_id" UNK_ID = "tokenizer.ggml.unknown_token_id" SEP_ID = "tokenizer.ggml.seperator_token_id" PAD_ID = "tokenizer.ggml.padding_token_id" - CLS_ID = "tokenizer.ggml.cls_token_id" MASK_ID = "tokenizer.ggml.mask_token_id" ADD_BOS = "tokenizer.ggml.add_bos_token" ADD_EOS = "tokenizer.ggml.add_eos_token" @@ -164,11 +196,16 @@ class Keys: CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" CHAT_TEMPLATES = "tokenizer.chat_templates" # FIM/Infill special tokens constants + FIM_PRE_ID = "tokenizer.ggml.fim_pre_token_id" + FIM_SUF_ID = "tokenizer.ggml.fim_suf_token_id" + FIM_MID_ID = "tokenizer.ggml.fim_mid_token_id" + FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id" + FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id" + FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id" + # deprecated: PREFIX_ID = "tokenizer.ggml.prefix_token_id" SUFFIX_ID = "tokenizer.ggml.suffix_token_id" MIDDLE_ID = "tokenizer.ggml.middle_token_id" - EOT_ID = "tokenizer.ggml.eot_token_id" - EOM_ID = "tokenizer.ggml.eom_token_id" class Adapter: TYPE = "adapter.type" @@ -192,50 +229,63 @@ class GGUFType: class MODEL_ARCH(IntEnum): - LLAMA = auto() - FALCON = auto() - BAICHUAN = auto() - GROK = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - MPT = auto() - STARCODER = auto() - REFACT = auto() - BERT = auto() - NOMIC_BERT = auto() - JINA_BERT_V2 = auto() - BLOOM = auto() - STABLELM = auto() - QWEN = auto() - QWEN2 = auto() - QWEN2MOE = auto() - PHI2 = auto() - PHI3 = auto() - PLAMO = auto() - CODESHELL = auto() - ORION = auto() - INTERNLM2 = auto() - MINICPM = auto() - GEMMA = auto() - GEMMA2 = auto() - STARCODER2 = auto() - RWKV6 = auto() - MAMBA = auto() - XVERSE = auto() - COMMAND_R = auto() - DBRX = auto() - OLMO = auto() - OPENELM = auto() - ARCTIC = auto() - DEEPSEEK2 = auto() - CHATGLM = auto() - BITNET = auto() - T5 = auto() - T5ENCODER = auto() - JAIS = auto() - NEMOTRON = auto() - EXAONE = auto() + LLAMA = auto() + DECI = auto() + FALCON = auto() + BAICHUAN = auto() + GROK = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + STARCODER = auto() + REFACT = auto() + BERT = auto() + NOMIC_BERT = auto() + JINA_BERT_V2 = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + QWEN2MOE = auto() + QWEN2VL = auto() + PHI2 = auto() + PHI3 = auto() + PHIMOE = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + MINICPM3 = auto() + GEMMA = auto() + GEMMA2 = auto() + STARCODER2 = auto() + RWKV6 = auto() + RWKV6QWEN2 = auto() + MAMBA = auto() + XVERSE = auto() + COMMAND_R = auto() + COHERE2 = auto() + DBRX = auto() + OLMO = auto() + OLMO2 = auto() + OLMOE = auto() + OPENELM = auto() + ARCTIC = auto() + DEEPSEEK = auto() + DEEPSEEK2 = auto() + CHATGLM = auto() + BITNET = auto() + T5 = auto() + T5ENCODER = auto() + JAIS = auto() + NEMOTRON = auto() + EXAONE = auto() + GRANITE = auto() + GRANITE_MOE = auto() + CHAMELEON = auto() + WAVTOKENIZER_DEC = auto() class MODEL_TENSOR(IntEnum): @@ -274,6 +324,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() + FFN_EXP_PROBS_B = auto() ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() @@ -291,6 +342,7 @@ class MODEL_TENSOR(IntEnum): TIME_MIX_LERP_V = auto() TIME_MIX_LERP_R = auto() TIME_MIX_LERP_G = auto() + TIME_MIX_LERP_FUSED = auto() TIME_MIX_LERP_W = auto() TIME_MIX_FIRST = auto() TIME_MIX_DECAY = auto() @@ -343,53 +395,84 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_DOWN = auto() ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() + CLS = auto() # classifier + CLS_OUT = auto() # classifier output projection + CONV1D = auto() + CONVNEXT_DW = auto() + CONVNEXT_NORM = auto() + CONVNEXT_PW1 = auto() + CONVNEXT_PW2 = auto() + CONVNEXT_GAMMA = auto() + POSNET_CONV1 = auto() + POSNET_CONV2 = auto() + POSNET_NORM = auto() + POSNET_NORM1 = auto() + POSNET_NORM2 = auto() + POSNET_ATTN_NORM = auto() + POSNET_ATTN_Q = auto() + POSNET_ATTN_K = auto() + POSNET_ATTN_V = auto() + POSNET_ATTN_OUT = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { - MODEL_ARCH.LLAMA: "llama", - MODEL_ARCH.FALCON: "falcon", - MODEL_ARCH.BAICHUAN: "baichuan", - MODEL_ARCH.GROK: "grok", - MODEL_ARCH.GPT2: "gpt2", - MODEL_ARCH.GPTJ: "gptj", - MODEL_ARCH.GPTNEOX: "gptneox", - MODEL_ARCH.MPT: "mpt", - MODEL_ARCH.STARCODER: "starcoder", - MODEL_ARCH.REFACT: "refact", - MODEL_ARCH.BERT: "bert", - MODEL_ARCH.NOMIC_BERT: "nomic-bert", - MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", - MODEL_ARCH.BLOOM: "bloom", - MODEL_ARCH.STABLELM: "stablelm", - MODEL_ARCH.QWEN: "qwen", - MODEL_ARCH.QWEN2: "qwen2", - MODEL_ARCH.QWEN2MOE: "qwen2moe", - MODEL_ARCH.PHI2: "phi2", - MODEL_ARCH.PHI3: "phi3", - MODEL_ARCH.PLAMO: "plamo", - MODEL_ARCH.CODESHELL: "codeshell", - MODEL_ARCH.ORION: "orion", - MODEL_ARCH.INTERNLM2: "internlm2", - MODEL_ARCH.MINICPM: "minicpm", - MODEL_ARCH.GEMMA: "gemma", - MODEL_ARCH.GEMMA2: "gemma2", - MODEL_ARCH.STARCODER2: "starcoder2", - MODEL_ARCH.RWKV6: "rwkv6", - MODEL_ARCH.MAMBA: "mamba", - MODEL_ARCH.XVERSE: "xverse", - MODEL_ARCH.COMMAND_R: "command-r", - MODEL_ARCH.DBRX: "dbrx", - MODEL_ARCH.OLMO: "olmo", - MODEL_ARCH.OPENELM: "openelm", - MODEL_ARCH.ARCTIC: "arctic", - MODEL_ARCH.DEEPSEEK2: "deepseek2", - MODEL_ARCH.CHATGLM: "chatglm", - MODEL_ARCH.BITNET: "bitnet", - MODEL_ARCH.T5: "t5", - MODEL_ARCH.T5ENCODER: "t5encoder", - MODEL_ARCH.JAIS: "jais", - MODEL_ARCH.NEMOTRON: "nemotron", - MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.DECI: "deci", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GROK: "grok", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.MPT: "mpt", + MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.REFACT: "refact", + MODEL_ARCH.BERT: "bert", + MODEL_ARCH.NOMIC_BERT: "nomic-bert", + MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", + MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.STABLELM: "stablelm", + MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", + MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN2VL: "qwen2vl", + MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PHI3: "phi3", + MODEL_ARCH.PHIMOE: "phimoe", + MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", + MODEL_ARCH.ORION: "orion", + MODEL_ARCH.INTERNLM2: "internlm2", + MODEL_ARCH.MINICPM: "minicpm", + MODEL_ARCH.MINICPM3: "minicpm3", + MODEL_ARCH.GEMMA: "gemma", + MODEL_ARCH.GEMMA2: "gemma2", + MODEL_ARCH.STARCODER2: "starcoder2", + MODEL_ARCH.RWKV6: "rwkv6", + MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", + MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.XVERSE: "xverse", + MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.COHERE2: "cohere2", + MODEL_ARCH.DBRX: "dbrx", + MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.OLMO2: "olmo2", + MODEL_ARCH.OLMOE: "olmoe", + MODEL_ARCH.OPENELM: "openelm", + MODEL_ARCH.ARCTIC: "arctic", + MODEL_ARCH.DEEPSEEK: "deepseek", + MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.T5: "t5", + MODEL_ARCH.T5ENCODER: "t5encoder", + MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.GRANITE: "granite", + MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.CHAMELEON: "chameleon", + MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -430,6 +513,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", @@ -445,6 +529,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v", MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r", MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g", + MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused", MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w", MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first", MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay", @@ -497,6 +582,24 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + MODEL_TENSOR.CLS: "cls", + MODEL_TENSOR.CLS_OUT: "cls.output", + MODEL_TENSOR.CONV1D: "conv1d", + MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw", + MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm", + MODEL_TENSOR.CONVNEXT_PW1: "convnext.{bid}.pw1", + MODEL_TENSOR.CONVNEXT_PW2: "convnext.{bid}.pw2", + MODEL_TENSOR.CONVNEXT_GAMMA: "convnext.{bid}.gamma", + MODEL_TENSOR.POSNET_CONV1: "posnet.{bid}.conv1", + MODEL_TENSOR.POSNET_CONV2: "posnet.{bid}.conv2", + MODEL_TENSOR.POSNET_NORM: "posnet.{bid}.norm", + MODEL_TENSOR.POSNET_NORM1: "posnet.{bid}.norm1", + MODEL_TENSOR.POSNET_NORM2: "posnet.{bid}.norm2", + MODEL_TENSOR.POSNET_ATTN_NORM: "posnet.{bid}.attn_norm", + MODEL_TENSOR.POSNET_ATTN_Q: "posnet.{bid}.attn_q", + MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", + MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", + MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -520,6 +623,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.DECI: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.GROK: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -606,6 +729,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, @@ -637,6 +762,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.LAYER_OUT_NORM, + MODEL_TENSOR.CLS, ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, @@ -723,6 +849,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_UP, ], MODEL_ARCH.QWEN2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.QWEN2VL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, @@ -800,6 +941,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, @@ -810,6 +953,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.PHIMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.CODESHELL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.POS_EMBD, @@ -859,6 +1020,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.OUTPUT, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, MODEL_TENSOR.ATTN_NORM, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, @@ -874,6 +1037,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.MINICPM3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FACTORS_LONG, + MODEL_TENSOR.ROPE_FACTORS_SHORT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.GEMMA: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -932,6 +1114,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.TIME_MIX_LERP_R, MODEL_TENSOR.TIME_MIX_LERP_G, MODEL_TENSOR.TIME_MIX_LERP_W, + MODEL_TENSOR.TIME_MIX_LERP_FUSED, MODEL_TENSOR.TIME_MIX_FIRST, MODEL_TENSOR.TIME_MIX_DECAY, MODEL_TENSOR.TIME_MIX_DECAY_W1, @@ -948,6 +1131,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE, MODEL_TENSOR.CHANNEL_MIX_VALUE, ], + MODEL_ARCH.RWKV6QWEN2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.TIME_MIX_W1, + MODEL_TENSOR.TIME_MIX_W2, + MODEL_TENSOR.TIME_MIX_LERP_X, + MODEL_TENSOR.TIME_MIX_LERP_K, + MODEL_TENSOR.TIME_MIX_LERP_V, + MODEL_TENSOR.TIME_MIX_LERP_R, + MODEL_TENSOR.TIME_MIX_LERP_G, + MODEL_TENSOR.TIME_MIX_LERP_W, + MODEL_TENSOR.TIME_MIX_LERP_FUSED, + MODEL_TENSOR.TIME_MIX_FIRST, + MODEL_TENSOR.TIME_MIX_DECAY, + MODEL_TENSOR.TIME_MIX_DECAY_W1, + MODEL_TENSOR.TIME_MIX_DECAY_W2, + MODEL_TENSOR.TIME_MIX_KEY, + MODEL_TENSOR.TIME_MIX_VALUE, + MODEL_TENSOR.TIME_MIX_RECEPTANCE, + MODEL_TENSOR.TIME_MIX_GATE, + MODEL_TENSOR.TIME_MIX_LN, + MODEL_TENSOR.TIME_MIX_OUTPUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.MAMBA: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -991,6 +1203,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_K_NORM, MODEL_TENSOR.ATTN_Q_NORM, ], + MODEL_ARCH.COHERE2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.DBRX: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1015,6 +1239,39 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.OLMO2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.OLMOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + ], MODEL_ARCH.OPENELM: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1049,6 +1306,29 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.DEEPSEEK: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], MODEL_ARCH.DEEPSEEK2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1075,6 +1355,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, ], MODEL_ARCH.CHATGLM : [ MODEL_TENSOR.TOKEN_EMBD, @@ -1193,6 +1474,73 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.GRANITE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.GRANITE_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], + MODEL_ARCH.CHAMELEON: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], + MODEL_ARCH.WAVTOKENIZER_DEC: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.CONV1D, + MODEL_TENSOR.CONVNEXT_DW, + MODEL_TENSOR.CONVNEXT_NORM, + MODEL_TENSOR.CONVNEXT_PW1, + MODEL_TENSOR.CONVNEXT_PW2, + MODEL_TENSOR.CONVNEXT_GAMMA, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.POSNET_CONV1, + MODEL_TENSOR.POSNET_CONV2, + MODEL_TENSOR.POSNET_NORM, + MODEL_TENSOR.POSNET_NORM1, + MODEL_TENSOR.POSNET_NORM2, + MODEL_TENSOR.POSNET_ATTN_NORM, + MODEL_TENSOR.POSNET_ATTN_Q, + MODEL_TENSOR.POSNET_ATTN_K, + MODEL_TENSOR.POSNET_ATTN_V, + MODEL_TENSOR.POSNET_ATTN_OUT, + ], # TODO } @@ -1202,6 +1550,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DECI: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], MODEL_ARCH.BAICHUAN: [ MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, @@ -1226,6 +1578,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], + MODEL_ARCH.DEEPSEEK: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], MODEL_ARCH.DEEPSEEK2: [ MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, @@ -1254,9 +1610,10 @@ class TokenType(IntEnum): class RopeScalingType(Enum): - NONE = 'none' - LINEAR = 'linear' - YARN = 'yarn' + NONE = 'none' + LINEAR = 'linear' + YARN = 'yarn' + LONGROPE = 'longrope' class PoolingType(IntEnum): @@ -1295,13 +1652,15 @@ class GGMLQuantizationType(IntEnum): F64 = 28 IQ1_M = 29 BF16 = 30 - Q4_0_4_4 = 31 - Q4_0_4_8 = 32 - Q4_0_8_8 = 33 TQ1_0 = 34 TQ2_0 = 35 +class ExpertGatingFuncType(IntEnum): + SOFTMAX = 1 + SIGMOID = 2 + + # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -1341,9 +1700,9 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors - MOSTLY_Q4_0_4_4 = 33 # except 1d tensors - MOSTLY_Q4_0_4_8 = 34 # except 1d tensors - MOSTLY_Q4_0_8_8 = 35 # except 1d tensors + # MOSTLY_Q4_0_4_4 = 33 # removed from gguf files, use Q4_0 and runtime repack + # MOSTLY_Q4_0_4_8 = 34 # removed from gguf files, use Q4_0 and runtime repack + # MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack MOSTLY_TQ1_0 = 36 # except 1d tensors MOSTLY_TQ2_0 = 37 # except 1d tensors @@ -1419,9 +1778,6 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.BF16: (1, 2), - GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16), - GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16), - GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), GGMLQuantizationType.TQ2_0: (256, 2 + 64), } @@ -1482,15 +1838,23 @@ KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID +KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID +KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID -KEY_TOKENIZER_CLS_ID = Keys.Tokenizer.CLS_ID KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV -KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID + +KEY_TOKENIZER_FIM_PRE_ID = Keys.Tokenizer.FIM_PRE_ID +KEY_TOKENIZER_FIM_SUF_ID = Keys.Tokenizer.FIM_SUF_ID +KEY_TOKENIZER_FIM_MID_ID = Keys.Tokenizer.FIM_MID_ID +KEY_TOKENIZER_FIM_PAD_ID = Keys.Tokenizer.FIM_PAD_ID +KEY_TOKENIZER_FIM_REP_ID = Keys.Tokenizer.FIM_REP_ID +KEY_TOKENIZER_FIM_SEP_ID = Keys.Tokenizer.FIM_SEP_ID + +# deprecated +KEY_TOKENIZER_PREFIX_ID = Keys.Tokenizer.PREFIX_ID KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID -KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID -KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index e8e61abf8..e17a4e831 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -145,11 +145,10 @@ class GGUFReader: count = int(count) itemsize = int(np.empty([], dtype = dtype).itemsize) end_offs = offset + itemsize * count - return ( - self.data[offset:end_offs] - .view(dtype = dtype)[:count] - .newbyteorder(override_order or self.byte_order) - ) + arr = self.data[offset:end_offs].view(dtype=dtype)[:count] + if override_order is None: + return arr + return arr.view(arr.dtype.newbyteorder(override_order)) def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: if field.name in self.fields: diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 3c95c2673..080d2b9dc 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,7 @@ from .constants import ( RopeScalingType, PoolingType, TokenType, + ExpertGatingFuncType, ) from .quants import quant_shape_from_byte_shape @@ -568,6 +569,9 @@ class GGUFWriter: def add_base_model_organization(self, source_id: int, organization: str) -> None: self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization) + def add_base_model_description(self, source_id: int, description: str) -> None: + self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description) + def add_base_model_url(self, source_id: int, url: str) -> None: self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url) @@ -580,15 +584,42 @@ class GGUFWriter: def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None: self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url) + def add_dataset_count(self, source_count: int) -> None: + self.add_uint32(Keys.General.DATASET_COUNT, source_count) + + def add_dataset_name(self, source_id: int, name: str) -> None: + self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name) + + def add_dataset_author(self, source_id: int, author: str) -> None: + self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author) + + def add_dataset_version(self, source_id: int, version: str) -> None: + self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version) + + def add_dataset_organization(self, source_id: int, organization: str) -> None: + self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization) + + def add_dataset_description(self, source_id: int, description: str) -> None: + self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description) + + def add_dataset_url(self, source_id: int, url: str) -> None: + self.add_string(Keys.General.DATASET_URL.format(id=source_id), url) + + def add_dataset_doi(self, source_id: int, doi: str) -> None: + self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi) + + def add_dataset_uuid(self, source_id: int, uuid: str) -> None: + self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid) + + def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None: + self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url) + def add_tags(self, tags: Sequence[str]) -> None: self.add_array(Keys.General.TAGS, tags) def add_languages(self, languages: Sequence[str]) -> None: self.add_array(Keys.General.LANGUAGES, languages) - def add_datasets(self, datasets: Sequence[str]) -> None: - self.add_array(Keys.General.DATASETS, datasets) - def add_tensor_data_layout(self, layout: str) -> None: self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) @@ -601,6 +632,21 @@ class GGUFWriter: def add_embedding_length(self, length: int) -> None: self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length) + def add_features_length(self, length: int) -> None: + self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length) + + def add_posnet_embedding_length(self, length: int) -> None: + self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length) + + def add_posnet_block_count(self, length: int) -> None: + self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length) + + def add_convnext_embedding_length(self, length: int) -> None: + self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length) + + def add_convnext_block_count(self, length: int) -> None: + self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length) + def add_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length) @@ -670,6 +716,15 @@ class GGUFWriter: def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) + def add_expert_weights_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value) + + def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: + self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + + def add_swin_norm(self, value: bool) -> None: + self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) + def add_rescale_every_n_layers(self, count: int) -> None: self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count) @@ -679,15 +734,30 @@ class GGUFWriter: def add_time_decay_extra_dim(self, dim: int) -> None: self.add_uint32(Keys.LLM.TIME_DECAY_EXTRA_DIM.format(arch=self.arch), dim) + def add_residual_scale(self, value: float) -> None: + self.add_float32(Keys.LLM.RESIDUAL_SCALE.format(arch=self.arch), value) + + def add_embedding_scale(self, value: float) -> None: + self.add_float32(Keys.LLM.EMBEDDING_SCALE.format(arch=self.arch), value) + def add_wkv_head_size(self, size: int) -> None: self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size) + def add_token_shift_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count) + def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) def add_layer_norm_rms_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + def add_group_norm_eps(self, value: float) -> None: + self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value) + + def add_group_norm_groups(self, value: int) -> None: + self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value) + def add_causal_attention(self, value: bool) -> None: self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value) @@ -703,12 +773,18 @@ class GGUFWriter: def add_sliding_window(self, value: int) -> None: self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value) + def add_attention_scale(self, value: float) -> None: + self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value) + def add_pooling_type(self, value: PoolingType) -> None: self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value) def add_rope_dimension_count(self, count: int) -> None: self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count) + def add_rope_dimension_sections(self, dims: Sequence[int]) -> None: + self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims) + def add_rope_freq_base(self, value: float) -> None: self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value) @@ -781,9 +857,6 @@ class GGUFWriter: def add_pad_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.PAD_ID, id) - def add_cls_token_id(self, id: int) -> None: - self.add_uint32(Keys.Tokenizer.CLS_ID, id) - def add_mask_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.MASK_ID, id) @@ -831,15 +904,6 @@ class GGUFWriter: self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value) - def add_prefix_token_id(self, id: int) -> None: - self.add_uint32(Keys.Tokenizer.PREFIX_ID, id) - - def add_suffix_token_id(self, id: int) -> None: - self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id) - - def add_middle_token_id(self, id: int) -> None: - self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id) - def add_eot_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.EOT_ID, id) diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index db318542a..962c27b20 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -41,7 +41,7 @@ class Metadata: base_models: Optional[list[dict]] = None tags: Optional[list[str]] = None languages: Optional[list[str]] = None - datasets: Optional[list[str]] = None + datasets: Optional[list[dict]] = None @staticmethod def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata: @@ -91,9 +91,11 @@ class Metadata: # Base Models is received here as an array of models metadata.base_models = metadata_override.get("general.base_models", metadata.base_models) + # Datasets is received here as an array of datasets + metadata.datasets = metadata_override.get("general.datasets", metadata.datasets) + metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages) - metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets) # Direct Metadata Override (via direct cli argument) if model_name is not None: @@ -346,12 +348,12 @@ class Metadata: use_model_card_metadata("author", "model_creator") use_model_card_metadata("basename", "model_type") - if "base_model" in model_card: + if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card: # This represents the parent models that this is based on # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges) # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md metadata_base_models = [] - base_model_value = model_card.get("base_model", None) + base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None))) if base_model_value is not None: if isinstance(base_model_value, str): @@ -364,18 +366,106 @@ class Metadata: for model_id in metadata_base_models: # NOTE: model size of base model is assumed to be similar to the size of the current model - model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) base_model = {} - if model_full_name_component is not None: - base_model["name"] = Metadata.id_to_title(model_full_name_component) - if org_component is not None: - base_model["organization"] = Metadata.id_to_title(org_component) - if version is not None: - base_model["version"] = version - if org_component is not None and model_full_name_component is not None: - base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}" + if isinstance(model_id, str): + if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"): + base_model["repo_url"] = model_id + + # Check if Hugging Face ID is present in URL + if "huggingface.co" in model_id: + match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id) + if match: + model_id_component = match.group(1) + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params) + + # Populate model dictionary with extracted components + if model_full_name_component is not None: + base_model["name"] = Metadata.id_to_title(model_full_name_component) + if org_component is not None: + base_model["organization"] = Metadata.id_to_title(org_component) + if version is not None: + base_model["version"] = version + + else: + # Likely a Hugging Face ID + model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + + # Populate model dictionary with extracted components + if model_full_name_component is not None: + base_model["name"] = Metadata.id_to_title(model_full_name_component) + if org_component is not None: + base_model["organization"] = Metadata.id_to_title(org_component) + if version is not None: + base_model["version"] = version + if org_component is not None and model_full_name_component is not None: + base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}" + + elif isinstance(model_id, dict): + base_model = model_id + + else: + logger.error(f"base model entry '{str(model_id)}' not in a known format") + metadata.base_models.append(base_model) + if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card: + # This represents the datasets that this was trained from + metadata_datasets = [] + dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None))) + + if dataset_value is not None: + if isinstance(dataset_value, str): + metadata_datasets.append(dataset_value) + elif isinstance(dataset_value, list): + metadata_datasets.extend(dataset_value) + + if metadata.datasets is None: + metadata.datasets = [] + + for dataset_id in metadata_datasets: + # NOTE: model size of base model is assumed to be similar to the size of the current model + dataset = {} + if isinstance(dataset_id, str): + if dataset_id.startswith(("http://", "https://", "ssh://")): + dataset["repo_url"] = dataset_id + + # Check if Hugging Face ID is present in URL + if "huggingface.co" in dataset_id: + match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id) + if match: + dataset_id_component = match.group(1) + dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params) + + # Populate dataset dictionary with extracted components + if dataset_name_component is not None: + dataset["name"] = Metadata.id_to_title(dataset_name_component) + if org_component is not None: + dataset["organization"] = Metadata.id_to_title(org_component) + if version is not None: + dataset["version"] = version + + else: + # Likely a Hugging Face ID + dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params) + + # Populate dataset dictionary with extracted components + if dataset_name_component is not None: + dataset["name"] = Metadata.id_to_title(dataset_name_component) + if org_component is not None: + dataset["organization"] = Metadata.id_to_title(org_component) + if version is not None: + dataset["version"] = version + if org_component is not None and dataset_name_component is not None: + dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}" + + elif isinstance(dataset_id, dict): + dataset = dataset_id + + else: + logger.error(f"dataset entry '{str(dataset_id)}' not in a known format") + + metadata.datasets.append(dataset) + use_model_card_metadata("license", "license") use_model_card_metadata("license_name", "license_name") use_model_card_metadata("license_link", "license_link") @@ -386,9 +476,6 @@ class Metadata: use_array_model_card_metadata("languages", "languages") use_array_model_card_metadata("languages", "language") - use_array_model_card_metadata("datasets", "datasets") - use_array_model_card_metadata("datasets", "dataset") - # Hugging Face Parameter Heuristics #################################### @@ -458,7 +545,10 @@ class Metadata: gguf_writer.add_size_label(self.size_label) if self.license is not None: - gguf_writer.add_license(self.license) + if isinstance(self.license, list): + gguf_writer.add_license(",".join(self.license)) + else: + gguf_writer.add_license(self.license) if self.license_name is not None: gguf_writer.add_license_name(self.license_name) if self.license_link is not None: @@ -493,6 +583,8 @@ class Metadata: gguf_writer.add_base_model_version(key, base_model_entry["version"]) if "organization" in base_model_entry: gguf_writer.add_base_model_organization(key, base_model_entry["organization"]) + if "description" in base_model_entry: + gguf_writer.add_base_model_description(key, base_model_entry["description"]) if "url" in base_model_entry: gguf_writer.add_base_model_url(key, base_model_entry["url"]) if "doi" in base_model_entry: @@ -502,9 +594,29 @@ class Metadata: if "repo_url" in base_model_entry: gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"]) + if self.datasets is not None: + gguf_writer.add_dataset_count(len(self.datasets)) + for key, dataset_entry in enumerate(self.datasets): + if "name" in dataset_entry: + gguf_writer.add_dataset_name(key, dataset_entry["name"]) + if "author" in dataset_entry: + gguf_writer.add_dataset_author(key, dataset_entry["author"]) + if "version" in dataset_entry: + gguf_writer.add_dataset_version(key, dataset_entry["version"]) + if "organization" in dataset_entry: + gguf_writer.add_dataset_organization(key, dataset_entry["organization"]) + if "description" in dataset_entry: + gguf_writer.add_dataset_description(key, dataset_entry["description"]) + if "url" in dataset_entry: + gguf_writer.add_dataset_url(key, dataset_entry["url"]) + if "doi" in dataset_entry: + gguf_writer.add_dataset_doi(key, dataset_entry["doi"]) + if "uuid" in dataset_entry: + gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"]) + if "repo_url" in dataset_entry: + gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"]) + if self.tags is not None: gguf_writer.add_tags(self.tags) if self.languages is not None: gguf_writer.add_languages(self.languages) - if self.datasets is not None: - gguf_writer.add_datasets(self.datasets) diff --git a/gguf-py/scripts/__init__.py b/gguf-py/gguf/scripts/__init__.py similarity index 100% rename from gguf-py/scripts/__init__.py rename to gguf-py/gguf/scripts/__init__.py diff --git a/gguf-py/scripts/gguf_convert_endian.py b/gguf-py/gguf/scripts/gguf_convert_endian.py similarity index 97% rename from gguf-py/scripts/gguf_convert_endian.py rename to gguf-py/gguf/scripts/gguf_convert_endian.py index b698af0fe..f97e91bd4 100755 --- a/gguf-py/scripts/gguf_convert_endian.py +++ b/gguf-py/gguf/scripts/gguf_convert_endian.py @@ -11,8 +11,8 @@ from pathlib import Path import numpy as np # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): - sys.path.insert(0, str(Path(__file__).parent.parent)) +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) import gguf diff --git a/gguf-py/scripts/gguf_dump.py b/gguf-py/gguf/scripts/gguf_dump.py similarity index 99% rename from gguf-py/scripts/gguf_dump.py rename to gguf-py/gguf/scripts/gguf_dump.py index 1b6546541..f95b4fd48 100755 --- a/gguf-py/scripts/gguf_dump.py +++ b/gguf-py/gguf/scripts/gguf_dump.py @@ -12,8 +12,8 @@ from typing import Any import numpy as np # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): - sys.path.insert(0, str(Path(__file__).parent.parent)) +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from gguf import GGUFReader, GGUFValueType, ReaderTensor # noqa: E402 diff --git a/gguf-py/scripts/gguf_hash.py b/gguf-py/gguf/scripts/gguf_hash.py similarity index 97% rename from gguf-py/scripts/gguf_hash.py rename to gguf-py/gguf/scripts/gguf_hash.py index ee34d09bf..3ef989921 100755 --- a/gguf-py/scripts/gguf_hash.py +++ b/gguf-py/gguf/scripts/gguf_hash.py @@ -13,8 +13,8 @@ from pathlib import Path from tqdm import tqdm # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): - sys.path.insert(0, str(Path(__file__).parent.parent)) +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from gguf import GGUFReader # noqa: E402 diff --git a/gguf-py/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py similarity index 98% rename from gguf-py/scripts/gguf_new_metadata.py rename to gguf-py/gguf/scripts/gguf_new_metadata.py index fce52a8c1..a8cfc9d58 100755 --- a/gguf-py/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -13,8 +13,8 @@ from tqdm import tqdm from typing import Any, Sequence, NamedTuple # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): - sys.path.insert(0, str(Path(__file__).parent.parent)) +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) import gguf diff --git a/gguf-py/scripts/gguf_set_metadata.py b/gguf-py/gguf/scripts/gguf_set_metadata.py similarity index 97% rename from gguf-py/scripts/gguf_set_metadata.py rename to gguf-py/gguf/scripts/gguf_set_metadata.py index e35b651b8..f5809c35c 100755 --- a/gguf-py/scripts/gguf_set_metadata.py +++ b/gguf-py/gguf/scripts/gguf_set_metadata.py @@ -6,8 +6,8 @@ import sys from pathlib import Path # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): - sys.path.insert(0, str(Path(__file__).parent.parent)) +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from gguf import GGUFReader # noqa: E402 diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index bc9a13ee5..617791e24 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -13,7 +13,7 @@ class TensorNameMap: "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone "transformer.word_embeddings", # falcon "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon @@ -42,6 +42,7 @@ class TensorNameMap: "emb_ln", # nomic-bert "transformer.norm", # openelm "rwkv.blocks.0.pre_ln", # rwkv + "backbone.norm", # wavtokenizer ), # Position embeddings @@ -54,19 +55,20 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 "output_layer", # chatglm "head", # rwkv + "head.out", # wavtokenizer ), # Output norm MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 gpt-j falcon jais exaone - "model.norm", # llama-hf baichuan internlm2 + "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe "norm", # llama-pth "transformer.norm_f", # mpt dbrx "ln_f", # refact bloom qwen gpt2 @@ -80,6 +82,7 @@ class TensorNameMap: "transformer.norm", # openelm "model.norm", # nemotron "rwkv.ln_out", # rwkv + "backbone.final_layer_norm", # wavtokenizer ), # Rope frequencies @@ -87,6 +90,13 @@ class TensorNameMap: "rope.freqs", # llama-pth "rotary_pos_emb.inv_freq", # chatglm ), + + MODEL_TENSOR.ROPE_FACTORS_LONG: (), + MODEL_TENSOR.ROPE_FACTORS_SHORT: (), + + MODEL_TENSOR.CONV1D: ( + "backbone.embed", # roberta + ), } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { @@ -98,7 +108,7 @@ class TensorNameMap: "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -142,7 +152,8 @@ class TensorNameMap: # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron + "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe + "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert "transformer.h.{bid}.attn.q_proj", # gpt-j @@ -154,7 +165,8 @@ class TensorNameMap: # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron + "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe + "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert "transformer.h.{bid}.attn.k_proj", # gpt-j @@ -167,7 +179,7 @@ class TensorNameMap: # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron + "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert "transformer.h.{bid}.attn.v_proj", # gpt-j @@ -185,7 +197,8 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron + "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe + "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert "transformer.h.{bid}.attn.out_proj", # gpt-j @@ -212,7 +225,7 @@ class TensorNameMap: ), MODEL_TENSOR.ATTN_POST_NORM: ( - "model.layers.{bid}.post_attention_layernorm", # gemma2 + "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 ), # Rotary embeddings @@ -229,7 +242,7 @@ class TensorNameMap: "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone "h.{bid}.post_attention_layernorm", # bloom "transformer.blocks.{bid}.norm_2", # mpt - "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron + "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe "layers.{bid}.ffn_norm", # llama-pth "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "model.layers.{bid}.ln2", # yi @@ -247,21 +260,26 @@ class TensorNameMap: # Post feed-forward norm MODEL_TENSOR.FFN_POST_NORM: ( - "model.layers.{bid}.post_feedforward_layernorm", # gemma2 + "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 ), MODEL_TENSOR.FFN_GATE_INP: ( - "layers.{bid}.feed_forward.gate", # mixtral - "model.layers.{bid}.block_sparse_moe.gate", # mixtral - "model.layers.{bid}.mlp.gate", # qwen2moe - "transformer.decoder_layer.{bid}.router", # Grok - "transformer.blocks.{bid}.ffn.router.layer", # dbrx + "layers.{bid}.feed_forward.gate", # mixtral + "model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe + "model.layers.{bid}.mlp.gate", # qwen2moe olmoe + "transformer.decoder_layer.{bid}.router", # Grok + "transformer.blocks.{bid}.ffn.router.layer", # dbrx + "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe ), + MODEL_TENSOR.FFN_EXP_PROBS_B: ( + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + ), + # Feed-forward up MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox @@ -269,7 +287,7 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon "h.{bid}.mlp.dense_h_to_4h", # bloom - "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron + "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert "transformer.h.{bid}.mlp.fc_in", # gpt-j @@ -292,15 +310,16 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_UP_EXP: ( - "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx - "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged) + "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx + "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) ), MODEL_TENSOR.FFN_UP_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2 + "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 ), # AWQ-activation gate @@ -310,7 +329,7 @@ class TensorNameMap: # Feed-forward gate MODEL_TENSOR.FFN_GATE: ( - "model.layers.{bid}.mlp.gate_proj", # llama-hf refact + "model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2 "layers.{bid}.feed_forward.w1", # llama-pth "transformer.h.{bid}.mlp.w2", # qwen "transformer.h.{bid}.mlp.c_fc2", # jais @@ -324,15 +343,16 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_GATE_EXP: ( - "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx - "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe (merged) + "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx + "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) ), MODEL_TENSOR.FFN_GATE_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2 + "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 ), # Feed-forward down @@ -342,7 +362,7 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.down_proj", # mpt "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon "h.{bid}.mlp.dense_4h_to_h", # bloom - "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron + "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert "transformer.h.{bid}.mlp.fc_out", # gpt-j @@ -364,21 +384,23 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_DOWN_EXP: ( - "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx - "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe (merged) + "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx + "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe + "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2 + "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 ), MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", "model.layers.{bid}.self_attn.q_layernorm", # persimmon - "model.layers.{bid}.self_attn.q_norm", # cohere + "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm @@ -387,7 +409,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", "model.layers.{bid}.self_attn.k_layernorm", # persimmon - "model.layers.{bid}.self_attn.k_norm", # cohere + "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm @@ -442,34 +464,42 @@ class TensorNameMap: MODEL_TENSOR.TIME_MIX_W1: ( "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6 + "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_W2: ( "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6 + "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_X: ( "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6 + "model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_K: ( "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6 + "model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_V: ( "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6 + "model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_R: ( "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6 + "model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_G: ( "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6 + "model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LERP_W: ( "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6 + "model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_FIRST: ( @@ -478,30 +508,37 @@ class TensorNameMap: MODEL_TENSOR.TIME_MIX_DECAY: ( "rwkv.blocks.{bid}.attention.time_decay", # rwkv v6 + "model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_DECAY_W1: ( "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6 + "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_DECAY_W2: ( "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6 + "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_KEY: ( - "rwkv.blocks.{bid}.attention.key", # rwkv + "rwkv.blocks.{bid}.attention.key", # rwkv + "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_VALUE: ( - "rwkv.blocks.{bid}.attention.value", # rwkv + "rwkv.blocks.{bid}.attention.value", # rwkv + "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_RECEPTANCE: ( "rwkv.blocks.{bid}.attention.receptance", # rwkv + "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_GATE: ( - "rwkv.blocks.{bid}.attention.gate", # rwkv + "rwkv.blocks.{bid}.attention.gate", # rwkv + "model.layers.{bid}.self_attn.gate", # rwkv6qwen2 ), MODEL_TENSOR.TIME_MIX_LN: ( @@ -509,7 +546,8 @@ class TensorNameMap: ), MODEL_TENSOR.TIME_MIX_OUTPUT: ( - "rwkv.blocks.{bid}.attention.output", # rwkv + "rwkv.blocks.{bid}.attention.output", # rwkv + "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2 ), MODEL_TENSOR.CHANNEL_MIX_LERP_K: ( @@ -674,9 +712,81 @@ class TensorNameMap: "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5 ), + ############################################################################ + # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 ), + + MODEL_TENSOR.CLS: ( + "classifier", # jina + "classifier.dense", # roberta + ), + + MODEL_TENSOR.CLS_OUT: ( + "classifier.out_proj", # roberta + ), + ############################################################################# + + MODEL_TENSOR.CONVNEXT_DW: ( + "backbone.convnext.{bid}.dwconv", # wavtokenizer + ), + + MODEL_TENSOR.CONVNEXT_NORM: ( + "backbone.convnext.{bid}.norm", # wavtokenizer + ), + + MODEL_TENSOR.CONVNEXT_PW1: ( + "backbone.convnext.{bid}.pwconv1", # wavtokenizer + ), + + MODEL_TENSOR.CONVNEXT_PW2: ( + "backbone.convnext.{bid}.pwconv2", # wavtokenizer + ), + + MODEL_TENSOR.CONVNEXT_GAMMA: ( + "backbone.convnext.{bid}.gamma", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_CONV1: ( + "backbone.posnet.{bid}.conv1", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_CONV2: ( + "backbone.posnet.{bid}.conv2", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_NORM: ( + "backbone.posnet.{bid}.norm", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_NORM1: ( + "backbone.posnet.{bid}.norm1", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_NORM2: ( + "backbone.posnet.{bid}.norm2", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_NORM: ( + "backbone.posnet.{bid}.norm", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_Q: ( + "backbone.posnet.{bid}.q", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_K: ( + "backbone.posnet.{bid}.k", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_V: ( + "backbone.posnet.{bid}.v", # wavtokenizer + ), + + MODEL_TENSOR.POSNET_ATTN_OUT: ( + "backbone.posnet.{bid}.proj_out", # wavtokenizer + ), } # architecture-specific block mappings diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index dc5749913..f2645f921 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -122,8 +122,30 @@ class SpecialVocab: tokenizer = json.load(f) if self.load_merges: merges = tokenizer.get('model', {}).get('merges') - if isinstance(merges, list) and merges and isinstance(merges[0], str): - self.merges = merges + if isinstance(merges, list) and merges: + if isinstance(merges[0], str): + self.merges = merges + elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str): + # New format since transformers 4.45 to support spaces in merges + # ref: https://github.com/ggerganov/llama.cpp/issues/9692 + # TODO: internally store as the new format instead of converting to old + if any(' ' in s for pair in merges for s in pair): + logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}') + self.merges = [ + ' '.join( + [ + # ensure the spaces are properly encoded + ''.join( + chr(ord(c) + 256) if c == ' ' else c + for c in part + ) + for part in pair + ] + ) + for pair in merges + ] + else: + raise ValueError("Unknown tokenizer merges format") added_tokens = tokenizer.get('added_tokens', {}) else: added_tokens = {} diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 33cfe26b7..78c6baa64 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,12 +1,11 @@ [tool.poetry] name = "gguf" -version = "0.10.0" +version = "0.15.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ {include = "gguf"}, {include = "gguf/py.typed"}, - {include = "scripts"}, ] readme = "README.md" homepage = "https://ggml.ai" @@ -33,7 +32,7 @@ requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] -gguf-convert-endian = "scripts:gguf_convert_endian_entrypoint" -gguf-dump = "scripts:gguf_dump_entrypoint" -gguf-set-metadata = "scripts:gguf_set_metadata_entrypoint" -gguf-new-metadata = "scripts:gguf_new_metadata_entrypoint" +gguf-convert-endian = "gguf.scripts:gguf_convert_endian_entrypoint" +gguf-dump = "gguf.scripts:gguf_dump_entrypoint" +gguf-set-metadata = "gguf.scripts:gguf_set_metadata_entrypoint" +gguf-new-metadata = "gguf.scripts:gguf_new_metadata_entrypoint" diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py index 81a2a30ae..40d484f4e 100755 --- a/gguf-py/tests/test_metadata.py +++ b/gguf-py/tests/test_metadata.py @@ -182,8 +182,43 @@ class TestMetadataMethod(unittest.TestCase): expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}] expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'] expect.languages=['en'] - expect.datasets=['teknium/OpenHermes-2.5'] + expect.datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}] + self.assertEqual(got, expect) + # Base Model spec is inferred from model id + model_card = {'base_models': 'teknium/OpenHermes-2.5'} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Base Model spec is only url + model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Base Model spec is given directly + model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Dataset spec is inferred from model id + model_card = {'datasets': 'teknium/OpenHermes-2.5'} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Dataset spec is only url + model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Dataset spec is given directly + model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) self.assertEqual(got, expect) def test_apply_metadata_heuristic_from_hf_parameters(self): diff --git a/gguf-py/tests/test_quants.py b/gguf-py/tests/test_quants.py index 762067814..f04d5acce 100755 --- a/gguf-py/tests/test_quants.py +++ b/gguf-py/tests/test_quants.py @@ -136,7 +136,7 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}") sum_diff_bits = np.sum(diff_bits) - logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits/(x.size * 8):.6f}%)") + logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)") return False diff --git a/grammars/README.md b/grammars/README.md index 7ec815471..976954091 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -46,7 +46,7 @@ Terminals support the full range of Unicode. Unicode characters can be specified Character ranges can be negated with `^`: ``` -single-line ::= [^\n]+ "\n"` +single-line ::= [^\n]+ "\n" ``` ## Sequences and Alternatives @@ -120,11 +120,11 @@ You can use GBNF grammars: - In [llama-server](../examples/server): - For any completion endpoints, passed as the `json_schema` body field - - For the `/chat/completions` endpoint, passed inside the `response_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}`) + - For the `/chat/completions` endpoint, passed inside the `response_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}` or `{ type: "json_schema", json_schema: {"schema": ...} }`) - In [llama-cli](../examples/main), passed as the `--json` / `-j` flag - To convert to a grammar ahead of time: - in CLI, with [examples/json_schema_to_grammar.py](../examples/json_schema_to_grammar.py) - - in JavaScript with [json-schema-to-grammar.mjs](../examples/server/public/json-schema-to-grammar.mjs) (this is used by the [server](../examples/server)'s Web UI) + - in JavaScript with [json-schema-to-grammar.mjs](../examples/server/public_legacy/json-schema-to-grammar.mjs) (this is used by the [server](../examples/server)'s Web UI) Take a look at [tests](../tests/test-json-schema-to-grammar.cpp) to see which features are likely supported (you'll also find usage examples in https://github.com/ggerganov/llama.cpp/pull/5978, https://github.com/ggerganov/llama.cpp/pull/6659 & https://github.com/ggerganov/llama.cpp/pull/6555). diff --git a/grammars/english.gbnf b/grammars/english.gbnf new file mode 100644 index 000000000..2e53686c8 --- /dev/null +++ b/grammars/english.gbnf @@ -0,0 +1,6 @@ +# note: this might be incomplete, mostly an example +root ::= en-char+ ([ \t\n] en-char+)* +en-char ::= letter | digit | punctuation +letter ::= [a-zA-Z] +digit ::= [0-9] +punctuation ::= [!"#$%&'()*+,-./:;<=>?@[\\\]^_`{|}~] diff --git a/include/llama-cpp.h b/include/llama-cpp.h new file mode 100644 index 000000000..8f6368177 --- /dev/null +++ b/include/llama-cpp.h @@ -0,0 +1,30 @@ +#pragma once + +#ifndef __cplusplus +#error "This header is for C++ only" +#endif + +#include + +#include "llama.h" + +struct llama_model_deleter { + void operator()(llama_model * model) { llama_model_free(model); } +}; + +struct llama_context_deleter { + void operator()(llama_context * context) { llama_free(context); } +}; + +struct llama_sampler_deleter { + void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); } +}; + +struct llama_adapter_lora_deleter { + void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } +}; + +typedef std::unique_ptr llama_model_ptr; +typedef std::unique_ptr llama_context_ptr; +typedef std::unique_ptr llama_sampler_ptr; +typedef std::unique_ptr llama_adapter_lora_ptr; diff --git a/include/llama.h b/include/llama.h index 6334fc30d..61907ed40 100644 --- a/include/llama.h +++ b/include/llama.h @@ -2,6 +2,7 @@ #define LLAMA_H #include "ggml.h" +#include "ggml-cpu.h" #include "ggml-backend.h" #include @@ -33,7 +34,6 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF -// TODO: use everywhere in the implementation #define LLAMA_TOKEN_NULL -1 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' @@ -56,7 +56,7 @@ extern "C" { // TODO: show sample usage // - // struct llama_vocab; // TODO: add in the future + struct llama_vocab; struct llama_model; struct llama_context; struct llama_sampler; @@ -102,12 +102,17 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, }; enum llama_rope_type { - LLAMA_ROPE_TYPE_NONE = -1, - LLAMA_ROPE_TYPE_NORM = 0, - LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX, + LLAMA_ROPE_TYPE_NONE = -1, + LLAMA_ROPE_TYPE_NORM = 0, + LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX, + LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE, + LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION, }; enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file @@ -169,9 +174,9 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors + //LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // removed from gguf files, use Q4_0 and runtime repack + //LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // removed from gguf files, use Q4_0 and runtime repack + //LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors @@ -183,7 +188,8 @@ extern "C" { LLAMA_ROPE_SCALING_TYPE_NONE = 0, LLAMA_ROPE_SCALING_TYPE_LINEAR = 1, LLAMA_ROPE_SCALING_TYPE_YARN = 2, - LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN, + LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3, + LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE, }; enum llama_pooling_type { @@ -192,6 +198,7 @@ extern "C" { LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, LLAMA_POOLING_TYPE_LAST = 3, + LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph }; enum llama_attention_type { @@ -201,9 +208,9 @@ extern "C" { }; enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported }; // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -215,6 +222,7 @@ extern "C" { typedef struct llama_token_data_array { // TODO: consider SoA + // NOTE: this pointer can be modified by the samplers llama_token_data * data; size_t size; int64_t selected; // this is the index in the data array (i.e. not the token id) @@ -230,8 +238,11 @@ extern "C" { // - token : the token ids of the input (used when embd is NULL) // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence + // (if set to NULL, the token position will be tracked automatically by llama_decode) // - seq_id : the sequence to which the respective token belongs + // (if set to NULL, the sequence ID will be assumed to be 0) // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output + // (if set to NULL, only the logits for last token will be returned) // typedef struct llama_batch { int32_t n_tokens; @@ -242,15 +253,6 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" - - // NOTE: helpers for smooth API transition - can be deprecated in the future - // for future-proof code, use the above fields instead and ignore everything below - // - // pos[i] = all_pos_0 + i*all_pos_1 - // - llama_pos all_pos_0; // used if pos == NULL - llama_pos all_pos_1; // used if pos == NULL - llama_seq_id all_seq_id; // used if seq_id == NULL } llama_batch; enum llama_model_kv_override_type { @@ -274,21 +276,18 @@ extern "C" { }; struct llama_model_params { + // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) + ggml_backend_dev_t * devices; + int32_t n_gpu_layers; // number of layers to store in VRAM enum llama_split_mode split_mode; // how to split the model across multiple GPUs - // main_gpu interpretation depends on split_mode: - // LLAMA_SPLIT_MODE_NONE: the GPU that is used for the entire model - // LLAMA_SPLIT_MODE_ROW: the GPU that is used for small tensors and intermediate results - // LLAMA_SPLIT_MODE_LAYER: ignored + // the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE int32_t main_gpu; // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() const float * tensor_split; - // comma separated list of RPC servers to use for offloading - const char * rpc_servers; - // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. // If the provided progress_callback returns true, model loading continues. // If it returns false, model loading is immediately aborted. @@ -343,7 +342,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] - //bool no_perf; // whether to measure performance timings, TODO: implement + bool no_perf; // whether to measure performance timings // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -383,7 +382,7 @@ extern "C" { } llama_chat_message; // lora adapter - struct llama_lora_adapter; + struct llama_adapter_lora; // Helpers for getting default parameters // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172) @@ -397,30 +396,53 @@ extern "C" { // Call once at the start of the program LLAMA_API void llama_backend_init(void); + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(void); + //optional: LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); // Optional: an auto threadpool gets created in ggml if not passed explicitly LLAMA_API void llama_attach_threadpool( - struct llama_context * ctx, - ggml_threadpool_t threadpool, - ggml_threadpool_t threadpool_batch); + struct llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch); + LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); - // Call once at the end of the program - currently only used for MPI - LLAMA_API void llama_backend_free(void); + DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params), + "use llama_model_load_from_file instead"); - LLAMA_API struct llama_model * llama_load_model_from_file( + // Load the model from a file + // If the file is split into multiple parts, the file name must follow this pattern: -%05d-of-%05d.gguf + // If the split file name does not follow this pattern, use llama_model_load_from_splits + LLAMA_API struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params); - LLAMA_API void llama_free_model(struct llama_model * model); + // Load the model from multiple splits (support custom naming scheme) + // The paths must be in the correct order + LLAMA_API struct llama_model * llama_model_load_from_splits( + const char ** paths, + size_t n_paths, + struct llama_model_params params); - // TODO: rename to llama_init_from_model - LLAMA_API struct llama_context * llama_new_context_with_model( + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), + "use llama_model_free instead"); + + LLAMA_API void llama_model_free(struct llama_model * model); + + LLAMA_API struct llama_context * llama_init_from_model( struct llama_model * model, struct llama_context_params params); + DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params), + "use llama_init_from_model instead"); + // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); @@ -431,29 +453,42 @@ extern "C" { LLAMA_API bool llama_supports_mmap (void); LLAMA_API bool llama_supports_mlock (void); LLAMA_API bool llama_supports_gpu_offload(void); + LLAMA_API bool llama_supports_rpc (void); LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); - LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); - LLAMA_API int32_t llama_n_embd (const struct llama_model * model); - LLAMA_API int32_t llama_n_layer (const struct llama_model * model); + DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); + DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); + DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead"); + DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead"); - LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); - LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + + LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); + + LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); // Get the model's RoPE frequency scaling factor - LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); + LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); + + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab); + + LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure // - The output string is always null-terminated and cleared on failure + // - When retrieving a string, an extra byte must be allocated to account for the null terminator // - GGUF array values are not supported by these functions // Get metadata value as a string by key name @@ -474,12 +509,13 @@ extern "C" { // Returns the total size of all the tensors in the model in bytes LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + // Get the default chat template. Returns nullptr if not available + // If name is NULL, returns the default chat template + LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); + // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); - // Get a llama model tensor - LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); - // Returns true if the model contains an encoder that requires llama_encode() call LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); @@ -499,32 +535,36 @@ extern "C" { const char * fname_out, const llama_model_quantize_params * params); + // + // Adapters + // + // Load a LoRA adapter from file - // The loaded adapter will be associated to the given model, and will be free when the model is deleted - LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init( + LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); + // Manually free a LoRA adapter + // Note: loaded adapters will be free when the associated model is deleted + LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + + // The following functions operate on a llama_context, hence the naming: llama_verb_... + // Add a loaded LoRA adapter to given context // This will not modify model's weight - LLAMA_API int32_t llama_lora_adapter_set( + LLAMA_API int32_t llama_set_adapter_lora( struct llama_context * ctx, - struct llama_lora_adapter * adapter, + struct llama_adapter_lora * adapter, float scale); // Remove a specific LoRA adapter from given context // Return -1 if the adapter is not present in the context - LLAMA_API int32_t llama_lora_adapter_remove( + LLAMA_API int32_t llama_rm_adapter_lora( struct llama_context * ctx, - struct llama_lora_adapter * adapter); + struct llama_adapter_lora * adapter); // Remove all LoRA adapters from given context - LLAMA_API void llama_lora_adapter_clear( - struct llama_context * ctx); - - // Manually free a LoRA adapter - // Note: loaded adapters will be free when the associated model is deleted - LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter); + LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. @@ -532,8 +572,8 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. - LLAMA_API int32_t llama_control_vector_apply( - struct llama_context * lctx, + LLAMA_API int32_t llama_apply_adapter_cvec( + struct llama_context * ctx, const float * data, size_t len, int32_t n_embd, @@ -544,6 +584,8 @@ extern "C" { // KV cache // + // TODO: remove llama_kv_cache_view_* API + // Information associated with an individual cell in the KV cache view. struct llama_kv_cache_view_cell { // The position for this cell. Takes KV cache shifts into account. @@ -590,8 +632,11 @@ extern "C" { LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view); // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) + // TODO: change signature to llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_context * ctx) LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + /// + // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); @@ -661,6 +706,9 @@ extern "C" { struct llama_context * ctx, llama_seq_id seq_id); + // TODO: the llama_kv_cache_defrag and llama_kv_cache_update API tightly couples llama_context with llama_kv_cache + // how to avoid this? + // Defragment the KV cache // This will be applied: // - lazily on next llama_decode() @@ -670,6 +718,9 @@ extern "C" { // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) LLAMA_API void llama_kv_cache_update(struct llama_context * ctx); + // Check if the context supports KV cache shifting + LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx); + // // State / sessions // @@ -772,15 +823,15 @@ extern "C" { // Decoding // - // Return batch for single sequence of tokens starting at pos_0 + // Return batch for single sequence of tokens + // The sequence ID will be fixed to 0 + // The position of the tokens will be tracked automatically by llama_decode // // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // LLAMA_API struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens, - llama_pos pos_0, - llama_seq_id seq_id); + int32_t n_tokens); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids @@ -800,7 +851,7 @@ extern "C" { // Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Stores the encoder output internally for later use by the decoder cross-attention layers. // 0 - success - // < 0 - error + // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch); @@ -808,7 +859,7 @@ extern "C" { // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) - // < 0 - error + // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch); @@ -870,45 +921,74 @@ extern "C" { // Get the embeddings for a sequence id // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE - // shape: [n_embd] (1-dimensional) + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence + // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); // // Vocab // - LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); + LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); - LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); + LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token); - LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token); + LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token); // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) - LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); + LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token); // Identify if Token Id is a control token or a render-able token - LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token); + LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token); // Special tokens - LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence - LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence - LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification - LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator - LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line - LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding + LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence + LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence + LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn + LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator + LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line + LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding - LLAMA_API bool llama_add_bos_token(const struct llama_model * model); - LLAMA_API bool llama_add_eos_token(const struct llama_model * model); + LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); + LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); - // Codellama infill tokens - LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix - LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle - LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix - LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle + LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); + + DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead"); + DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); + DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); + DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); + DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead"); + DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead"); + DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead"); + DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead"); + DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead"); + DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead"); + DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead"); + + // CLS is equivalent to BOS + DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification + "use llama_vocab_bos instead"); // // Tokenization // + // The API is thread-safe. + // /// @details Convert the provided text into tokens. /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. @@ -918,7 +998,7 @@ extern "C" { /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated /// as plaintext. Does not insert a leading space. LLAMA_API int32_t llama_tokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const char * text, int32_t text_len, llama_token * tokens, @@ -932,7 +1012,7 @@ extern "C" { // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') // @param special If true, special tokens are rendered in the output. LLAMA_API int32_t llama_token_to_piece( - const struct llama_model * model, + const struct llama_vocab * vocab, llama_token token, char * buf, int32_t length, @@ -946,7 +1026,7 @@ extern "C" { /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so. /// @param unparse_special If true, special tokens are rendered in the output. LLAMA_API int32_t llama_detokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const llama_token * tokens, int32_t n_tokens, char * text, @@ -969,7 +1049,6 @@ extern "C" { /// @param length The size of the allocated buffer /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. LLAMA_API int32_t llama_chat_apply_template( - const struct llama_model * model, const char * tmpl, const struct llama_chat_message * chat, size_t n_msg, @@ -977,6 +1056,9 @@ extern "C" { char * buf, int32_t length); + // Get list of built-in chat templates + LLAMA_API int32_t llama_chat_builtin_templates(const char ** output, size_t len); + // // Sampling API // @@ -1014,7 +1096,6 @@ extern "C" { // llama_sampler_free(smpl); // // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). - // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab // typedef void * llama_sampler_context_t; @@ -1056,13 +1137,18 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); + // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed + LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); + // available samplers: - LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void); - LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); + LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void); + /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. + DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), + "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)"); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); @@ -1073,16 +1159,18 @@ extern "C" { /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); - /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep); - /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); + + /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent); + /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 + LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -1107,35 +1195,81 @@ extern "C" { float eta); LLAMA_API struct llama_sampler * llama_sampler_init_grammar( - const struct llama_model * model, + const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); + /// @details Lazy grammar sampler, introduced in https://github.com/ggerganov/llama.cpp/pull/9639 + /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( - int32_t n_vocab, // llama_n_vocab() - llama_token special_eos_id, // llama_token_eos() - llama_token linefeed_id, // llama_token_nl() - int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat, // 1.0 = disabled - float penalty_freq, // 0.0 = disabled - float penalty_present, // 0.0 = disabled - bool penalize_nl, // consider newlines as a repeatable token - bool ignore_eos); // ignore the end-of-sequence token + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat, // 1.0 = disabled + float penalty_freq, // 0.0 = disabled + float penalty_present); // 0.0 = disabled + + /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 + LLAMA_API struct llama_sampler * llama_sampler_init_dry( + const struct llama_vocab * vocab, + int32_t n_ctx_train, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - // Shorthand for: + // this sampler is meant to be used for fill-in-the-middle infilling + // it's supposed to be used after top_k + top_p sampling // + // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG + // 2. combine probs of tokens that have the same prefix + // + // example: + // + // - before: + // "hel": 0.5 + // "hell": 0.2 + // "hello": 0.1 + // "dummy": 0.1 + // + // - after: + // "hel": 0.8 + // "dummy": 0.1 + // + // 3. discard non-EOG tokens with low prob + // 4. if no tokens are left -> pick EOT + // + LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); + + // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise + LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); + + /// @details Sample and accept a token from the idx-th output of the last evaluation + // + // Shorthand for: // const auto * logits = llama_get_logits_ith(ctx, idx); // llama_token_data_array cur_p = { ... init from logits ... }; // llama_sampler_apply(smpl, &cur_p); - // return cur_p.data[cur_p.selected].id; - // - // At this point, this is mostly a convenience function. - // + // auto token = cur_p.data[cur_p.selected].id; + // llama_sampler_accept(smpl, token); + // return token; + // Returns the sampled token LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); // TODO: extend in the future @@ -1168,15 +1302,30 @@ extern "C" { // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements. // - enum llama_perf_type { - LLAMA_PERF_TYPE_CONTEXT = 0, - LLAMA_PERF_TYPE_SAMPLER_CHAIN = 1, + struct llama_perf_context_data { + double t_start_ms; + double t_load_ms; + double t_p_eval_ms; + double t_eval_ms; + + int32_t n_p_eval; + int32_t n_eval; }; - LLAMA_API void llama_perf_print(const void * ctx, enum llama_perf_type type); - LLAMA_API void llama_perf_reset( void * ctx, enum llama_perf_type type); + struct llama_perf_sampler_data { + double t_sample_ms; - LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx); + int32_t n_sample; + }; + + LLAMA_API struct llama_perf_context_data llama_perf_context (const struct llama_context * ctx); + LLAMA_API void llama_perf_context_print(const struct llama_context * ctx); + LLAMA_API void llama_perf_context_reset( struct llama_context * ctx); + + // NOTE: the following work only with samplers constructed via llama_sampler_chain_init + LLAMA_API struct llama_perf_sampler_data llama_perf_sampler (const struct llama_sampler * chain); + LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); + LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); #ifdef __cplusplus } diff --git a/media/llama-leader.jpeg b/media/llama-leader.jpeg deleted file mode 100644 index 0b4e6e1cf..000000000 Binary files a/media/llama-leader.jpeg and /dev/null differ diff --git a/models/ggml-vocab-chameleon.gguf.inp b/models/ggml-vocab-chameleon.gguf.inp new file mode 100644 index 000000000..9baf7d77a --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-chameleon.gguf.out b/models/ggml-vocab-chameleon.gguf.out new file mode 100644 index 000000000..7c5413fee --- /dev/null +++ b/models/ggml-vocab-chameleon.gguf.out @@ -0,0 +1,46 @@ + 17245 16604 16403 16604 33583 18355 + 16421 51153 + + 16604 + 16650 + 16650 16604 + 16581 + 16582 + 16582 16582 + 16582 16582 16582 + 16581 16582 + 31596 17394 + 34926 17394 + 31596 18671 + 34926 18671 + 34926 18671 16384 + 31596 16395 17394 16384 + 34926 16395 17394 16384 + 16811 16704 20410 16483 16631 16397 52854 + 16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639 + 29479 23955 17012 20103 25527 27670 17408 19005 21473 24774 + 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615 + 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392 + 31596 + 34926 + 16650 31596 + 16650 34926 + 16696 31596 + 16696 31596 16582 16696 31596 + 16604 16391 + 16582 16604 16412 + 16390 22623 + 31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 + 16384 16384 16384 16384 16384 16384 + 16402 + 16402 16402 + 16402 16402 16402 + 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 + 16402 16402 16402 16402 16402 16402 16402 16402 16402 + 16418 19038 16639 16448 24315 33727 16467 + 18765 17981 + 16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427 diff --git a/models/ggml-vocab-deepseek-r1-qwen.gguf.inp b/models/ggml-vocab-deepseek-r1-qwen.gguf.inp new file mode 100644 index 000000000..9baf7d77a --- /dev/null +++ b/models/ggml-vocab-deepseek-r1-qwen.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-r1-qwen.gguf.out b/models/ggml-vocab-deepseek-r1-qwen.gguf.out new file mode 100644 index 000000000..18b4b45cd --- /dev/null +++ b/models/ggml-vocab-deepseek-r1-qwen.gguf.out @@ -0,0 +1,46 @@ + 1122 220 19 220 26062 3951 + 37 50753 261 + + 220 + 256 + 262 + 197 + 198 + 271 + 1406 + 1572 + 9707 1879 + 21927 1879 + 9707 4337 + 21927 4337 + 21927 4337 0 + 9707 11 1879 0 + 21927 11 1879 0 + 419 374 11162 99 247 13 10821 + 86 15 19 23 220 22 83 1963 41808 11472 2940 16739 + 78762 14144 1456 13073 63471 33594 3038 133178 79012 + 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848 + 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8 + 9707 + 21927 + 220 21927 + 256 21927 + 262 21927 + 262 21927 198 262 21927 + 320 + 198 284 + 6 11385 + 9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 + 17085 2928 + 18 + 18 18 + 18 18 18 + 18 18 18 18 + 18 18 18 18 18 + 18 18 18 18 18 18 + 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 + 18 18 18 18 18 18 18 18 18 + 34 90063 128324 + 2560 2347 + 198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43 diff --git a/models/ggml-vocab-roberta-bpe.gguf.inp b/models/ggml-vocab-roberta-bpe.gguf.inp new file mode 100644 index 000000000..9baf7d77a --- /dev/null +++ b/models/ggml-vocab-roberta-bpe.gguf.inp @@ -0,0 +1,112 @@ +ied 4 ½ months +__ggml_vocab_test__ +Führer +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + +__ggml_vocab_test__ + + +__ggml_vocab_test__ + + + +__ggml_vocab_test__ + + + + +__ggml_vocab_test__ + + +__ggml_vocab_test__ +Hello world +__ggml_vocab_test__ + Hello world +__ggml_vocab_test__ +Hello World +__ggml_vocab_test__ + Hello World +__ggml_vocab_test__ + Hello World! +__ggml_vocab_test__ +Hello, world! +__ggml_vocab_test__ + Hello, world! +__ggml_vocab_test__ + this is 🦙.cpp +__ggml_vocab_test__ +w048 7tuijk dsdfhu +__ggml_vocab_test__ +нещо на Български +__ggml_vocab_test__ +កាន់តែពិសេសអាចខលចេញ +__ggml_vocab_test__ +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) +__ggml_vocab_test__ +Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello +__ggml_vocab_test__ + Hello + Hello +__ggml_vocab_test__ + ( +__ggml_vocab_test__ + + = +__ggml_vocab_test__ +' era +__ggml_vocab_test__ +Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ +__ggml_vocab_test__ +!!!!!! +__ggml_vocab_test__ +3 +__ggml_vocab_test__ +33 +__ggml_vocab_test__ +333 +__ggml_vocab_test__ +3333 +__ggml_vocab_test__ +33333 +__ggml_vocab_test__ +333333 +__ggml_vocab_test__ +3333333 +__ggml_vocab_test__ +33333333 +__ggml_vocab_test__ +333333333 +__ggml_vocab_test__ +Cửa Việt +__ggml_vocab_test__ + discards +__ggml_vocab_test__ + + + + + + + + + + + +🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL +__ggml_vocab_test__ diff --git a/models/ggml-vocab-roberta-bpe.gguf.out b/models/ggml-vocab-roberta-bpe.gguf.out new file mode 100644 index 000000000..f181ac3dc --- /dev/null +++ b/models/ggml-vocab-roberta-bpe.gguf.out @@ -0,0 +1,46 @@ + 2550 204 18430 377 + 597 2768 298 8564 + + 1437 + 1437 1437 + 1437 1437 1437 + 50117 + 50118 + 50140 + 50140 50118 + 50117 50118 + 31414 232 + 20920 232 + 31414 623 + 20920 623 + 20920 623 328 + 31414 6 232 328 + 20920 6 232 328 + 42 16 8103 18164 27 4 49317 + 605 40976 262 10109 18474 385 29 36807 6455 + 36765 25482 22063 23171 34251 18697 10809 26161 18697 3602 22063 27969 40966 25417 15264 26161 24269 36709 41171 35328 + 1376 17772 7471 1376 17772 19002 1376 17772 9085 1376 4333 13859 1376 17772 9357 1376 4333 9264 1376 17772 25448 1376 17772 18400 1376 17772 4333 1376 4333 10172 1376 17772 4333 1376 17772 7258 1376 17772 19002 1376 17772 5782 1376 17772 10172 1376 17772 3726 1376 17772 5782 1376 4333 10172 1376 17772 23171 + 6569 15113 7471 36 21113 43 17841 19002 17 8384 6569 14285 4958 12605 36 34654 2841 4203 354 10146 26511 1070 43 36174 5782 36 8338 21554 14 34 63 308 19233 43 + 31414 + 20920 + 1437 20920 + 1437 1437 20920 + 1437 1437 1437 20920 + 1437 1437 1437 20920 50118 1437 1437 1437 20920 + 36 + 50118 5457 + 108 3567 + 31414 6 1423 108 1250 328 1336 32 47 17841 10172 17487 47876 3602 48617 15264 46537 11423 27326 48494 8210 49233 1558 1570 27761 49429 43251 10809 17772 + 32376 12846 + 246 + 3103 + 25631 + 46152 + 3103 25631 + 46152 3103 + 46152 25631 + 46152 46152 + 46152 3103 25631 + 347 1376 2023 12410 102 16376 1376 2023 6382 90 + 9553 5954 + 50118 1437 50140 1437 50140 50118 1437 50117 1437 50117 50117 1437 50117 50118 1437 1437 50118 1437 1437 1437 50118 1437 1437 1437 1437 50118 1437 1437 1437 1437 1437 50118 6569 15113 7471 36 21113 43 17841 19002 17 8384 6569 14285 4958 12605 36 34654 2841 4203 354 10146 26511 1070 43 36174 5782 8103 18164 27 6569 18164 27 155 2357 30242 155 25631 30242 3103 30242 25631 30242 46152 30242 3103 25631 155 4 246 155 7586 246 155 734 246 25974 17772 7471 1376 17772 19002 1376 17772 9085 1376 4333 13859 1376 17772 9357 1376 4333 9264 1376 17772 25448 1376 17772 18400 1376 17772 4333 1376 4333 10172 1376 17772 4333 1376 17772 7258 1376 17772 19002 1376 17772 5782 18636 10172 17487 47876 3602 48617 15264 46537 11423 27326 48494 8210 49233 1558 1570 27761 49429 43251 10809 17772 36738 48332 47463 18697 10809 25482 22063 23171 34251 18697 10809 26161 18697 3602 22063 27969 40966 25417 15264 26161 24269 36709 41171 35328 128 49690 108 49972 49519 12905 48149 48149 43796 32376 12846 27282 28749 38 348 57 128 41042 37 18 89 6 128 4629 47 686 116 128 448 45 686 38 581 146 24 6 128 495 47 101 103 6845 116 166 108 30660 10 108 462 574 diff --git a/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja b/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja new file mode 100644 index 000000000..f5baef30b --- /dev/null +++ b/models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja @@ -0,0 +1,202 @@ + +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "List[" + json_to_python_type(json_spec.items) + "]"}} +{%- elif json_spec.type == "object" %} + {{- "Dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + +{%- macro old_tool_parser(tools) %} +{%- for tool in tools %} + {%- if loop.index0 != 0 %} + {{- '\n\n' }} + {%- endif %} + {{- '```python\ndef ' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameter_definitions|items %} + {%- if loop.index0 != 0 %} + {{- ', '}} + {%- endif %} + {{- param_name + ': ' }} + {%- if not param_fields.required %} + {{- 'Optional[' + param_fields.type + '] = None'}} + {%- else %} + {{- param_fields.type }} + {%- endif %} + {%- endfor %} + {{- ') -> List[Dict]:\n """'}} + {{- tool.description }} + {%- if tool.parameter_definitions|length != 0 %} + {{- '\n\n Args:\n '}} + {%- for param_name, param_fields in tool.parameter_definitions|items %} + {%- if loop.index0 != 0 %} + {{- '\n ' }} + {%- endif %} + {{- param_name + ' ('}} + {%- if not param_fields.required %} + {{- 'Optional[' + param_fields.type + ']'}} + {%- else %} + {{- param_fields.type }} + {%- endif %} + {{- '): ' + param_fields.description }} + {%- endfor %} + {%- endif %} + {{- '\n """\n pass\n```' }} +{%- endfor %} +{%- endmacro %} + +{%- macro new_tool_parser(tools) %} +{%- for tool in tools %} + {%- if loop.index0 != 0 %} + {{- '\n\n'}} + {%- endif %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{-'```python +def ' + tool.name + '('}} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.index0 != 0 %} + {{- ', '}} + {%- endif %} + {{-param_name + ": "}} + {%- if not param_name in tool.parameters.required %} + {{-'Optional[' + json_to_python_type(param_fields) + '] = None'}} + {%- else %} + {{- json_to_python_type(param_fields) }} + {%- endif %} + {%- endfor %} + {{- ') -> List[Dict]: + """'}} + {{- tool.description }} + {%- if tool.parameters.properties|length != 0 %} + {{- '\n\n Args:\n '}} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.index0 != 0 %} + {{- '\n ' }} + {%- endif %} + {{- param_name + ' ('}} + {%- if not param_name in tool.parameters.required %} + {{-'Optional[' + json_to_python_type(param_fields) + ']'}} + {%- else %} + {{- json_to_python_type(param_fields) }} + {%- endif %} + {{- '): ' + param_fields.description }} + {%- endfor %} + {%- endif %} + {{- '\n """\n pass\n```' }} +{%- endfor %} +{%- endmacro %} + +{{- bos_token }} +{%- if messages[0]['role'] == 'system' %} + {%- set loop_messages = messages[1:] %} + {%- set system_message = messages[0]['content'] %} +{%- else %} + {%- set loop_messages = messages %} + {%- set system_message = '## Task and Context\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user\'s needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.' %} +{%- endif %} +{{- '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' }} +{{- '# Safety Preamble' }} +{{- ' +The instructions in this section override those in the task description and style guide sections. Don\'t answer questions that are harmful or immoral.' }} +{{- ' + +# System Preamble' }} +{{- ' +## Basic Rules' }} +{{- ' +You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user\'s requests, you cite your sources in your answers, according to those instructions.' }} +{{- ' + +# User Preamble' }} +{{- ' +' + system_message }} +{{-' + +## Available Tools +Here is a list of tools that you have available to you: + +'}} +{%- set ns = namespace(new_tools=true) %} +{%- for tool in tools %} + {%- if tool.parameter_definitions is defined %} + {%- set ns.new_tools = false %} + {%- endif %} +{%- endfor %} +{%- if ns.new_tools %} + {{- new_tool_parser(tools) }} +{%- else %} + {{- old_tool_parser(tools) }} +{%- endif %} +{{- '<|END_OF_TURN_TOKEN|>'}} +{%- for message in loop_messages %} + {%- set content = message['content'] %} + {%- if message.role == 'user' %} + {{- '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content|trim + '<|END_OF_TURN_TOKEN|>' }} + {%- elif message.role == 'system' %} + {{- '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + content|trim + '<|END_OF_TURN_TOKEN|>' }} + {%- elif message.role == 'assistant' and message.tool_calls is defined %} + {{- '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }} + {%- if message.content is defined %} + {{- message.content|trim }} + {%- endif %} + {{- '\nAction:\n```json\n[\n' }} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{\n'|indent(4, first=true) }} + {{- '"tool_name": "'|indent(8, first=true) + tool_call.name + '",\n' }} + {{- '"parameters": '|indent(8, first=true) }} + {%- if tool_call.arguments is defined and tool_call.arguments|length > 0 %} + {{- tool_call.arguments|tojson(indent=4)|indent(8) }} + {{- '\n' }} + {%- else %} + {{- '{}\n' }} + {%- endif %} + {{- '}'|indent(4, first=true) }} + {%- if not loop.last %} + {{- ',\n' }} + {%- endif %} + {%- endfor %} + {{- "\n]```\n" }} + {%- elif message.role == 'assistant' %} + {{- '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content|trim + '<|END_OF_TURN_TOKEN|>' }} + {%- elif message.role == 'tool' %} + {{- '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\n' }} + {{- message.content|trim }} + {{- '<|END_OF_TURN_TOKEN|>' }} + {%- endif %} +{%- endfor %} +{{-'<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write \'Action:\' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user\'s last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example: +```json +[ + { + "tool_name": title of the tool in the specification, + "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters + } +]```<|END_OF_TURN_TOKEN|>'}} +{%- if add_generation_prompt %} + {{- '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }} +{%- endif %} diff --git a/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja b/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja new file mode 100644 index 000000000..149250bd5 --- /dev/null +++ b/models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja @@ -0,0 +1,152 @@ +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]"}} +{%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} + {%- else %} + {{- "dict" }} + {%- endif %} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- '<|im_start|>system +' }} +{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + " + +" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args: +" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- " + Returns: + " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- " +" }} + {%- endif %} +{%- endfor %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|> +' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + ' +' + message.content + '<|im_end|>' + ' +' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- ' + +' }} {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {{- ', '}} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments|tojson }} + {%- endif %} + {%- endif %} + {{- '}' }} + {{- ' +' }} + {%- endfor %} + {{- '<|im_end|> +' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool +' }} + {%- endif %} + {{- ' +' }} + {{- message.content }} + {%- if not loop.last %} + {{- ' + +' }} + {%- else %} + {{- ' +' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant +' }} +{%- endif %} diff --git a/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja b/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja new file mode 100644 index 000000000..149250bd5 --- /dev/null +++ b/models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja @@ -0,0 +1,152 @@ +{%- macro json_to_python_type(json_spec) %} +{%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + +{%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} +{%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]"}} +{%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}} + {%- else %} + {{- "dict" }} + {%- endif %} +{%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} +{%- else %} + {{- "Any" }} +{%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- '<|im_start|>system +' }} +{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + " + +" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args: +" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- " + Returns: + " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- " +" }} + {%- endif %} +{%- endfor %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|> +' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + ' +' + message.content + '<|im_end|>' + ' +' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- ' + +' }} {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {{- ', '}} + {%- if tool_call.arguments is defined %} + {{- '"arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments|tojson }} + {%- endif %} + {%- endif %} + {{- '}' }} + {{- ' +' }} + {%- endfor %} + {{- '<|im_end|> +' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool +' }} + {%- endif %} + {{- ' +' }} + {{- message.content }} + {%- if not loop.last %} + {{- ' + +' }} + {%- else %} + {{- ' +' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant +' }} +{%- endif %} diff --git a/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja b/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja new file mode 100644 index 000000000..bdf7919a9 --- /dev/null +++ b/models/templates/Qwen-Qwen2.5-7B-Instruct.jinja @@ -0,0 +1,54 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }} + {%- endif %} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- else %} + {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + message.content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja new file mode 100644 index 000000000..02a1c3bce --- /dev/null +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja @@ -0,0 +1 @@ +{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %} \ No newline at end of file diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja new file mode 100644 index 000000000..2ebfe7c1e --- /dev/null +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja @@ -0,0 +1,56 @@ +{% if not add_generation_prompt is defined %} +{% set add_generation_prompt = false %} +{% endif %} +{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %} +{%- for message in messages %} +{%- if message['role'] == 'system' %} +{% set ns.system_prompt = message['content'] %} +{%- endif %} +{%- endfor %} +{{bos_token}} +{{ns.system_prompt}} +{%- for message in messages %} +{%- if message['role'] == 'user' %} +{%- set ns.is_tool = false -%} +{{'<|User|>' + message['content']}} +{%- endif %} +{%- if message['role'] == 'assistant' and message['content'] is none %} +{%- set ns.is_tool = false -%} +{%- for tool in message['tool_calls']%} +{%- if not ns.is_first %} +{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} +{%- set ns.is_first = true -%} +{%- else %} +{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} +{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} +{%- endif %} +{%- endfor %} +{%- endif %} +{%- if message['role'] == 'assistant' and message['content'] is not none %} +{%- if ns.is_tool %} +{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} +{%- set ns.is_tool = false -%} +{%- else %} +{% set content = message['content'] %} +{% if '' in content %} +{% set content = content.split('')[-1] %} +{% endif %} +{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}} +{%- endif %} +{%- endif %} +{%- if message['role'] == 'tool' %} +{%- set ns.is_tool = true -%} +{%- if ns.is_output_first %} +{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} +{%- set ns.is_output_first = false %} +{%- else %} +{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} +{%- endif %} +{%- endif %} +{%- endfor -%} +{% if ns.is_tool %} +{{'<|tool▁outputs▁end|>'}} +{% endif %} +{% if add_generation_prompt and not ns.is_tool %} +{{'<|Assistant|>'}} +{% endif %} \ No newline at end of file diff --git a/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja b/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja new file mode 100644 index 000000000..9b8136df7 --- /dev/null +++ b/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja @@ -0,0 +1,57 @@ +{%- set loop_messages = messages -%} +{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%} +{%- set system_prompt_suffix -%} +{%- filter trim -%} +In addition to plain text responses, you can chose to call one or more of the provided functions. + +Use the following rule to decide when to call a function: + * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so + * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls + +If you decide to call functions: + * prefix function calls with functools marker (no closing marker required) + * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] + * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples + * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 + * make sure you pick the right functions that match the user intent + +Available functions as JSON spec: +{%- endfilter -%} +{%- endset -%} +{%- set system_prompt_suffix = system_prompt_suffix + "\n" + functions -%} +{%- set system_prompt_suffix = system_prompt_suffix + '\nToday is ' + datetime + '.' -%} +{%- set ns = namespace(role='', content='') -%} +{#- Basic consistency checks -#} +{%- if not loop_messages -%} + {{ raise_exception('Expected non-empty messages') }} +{%- endif -%} +{%- for message in loop_messages -%} + {%- set ns.role = message['role'] | lower -%} + {%- if ns.role not in message_roles -%} + {%- set message_roles_string = message_roles | join(', ') -%} + {{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles_string + ' are supported.') }} + {%- endif -%} + {%- set msg_content = message['content'] | default('', true) | trim -%} + {%- if loop.index0 == 0 -%} + {%- if ns.role == 'system' -%} + {%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\n' + message['content'] | trim + '\n' + system_prompt_suffix + '<|eot_id|>' -%} + {%- else -%} + {%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\n\nYou are a helpful assistant with access to functions.\n' + system_prompt_suffix + '<|eot_id|>' -%} + {%- endif -%} + {%- set ns.content = bos_token + system_prompt -%} + {{- ns.content -}} + {%- endif -%} + {%- if loop.index0 > 0 or ns.role != 'system' -%} + {%- set ns.content = '<|start_header_id|>' + ns.role + '<|end_header_id|>\n\n' + msg_content -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- set tool = namespace(calls=[]) -%} + {%- for call in message['tool_calls'] -%} + {%- set tool.calls = tool.calls + ['{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}'] -%} + {%- endfor -%} + {%- set ns.content = ns.content + ' functools[' + tool.calls | join(', ') + ']' -%} + {%- endif -%} + {%- set ns.content = ns.content + '<|eot_id|>' -%} + {{- ns.content -}} + {%- endif -%} +{%- endfor -%} +{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} diff --git a/models/templates/google-gemma-2-2b-it.jinja b/models/templates/google-gemma-2-2b-it.jinja new file mode 100644 index 000000000..923ec253c --- /dev/null +++ b/models/templates/google-gemma-2-2b-it.jinja @@ -0,0 +1,4 @@ +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + ' +' + message['content'] | trim + ' +' }}{% endfor %}{% if add_generation_prompt %}{{'model +'}}{% endif %} \ No newline at end of file diff --git a/models/templates/meetkai-functionary-medium-v3.1.jinja b/models/templates/meetkai-functionary-medium-v3.1.jinja new file mode 100644 index 000000000..29d64a215 --- /dev/null +++ b/models/templates/meetkai-functionary-medium-v3.1.jinja @@ -0,0 +1,58 @@ +{# version=v3-llama3.1 #}{%- if not tools is defined -%} + {%- set tools = none -%} +{%- endif -%} + +{%- set has_code_interpreter = tools | selectattr("type", "equalto", "code_interpreter") | list | length > 0 -%} +{%- if has_code_interpreter -%} + {%- set tools = tools | rejectattr("type", "equalto", "code_interpreter") | list -%} +{%- endif -%} + +{#- System message + builtin tools #} +{{- bos_token + "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if has_code_interpreter %} + {{- "Environment: ipython\n\n" }} +{%- else -%} + {{ "\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n\n" }} +{%- if tools %} + {{- "\nYou have access to the following functions:\n\n" }} + {%- for t in tools %} + {%- if "type" in t -%} + {{ "Use the function '"|safe + t["function"]["name"] + "' to '"|safe + t["function"]["description"] + "'\n"|safe + t["function"] | tojson() }} + {%- else -%} + {{ "Use the function '"|safe + t["name"] + "' to '"|safe + t["description"] + "'\n"|safe + t | tojson() }} + {%- endif -%} + {{- "\n\n" }} + {%- endfor %} + {{- '\nThink very carefully before calling functions.\nIf a you choose to call a function ONLY reply in the following format:\n<{start_tag}={function_name}>{parameters}{end_tag}\nwhere\n\nstart_tag => ` a JSON dict with the function argument name as key and function argument value as value.\nend_tag => ``\n\nHere is an example,\n{"example_name": "example_value"}\n\nReminder:\n- If looking for real time information use relevant functions before falling back to brave_search\n- Function calls MUST follow the specified format, start with \n- Required parameters MUST be specified\n- Only call one function at a time\n- Put the entire function call reply on one line\n\n' -}} +{%- endif %} +{{- "<|eot_id|>" -}} + +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- elif message['role'] == 'tool' -%} + {{ '<|start_header_id|>ipython<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}} + {%- if message['content'] -%} + {{ message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {%- if tool_call["function"]["name"] == "python" -%} + {{ '<|python_tag|>' + tool_call['function']['arguments'] }} + {%- else -%} + {{ '' + tool_call['function']['arguments'] + '' }} + {%- endif -%} + {%- endfor -%} + {{ '<|eom_id|>' }} + {%- else -%} + {{ '<|eot_id|>' }} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif -%} \ No newline at end of file diff --git a/models/templates/meetkai-functionary-medium-v3.2.jinja b/models/templates/meetkai-functionary-medium-v3.2.jinja new file mode 100644 index 000000000..74fd1e7af --- /dev/null +++ b/models/templates/meetkai-functionary-medium-v3.2.jinja @@ -0,0 +1,287 @@ +{# version=v3.llama3 #}{%- macro append_new_param_info(param_declaration, comment_info, examples_info, depth) -%} + {%- set offset = "" -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- if examples_info | length > 0 -%} + {# Append each example info #} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- endif -%} + {{ "\n" + offset + param_declaration }} +{%- endmacro -%} + +{%- macro convert_data_type(param_type) -%} + {%- if param_type == "integer" or param_type == "float" -%} + {{ "number" }} + {%- else -%} + {{ param_type }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_param_type(param) -%} + {%- set param_type = "any" -%} + + {%- if "type" in param -%} + {%- set raw_param_type = param["type"] -%} + {%- if raw_param_type is iterable and raw_param_type is not string -%} + {%- set param_type = raw_param_type | join(" | ") -%} + {%- else -%} + {%- set param_type = raw_param_type -%} + {%- endif -%} + {{ convert_data_type(param_type) }} + {%- elif "oneOf" in param -%} + {%- set one_of_types = param["oneOf"]|selectattr("type", "defined")|list -%} + {%- set one_of_types = one_of_types|map(attribute="type")|unique|list -%} + {{ convert_data_type(one_of_types | join(" | ")) }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_format_param(param) -%} + {%- if "format" in param -%} + {{ param["format"] }} + {%- elif "oneOf" in param -%} + {%- set formats = [] -%} + {%- for item in param["oneOf"] -%} + {%- if "format" in item -%} + {%- if item["format"] == param["oneOf"][-1]["format"] -%} + {{ item["format"] }} + {%- else -%} + {{ item["format"] + " or "}} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ "<|NONE|>" }} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_param_info(param) -%} + {%- set param_type = param.get("type", "any") -%} + {%- set format_param = get_format_param(param) -%} + + {%- if "description" in param or "default" in param or format_param != "<|NONE|>" or param["maximum"] or param["minimum"] or param["maxLength"] or param["minLength"] -%} + {{ "//" }} + {%- if "description" in param -%} + {%- set desc = param["description"] -%} + {%- if not desc.endswith(".") -%} + {%- set desc = desc + "." -%} + {%- endif -%} + {{ " " + desc }} + {%- endif -%} + + {%- if "default" in param -%} + {%- set default_value = param["default"] -%} + {%- if param_type == "string" -%} + {%- set default_value = '"' ~ default_value ~ '"' -%} + {%- endif -%} + {{ " Default=" ~ default_value ~ "." }} + {%- endif -%} + + {%- set format_param = get_format_param(param) -%} + {%- if format_param != "<|NONE|>" -%} + {{ " Format=" ~ format_param }} + {%- endif -%} + + {%- for field, field_name in [("maximum", "Maximum"), ("minimum", "Minimum"), ("maxLength", "Maximum length"), ("minLength", "Minimum length")] -%} + {%- if field in param -%} + {{ " " + field_name ~ "=" ~ param[field] }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ "<|NONE|>"}} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_enum_option_str(enum_options) -%} + {%- for v in enum_options -%} + {%- if v is string -%} + {{ '"' + v + '"' }} + {%- else -%} + {{ v }} + {%- endif -%} + {%- if enum_options|length > 0 and v != enum_options[-1] -%} + {{ " | " }} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} + +{%- macro get_array_typescript(param_name, param_dic, depth) -%} + {%- set offset = '' -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + {%- set items_info = param_dic.get('items', {}) -%} + + {%- if items_info|length == 0 -%} + {%- if param_name -%} + {{ "\n" + offset + param_name + ": []" }} + {%- else -%} + {{ "\n" + offset + "[]" }} + {%- endif -%} + {%- else -%} + {%- set array_type = get_param_type(items_info) -%} + {%- if array_type == 'object' -%} + {%- if param_name -%} + {{ "\n" + offset + param_name + ": {" }} + {%- else -%} + {{ "\n" + offset + "{" }} + {%- endif -%} + {{ get_parameter_typescript(items_info.get('properties', {}), items_info.get('required', []), depth + 1) -}} + {{- "\n" + offset + "}[]" }} + {%- elif array_type == 'array' -%} + {%- set item_info = get_array_typescript(None, items_info, depth + 1) -%} + {%- if not param_name -%} + {{ "\n" + item_info + "[]" }} + {%- else -%} + {{ "\n" + offset + param_name + ": " + item_info|trim + "[]" }} + {%- endif -%} + {%- else -%} + {%- if 'enum' in items_info -%} + {%- set item_type = get_enum_option_str(items_info['enum']) -%} + {%- if param_name is none -%} + {{ "(" + item_type + ")[]"}} + {%- else -%} + {{ "\n" + offset + param_name + ": (" + item_type + ")[]" }} + {%- endif -%} + {%- else -%} + {%- if param_name is none -%} + {{ "\n" + array_type + "[]" }} + {%- else -%} + {{ "\n" + offset + param_name + ": " + array_type + "[]," }} + {%- endif -%} + {%- endif -%} + {%- endif -%} + {%- endif -%} +{%- endmacro -%} + +{%- macro get_parameter_typescript(properties, required_params, depth=0) -%} + {%- set res = "" -%} + {%- for param_name, param in properties.items() -%} + {%- if param is mapping -%} + {%- set comment_info = get_param_info(param) -%} + {# Param Examples #} + {%- set examples_info = [] -%} + {%- if "examples" in param -%} + {%- set examples_info = ["Example " + param_name + ":"] -%} + {%- set examples_info = examples_info + param["examples"] -%} + {%- endif -%} + + {# Param Name declaration #} + {%- set param_declaration = param_name -%} + {%- if required_params is iterable and param_name not in required_params -%} + {%- set param_declaration = param_declaration + "?" -%} + {%- endif -%} + + {%- set param_type = get_param_type(param) -%} + + {# Handle indentation based on depth #} + {%- set offset = "" -%} + {%- if depth >= 1 -%} + {%- set offset = " " * depth -%} + {%- endif -%} + + {%- if param_type == "object" -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- endif -%} + {%- if examples_info|length > 0 -%} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- set param_declaration = param_declaration + ": {" -%} + {{ "\n" + offset + param_declaration -}} + {{- get_parameter_typescript(param.get("properties", {}), param.get("required", []), depth + 1) -}} + {{- "\n" + offset + "}," }} + {%- elif param_type == "array" -%} + {%- set item_info = param.get("items", {}) -%} + {%- if "type" not in item_info -%} + {%- set param_declaration = param_declaration + ": []," -%} + {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }} + {%- else -%} + {%- if comment_info != "<|NONE|>" -%} + {{ "\n" + offset + comment_info }} + {%- endif -%} + {%- if examples_info|length > 0 -%} + {%- for example in examples_info -%} + {{ "\n" + offset + "// " + example|string|replace("'", '"') }} + {%- endfor -%} + {%- endif -%} + {%- set array_declaration = get_array_typescript(param_declaration, param, depth) -%} + {%- if not array_declaration.endswith(",") -%} + {%- set array_declaration = array_declaration + "," -%} + {%- endif -%} + {{ array_declaration}} + {%- endif -%} + {%- else -%} + {%- if "enum" in param -%} + {%- set param_type = get_enum_option_str(param["enum"]) -%} + {%- endif -%} + {%- if "nullable" in param and param["nullable"] -%} + {%- set param_type = param_type + " | null" -%} + {%- endif -%} + {%- set param_declaration = param_declaration + ": " + param_type + "," -%} + {{ append_new_param_info(param_declaration, comment_info, examples_info, depth) }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} +{%- endmacro -%} + +{%- macro generate_schema_from_functions(functions, namespace='functions') -%} + {{ "// Supported function definitions that should be called when necessary.\n" -}} + {{- "namespace " + namespace + " {\n\n" -}} + + {%- for function in functions -%} + {%- if function.get("function") -%} + {%- set function = function.get("function") -%} + {%- endif -%} + + {%- set function_name = function.get("name") -%} + {%- if function_name -%} + {%- set description = function.get('description', '') -%} + {%- set parameters = function.get('parameters', {}) -%} + {{- "// " + description + "\n" -}} + {{- "type " + function_name -}} + {%- if parameters and parameters.get("properties") -%} + {{- " = (_: {" -}} + {%- set required_params = parameters.get("required", []) -%} + {{ get_parameter_typescript(parameters.get("properties"), required_params, 0) -}} + {{- "\n}) => any;\n\n" }} + {%- else -%} + {{ " = () => any;\n\n" }} + {%- endif -%} + {%- endif -%} + {%- endfor -%} + {{ "} // namespace " + namespace }} +{%- endmacro -%} +{%- if not tools -%} + {%- set tools = [] -%} +{%- endif -%} +{{ bos_token + '<|start_header_id|>system<|end_header_id|>\n\nYou are capable of executing available function(s) if required.\nOnly execute function(s) when absolutely necessary.\nAsk for the required input to:recipient==all\nUse JSON for function arguments.\nRespond in this format:\n>>>${recipient}\n${content}\nAvailable functions:\n' + generate_schema_from_functions(tools) + '<|eot_id|>' -}} +{%- if tools|length > 0 and tools|selectattr("type", "equalto", "code_interpreter")|list|length > 0 -%} + {{ '<|start_header_id|>system<|end_header_id|>\n\nWhen you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at \'/mnt/data\' can be used to save and persist user files.<|eot_id|>' }} +{%- endif -%} +{%- for message in messages -%} + {%- if message['role'] == 'user' or message['role'] == 'system' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- elif message['role'] == 'tool' -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }} + {%- else -%} + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}} + {%- if message['content'] -%} + {{ '>>>all\n' + message['content'] }} + {%- endif -%} + {%- if 'tool_calls' in message and message['tool_calls'] -%} + {%- for tool_call in message['tool_calls'] -%} + {{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }} + {%- endfor -%} + {%- endif -%} + {{ '<|eot_id|>' }} + {%- endif -%} +{%- endfor -%} +{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n>>>' }}{% endif %} \ No newline at end of file diff --git a/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja b/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja new file mode 100644 index 000000000..33089ace1 --- /dev/null +++ b/models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja @@ -0,0 +1,109 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- set date_string = "26 Jul 2024" %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- System message + builtin tools #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if builtin_tools is defined or tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{%- if builtin_tools is defined %} + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- "<|python_tag|>" + tool_call.name + ".call(" }} + {%- for arg_name, arg_val in tool_call.arguments | items %} + {{- arg_name + '="' + arg_val + '"' }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- else %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {%- endif %} + {%- if builtin_tools is defined %} + {#- This means we're in ipython mode #} + {{- "<|eom_id|>" }} + {%- else %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja b/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja new file mode 100644 index 000000000..1bad6a0f6 --- /dev/null +++ b/models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja @@ -0,0 +1,93 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- System message #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {{- "<|eot_id|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja b/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja new file mode 100644 index 000000000..33089ace1 --- /dev/null +++ b/models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja @@ -0,0 +1,109 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- set date_string = "26 Jul 2024" %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "" %} +{%- endif %} + +{#- System message + builtin tools #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if builtin_tools is defined or tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{%- if builtin_tools is defined %} + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} +{%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- "<|python_tag|>" + tool_call.name + ".call(" }} + {%- for arg_name, arg_val in tool_call.arguments | items %} + {{- arg_name + '="' + arg_val + '"' }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- else %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {%- endif %} + {%- if builtin_tools is defined %} + {#- This means we're in ipython mode #} + {{- "<|eom_id|>" }} + {%- else %} + {{- "<|eot_id|>" }} + {%- endif %} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping or message.content is iterable %} + {{- message.content | tojson }} + {%- else %} + {{- message.content }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/models/templates/microsoft-Phi-3.5-mini-instruct.jinja b/models/templates/microsoft-Phi-3.5-mini-instruct.jinja new file mode 100644 index 000000000..d1533d152 --- /dev/null +++ b/models/templates/microsoft-Phi-3.5-mini-instruct.jinja @@ -0,0 +1,8 @@ +{% for message in messages %}{% if message['role'] == 'system' and message['content'] %}{{'<|system|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'user' %}{{'<|user|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|> +' + message['content'] + '<|end|> +'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> +' }}{% else %}{{ eos_token }}{% endif %} \ No newline at end of file diff --git a/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja b/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja new file mode 100644 index 000000000..9c21a3f13 --- /dev/null +++ b/models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja @@ -0,0 +1,87 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{#- This block checks for alternating user/assistant messages, skipping tool calling messages #} +{%- set ns = namespace() %} +{%- set ns.index = 0 %} +{%- for message in loop_messages %} + {%- if not (message.role == "tool" or message.role == "tool_results" or (message.tool_calls is defined and message.tool_calls is not none)) %} + {%- if (message["role"] == "user") != (ns.index % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} + {%- set ns.index = ns.index + 1 %} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS][" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST]" + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST]" + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif (message.tool_calls is defined and message.tool_calls is not none) %} + {{- "[TOOL_CALLS][" }} + {%- for tool_call in message.tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length != 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- endif %} + {{- ', "id": "' + tool_call.id + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- message["content"] + eos_token}} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS]{"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/pocs/CMakeLists.txt b/pocs/CMakeLists.txt index 03e1d2c04..d49d14dee 100644 --- a/pocs/CMakeLists.txt +++ b/pocs/CMakeLists.txt @@ -8,5 +8,7 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) if (EMSCRIPTEN) else() - add_subdirectory(vdot) + if (NOT GGML_BACKEND_DL) + add_subdirectory(vdot) + endif() endif() diff --git a/pocs/vdot/CMakeLists.txt b/pocs/vdot/CMakeLists.txt index d5405ad29..6235aec1f 100644 --- a/pocs/vdot/CMakeLists.txt +++ b/pocs/vdot/CMakeLists.txt @@ -1,9 +1,9 @@ set(TARGET llama-vdot) add_executable(${TARGET} vdot.cpp) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) set(TARGET llama-q8dot) add_executable(${TARGET} q8dot.cpp) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_11) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/pocs/vdot/q8dot.cpp b/pocs/vdot/q8dot.cpp index 1a52ff5e9..3df6e1f42 100644 --- a/pocs/vdot/q8dot.cpp +++ b/pocs/vdot/q8dot.cpp @@ -11,6 +11,7 @@ #include #include +#include constexpr int kVecSize = 1 << 16; @@ -136,7 +137,7 @@ int main(int argc, char** argv) { auto ggml_type = type == 0 ? GGML_TYPE_Q4_0 : GGML_TYPE_Q4_1; - auto funcs = ggml_internal_get_type_traits(ggml_type); + const auto * funcs = ggml_get_type_traits_cpu(ggml_type); Stat simple, ggml; @@ -156,8 +157,8 @@ int main(int argc, char** argv) { t1 = std::chrono::high_resolution_clock::now(); float fs; - if (type == 0) funcs.vec_dot(kVecSize * QK4_1, &fs, 0, x40.data(), 0, y.data(), 0, 1); - else funcs.vec_dot(kVecSize * QK4_1, &fs, 0, x41.data(), 0, y.data(), 0, 1); + if (type == 0) funcs->vec_dot(kVecSize * QK4_1, &fs, 0, x40.data(), 0, y.data(), 0, 1); + else funcs->vec_dot(kVecSize * QK4_1, &fs, 0, x41.data(), 0, y.data(), 0, 1); t2 = std::chrono::high_resolution_clock::now(); t = 1e-3*std::chrono::duration_cast(t2-t1).count(); if (iloop > 3) ggml.addResult(fs, t); diff --git a/pocs/vdot/vdot.cpp b/pocs/vdot/vdot.cpp index 17e9e4482..2dca62848 100644 --- a/pocs/vdot/vdot.cpp +++ b/pocs/vdot/vdot.cpp @@ -9,6 +9,7 @@ #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -236,7 +237,7 @@ int main(int argc, char** argv) { int n4 = useQ4_1 ? kVecSize / QK4_1 : kVecSize / QK4_0; n4 = 64*((n4 + 63)/64); int n8 = kVecSize / QK8_0; n8 = 64*((n8 + 63)/64); - auto funcs = useQ4_1 ? ggml_internal_get_type_traits(GGML_TYPE_Q4_1) : ggml_internal_get_type_traits(GGML_TYPE_Q4_0); + const auto * funcs_cpu = ggml_get_type_traits_cpu(useQ4_1 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q4_0); std::vector q40; std::vector q41; @@ -261,9 +262,9 @@ int main(int argc, char** argv) { // Note, we do not include this in the timing as in practical application // we already have the quantized model weights. if (useQ4_1) { - funcs.from_float(x1.data(), q41.data(), kVecSize); + funcs_cpu->from_float(x1.data(), q41.data(), kVecSize); } else { - funcs.from_float(x1.data(), q40.data(), kVecSize); + funcs_cpu->from_float(x1.data(), q40.data(), kVecSize); } // Now measure time the dot product needs using the "scalar" version above @@ -282,10 +283,10 @@ int main(int argc, char** argv) { dot_q4_q8(kVecSize, &result, q40.data(), q8.data()); } else { - auto vdot = ggml_internal_get_type_traits(funcs.vec_dot_type); - vdot.from_float(y1.data(), q8.data(), kVecSize); - if (useQ4_1) funcs.vec_dot(kVecSize, &result, 0, q41.data(), 0, q8.data(), 0, 1); - else funcs.vec_dot(kVecSize, &result, 0, q40.data(), 0, q8.data(), 0, 1); + const auto * vdot = ggml_get_type_traits_cpu(funcs_cpu->vec_dot_type); + vdot->from_float(y1.data(), q8.data(), kVecSize); + if (useQ4_1) funcs_cpu->vec_dot(kVecSize, &result, 0, q41.data(), 0, q8.data(), 0, 1); + else funcs_cpu->vec_dot(kVecSize, &result, 0, q40.data(), 0, q8.data(), 0, 1); } sumq += result; t2 = std::chrono::high_resolution_clock::now(); diff --git a/pyrightconfig.json b/pyrightconfig.json index 6016f4b6d..9acbbeb78 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -5,7 +5,8 @@ "reportUnusedImport": "warning", "reportDuplicateImport": "error", "reportDeprecated": "warning", - "reportUnnecessaryTypeIgnoreComment": "warning", + "reportUnnecessaryTypeIgnoreComment": "information", + "disableBytesTypePromotions": false, // TODO: change once Python 3.12 is the minimum "executionEnvironments": [ { // TODO: make this version override work correctly diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index 1d07b0952..859204b27 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,5 +1,5 @@ numpy~=1.26.4 sentencepiece~=0.2.0 -transformers>=4.40.1,<5.0.0 +transformers>=4.45.1,<5.0.0 gguf>=0.1.0 protobuf>=4.21.0,<5.0.0 diff --git a/scripts/compare-commits.sh b/scripts/compare-commits.sh index 70679f4e5..e40d1cc6d 100755 --- a/scripts/compare-commits.sh +++ b/scripts/compare-commits.sh @@ -8,20 +8,31 @@ fi set -e set -x +# verify at the start that the compare script has all the necessary dependencies installed +./scripts/compare-llama-bench.py --check + bench_args="${@:3}" rm -f llama-bench.sqlite > /dev/null # to test a backend, call the script with the corresponding environment variable (e.g. GGML_CUDA=1 ./scripts/compare-commits.sh ...) +if [ -n "$GGML_CUDA" ]; then + cmake_opts="-DGGML_CUDA=ON" +fi + +dir="build-bench" + +function run { + rm -fr ${dir} > /dev/null + cmake -B ${dir} -S . $cmake_opts > /dev/null + cmake --build ${dir} -t llama-bench > /dev/null + ${dir}/bin/llama-bench -o sql -oe md $bench_args | sqlite3 llama-bench.sqlite +} git checkout $1 > /dev/null -make clean > /dev/null -make -j$(nproc) $make_opts llama-bench > /dev/null -./llama-bench -o sql -oe md $bench_args | sqlite3 llama-bench.sqlite +run git checkout $2 > /dev/null -make clean > /dev/null -make -j$(nproc) $make_opts llama-bench > /dev/null -./llama-bench -o sql -oe md $bench_args | sqlite3 llama-bench.sqlite +run ./scripts/compare-llama-bench.py -b $1 -c $2 diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 92b9e682a..239c458d8 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -19,22 +19,22 @@ logger = logging.getLogger("compare-llama-bench") # Properties by which to differentiate results per commit: KEY_PROPERTIES = [ - "cpu_info", "gpu_info", "n_gpu_layers", "cuda", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas", - "blas", "model_filename", "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "embeddings", "n_threads", - "type_k", "type_v", "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen" + "cpu_info", "gpu_info", "backends", "n_gpu_layers", "model_filename", "model_type", "n_batch", "n_ubatch", + "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v", "use_mmap", "no_kv_offload", + "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen" ] # Properties that are boolean and are converted to Yes/No for the table: -BOOL_PROPERTIES = ["cuda", "vulkan", "kompute", "metal", "sycl", "gpu_blas", "blas", "embeddings", "use_mmap", "no_kv_offload", "flash_attn"] +BOOL_PROPERTIES = ["embeddings", "cpu_strict", "use_mmap", "no_kv_offload", "flash_attn"] # Header names for the table: PRETTY_NAMES = { - "cuda": "CUDA", "vulkan": "Vulkan", "kompute": "Kompute", "metal": "Metal", "sycl": "SYCL", "rpc": "RPC", - "gpu_blas": "GPU BLAS", "blas": "BLAS", "cpu_info": "CPU", "gpu_info": "GPU", "model_filename": "File", "model_type": "Model", - "model_size": "Model Size [GiB]", "model_n_params": "Num. of Par.", "n_batch": "Batch size", "n_ubatch": "Microbatch size", - "n_threads": "Threads", "type_k": "K type", "type_v": "V type", "n_gpu_layers": "GPU layers", "split_mode": "Split mode", - "main_gpu": "Main GPU", "no_kv_offload": "NKVO", "flash_attn": "FlashAttention", "tensor_split": "Tensor split", - "use_mmap": "Use mmap", "embeddings": "Embeddings", + "cpu_info": "CPU", "gpu_info": "GPU", "backends": "Backends", "n_gpu_layers": "GPU layers", + "model_filename": "File", "model_type": "Model", "model_size": "Model size [GiB]", + "model_n_params": "Num. of par.", "n_batch": "Batch size", "n_ubatch": "Microbatch size", + "embeddings": "Embeddings", "cpu_mask": "CPU mask", "cpu_strict": "CPU strict", "poll": "Poll", + "n_threads": "Threads", "type_k": "K type", "type_v": "V type", "split_mode": "Split mode", "main_gpu": "Main GPU", + "no_kv_offload": "NKVO", "flash_attn": "FlashAttention", "tensor_split": "Tensor split", "use_mmap": "Use mmap", } DEFAULT_SHOW = ["model_type"] # Always show these properties by default. @@ -92,6 +92,7 @@ help_s = ( "If the columns are manually specified, then the results for each unique combination of the " "specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench." ) +parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed") parser.add_argument("-s", "--show", help=help_s) parser.add_argument("--verbose", action="store_true", help="increase output verbosity") @@ -99,6 +100,10 @@ known_args, unknown_args = parser.parse_known_args() logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO) +if known_args.check: + # Check if all required Python libraries are installed. Would have failed earlier if not. + sys.exit(0) + if unknown_args: logger.error(f"Received unknown args: {unknown_args}.\n") parser.print_help() @@ -121,6 +126,8 @@ connection = sqlite3.connect(input_file) cursor = connection.cursor() builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() +commit_short_len = len(builds[0][0]) + try: repo = git.Repo(".", search_parent_directories=True) except git.InvalidGitRepositoryError: @@ -133,11 +140,11 @@ def find_parent_in_data(commit: git.Commit): seen_hexsha8 = set() while heap: depth, current_commit = heapq.heappop(heap) - current_hexsha8 = commit.hexsha[:8] + current_hexsha8 = commit.hexsha[:commit_short_len] if (current_hexsha8,) in builds: return current_hexsha8 for parent in commit.parents: - parent_hexsha8 = parent.hexsha[:8] + parent_hexsha8 = parent.hexsha[:commit_short_len] if parent_hexsha8 not in seen_hexsha8: seen_hexsha8.add(parent_hexsha8) heapq.heappush(heap, (depth + 1, parent)) @@ -151,9 +158,9 @@ def get_all_parent_hexsha8s(commit: git.Commit): while unvisited: current_commit = unvisited.pop(0) - visited.append(current_commit.hexsha[:8]) + visited.append(current_commit.hexsha[:commit_short_len]) for parent in current_commit.parents: - if parent.hexsha[:8] not in visited: + if parent.hexsha[:commit_short_len] not in visited: unvisited.append(parent) return visited @@ -164,10 +171,10 @@ def get_commit_name(hexsha8): if repo is None: return hexsha8 for h in repo.heads: - if h.commit.hexsha[:8] == hexsha8: + if h.commit.hexsha[:commit_short_len] == hexsha8: return h.name for t in repo.tags: - if t.commit.hexsha[:8] == hexsha8: + if t.commit.hexsha[:commit_short_len] == hexsha8: return t.name return hexsha8 @@ -178,13 +185,13 @@ def get_commit_hexsha8(name): return None for h in repo.heads: if h.name == name: - return h.commit.hexsha[:8] + return h.commit.hexsha[:commit_short_len] for t in repo.tags: if t.name == name: - return t.commit.hexsha[:8] + return t.commit.hexsha[:commit_short_len] for c in repo.iter_commits("--all"): - if c.hexsha[:8] == name[:8]: - return c.hexsha[:8] + if c.hexsha[:commit_short_len] == name[:commit_short_len]: + return c.hexsha[:commit_short_len] return None @@ -298,14 +305,11 @@ else: show = [] # Show CPU and/or GPU by default even if the hardware for all results is the same: - if "gpu_blas" not in properties_different and "n_gpu_layers" not in properties_different: - gpu_blas = bool(rows_full[0][KEY_PROPERTIES.index("gpu_blas")]) + if "n_gpu_layers" not in properties_different: ngl = int(rows_full[0][KEY_PROPERTIES.index("n_gpu_layers")]) - if not gpu_blas or ngl != 99 and "cpu_info" not in properties_different: + if ngl != 99 and "cpu_info" not in properties_different: show.append("cpu_info") - if gpu_blas and "gpu_info" not in properties_different: - show.append("gpu_info") show += properties_different diff --git a/scripts/debug-test.sh b/scripts/debug-test.sh index 91946c514..c6c1e988a 100755 --- a/scripts/debug-test.sh +++ b/scripts/debug-test.sh @@ -110,7 +110,7 @@ rm -rf "$build_dir" && mkdir "$build_dir" || abort "Failed to make $build_dir" ########################################################### # Note: test-eval-callback requires -DLLAMA_CURL -cmake -B "./$build_dir" -DCMAKE_BUILD_TYPE=Debug -DGGML_CUDA=1 -DLLAMA_CURL=1 || abort "Failed to build enviroment" +cmake -B "./$build_dir" -DCMAKE_BUILD_TYPE=Debug -DGGML_CUDA=1 -DLLAMA_CURL=1 || abort "Failed to build environment" pushd "$build_dir" make -j || abort "Failed to compile" popd > /dev/null || exit 1 @@ -127,7 +127,7 @@ printf "\n\nGathering tests that fit REGEX: ${test_suite} ...\n" pushd "$build_dir" tests=($(ctest -R ${test_suite} -V -N | grep -E " +Test +#[0-9]+*" | cut -d':' -f2 | awk '{$1=$1};1')) if [ ${#tests[@]} -eq 0 ]; then - abort "No tests avaliable... check your compliation process..." + abort "No tests available... check your compilation process..." fi popd > /dev/null || exit 1 @@ -137,7 +137,7 @@ popd > /dev/null || exit 1 # Select test number if [ -z $test_number ]; then - # List out avaliable tests + # List out available tests printf "Which test would you like to debug?\n" id=0 for s in "${tests[@]}" diff --git a/scripts/fetch_server_test_models.py b/scripts/fetch_server_test_models.py new file mode 100755 index 000000000..05690b138 --- /dev/null +++ b/scripts/fetch_server_test_models.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +''' + This script fetches all the models used in the server tests. + + This is useful for slow tests that use larger models, to avoid them timing out on the model downloads. + + It is meant to be run from the root of the repository. + + Example: + python scripts/fetch_server_test_models.py + ( cd examples/server/tests && ./tests.sh -v -x -m slow ) +''' +import ast +import glob +import logging +import os +from typing import Generator +from pydantic import BaseModel +from typing import Optional +import subprocess + + +class HuggingFaceModel(BaseModel): + hf_repo: str + hf_file: Optional[str] = None + + class Config: + frozen = True + + +def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]: + try: + with open(test_file) as f: + tree = ast.parse(f.read()) + except Exception as e: + logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}') + return + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + for dec in node.decorator_list: + if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize': + param_names = ast.literal_eval(dec.args[0]).split(",") + if "hf_repo" not in param_names: + continue + + raw_param_values = dec.args[1] + if not isinstance(raw_param_values, ast.List): + logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}') + continue + + hf_repo_idx = param_names.index("hf_repo") + hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None + + for t in raw_param_values.elts: + if not isinstance(t, ast.Tuple): + logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}') + continue + yield HuggingFaceModel( + hf_repo=ast.literal_eval(t.elts[hf_repo_idx]), + hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None) + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') + + models = sorted(list(set([ + model + for test_file in glob.glob('examples/server/tests/unit/test_*.py') + for model in collect_hf_model_test_parameters(test_file) + ])), key=lambda m: (m.hf_repo, m.hf_file)) + + logging.info(f'Found {len(models)} models in parameterized tests:') + for m in models: + logging.info(f' - {m.hf_repo} / {m.hf_file}') + + cli_path = os.environ.get( + 'LLAMA_SERVER_BIN_PATH', + os.path.join( + os.path.dirname(__file__), + '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli')) + + for m in models: + if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file): + continue + if m.hf_file is not None and '-of-' in m.hf_file: + logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file') + continue + logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched') + cmd = [ + cli_path, + '-hfr', m.hf_repo, + *([] if m.hf_file is None else ['-hff', m.hf_file]), + '-n', '1', + '-p', 'Hey', + '--no-warmup', + '--log-disable', + '-no-cnv'] + if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo: + cmd.append('-fa') + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError: + logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}') + exit(1) diff --git a/scripts/get_chat_template.py b/scripts/get_chat_template.py new file mode 100644 index 000000000..e8982d11a --- /dev/null +++ b/scripts/get_chat_template.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +''' + Fetches the Jinja chat template of a HuggingFace model. + If a model has multiple chat templates, you can specify the variant name. + + Syntax: + ./scripts/get_chat_template.py model_id [variant] + + Examples: + ./scripts/get_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct + ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use + ./scripts/get_chat_template.py meta-llama/Llama-3.2-3B-Instruct +''' + +import json +import re +import sys + + +def get_chat_template(model_id, variant=None): + try: + # Use huggingface_hub library if available. + # Allows access to gated models if the user has access and ran `huggingface-cli login`. + from huggingface_hub import hf_hub_download + with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: + config_str = f.read() + except ImportError: + import requests + assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}" + response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") + if response.status_code == 401: + raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`') + response.raise_for_status() + config_str = response.text + + try: + config = json.loads(config_str) + except json.JSONDecodeError: + # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json + # (Remove extra '}' near the end of the file) + config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) + + chat_template = config['chat_template'] + if isinstance(chat_template, str): + return chat_template + else: + variants = { + ct['name']: ct['template'] + for ct in chat_template + } + + def format_variants(): + return ', '.join(f'"{v}"' for v in variants.keys()) + + if variant is None: + if 'default' not in variants: + raise Exception(f'Please specify a chat template variant (one of {format_variants()})') + variant = 'default' + sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n') + elif variant not in variants: + raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})") + + return variants[variant] + + +def main(args): + if len(args) < 1: + raise ValueError("Please provide a model ID and an optional variant name") + model_id = args[0] + variant = None if len(args) < 2 else args[1] + + template = get_chat_template(model_id, variant) + sys.stdout.write(template) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/scripts/hf.sh b/scripts/hf.sh index 85c2c4d9a..b251925fa 100755 --- a/scripts/hf.sh +++ b/scripts/hf.sh @@ -26,7 +26,7 @@ function has_cmd { } if has_cmd wget; then - cmd="wget -q --show-progress -c -O %s/%s %s" + cmd="wget -q -c -O %s/%s %s" elif has_cmd curl; then cmd="curl -C - -f --output-dir %s -o %s -L %s" else diff --git a/scripts/pod-llama.sh b/scripts/pod-llama.sh deleted file mode 100644 index 6e56e1ed0..000000000 --- a/scripts/pod-llama.sh +++ /dev/null @@ -1,212 +0,0 @@ -#!/bin/bash -# -# Use this script only on fresh pods (runpod.io)! -# Otherwise, it can break your environment! -# - -if [ -z "$1" ]; then - echo "Usage: $0 " - echo " 0: no models" - echo " 1: tinyllama-1b" - echo " 2: codellama-7b" - echo " 3: codellama-13b" - echo " 4: codellama-34b" - echo " 5: codellama-7b-instruct" - echo " 6: codellama-13b-instruct" - echo " 7: codellama-34b-instruct" - - exit 1 -fi - -set -x - -# setup deps -apt-get update -apt-get install -y git-lfs cmake cmake-curses-gui vim ruby -git-lfs install - -if [ ! -d "/workspace" ]; then - ln -sfn $(pwd) /workspace -fi - -# download data -cd /workspace - -# this is useful to git clone repos without doubling the disk size due to .git -git clone https://github.com/iboB/git-lfs-download -ln -sfn /workspace/git-lfs-download/git-lfs-download /usr/local/bin/git-lfs-download - -# llama.cpp -cd /workspace -git clone https://github.com/ggerganov/llama.cpp - -cd llama.cpp - -GGML_CUDA=1 make -j - -ln -sfn /workspace/TinyLlama-1.1B-Chat-v0.3 ./models/tinyllama-1b -ln -sfn /workspace/CodeLlama-7b-hf ./models/codellama-7b -ln -sfn /workspace/CodeLlama-13b-hf ./models/codellama-13b -ln -sfn /workspace/CodeLlama-34b-hf ./models/codellama-34b -ln -sfn /workspace/CodeLlama-7b-Instruct-hf ./models/codellama-7b-instruct -ln -sfn /workspace/CodeLlama-13b-Instruct-hf ./models/codellama-13b-instruct -ln -sfn /workspace/CodeLlama-34b-Instruct-hf ./models/codellama-34b-instruct - -pip install -r requirements.txt - -# cmake -cd /workspace/llama.cpp - -mkdir build-cublas -cd build-cublas - -cmake -DGGML_CUDA=1 ../ -make -j - -if [ "$1" -eq "0" ]; then - exit 0 -fi - -# more models -if [ "$1" -eq "1" ]; then - cd /workspace - - git-lfs-download https://huggingface.co/PY007/TinyLlama-1.1B-Chat-v0.3 - - cd /workspace/llama.cpp - - python3 examples/convert_legacy_llama.py ./models/tinyllama-1b --outfile ./models/tinyllama-1b/ggml-model-f16.gguf --outtype f16 - - ./llama-quantize ./models/tinyllama-1b/ggml-model-f16.gguf ./models/tinyllama-1b/ggml-model-q4_0.gguf q4_0 - ./llama-quantize ./models/tinyllama-1b/ggml-model-f16.gguf ./models/tinyllama-1b/ggml-model-q4_k.gguf q4_k - ./llama-quantize ./models/tinyllama-1b/ggml-model-f16.gguf ./models/tinyllama-1b/ggml-model-q8_0.gguf q8_0 -fi - -if [ "$1" -eq "2" ]; then - cd /workspace - - git-lfs-download https://huggingface.co/codellama/CodeLlama-7b-hf --without *safetensors* - rm -v ./CodeLlama-7b-hf/*safetensors* - - cd /workspace/llama.cpp - - python3 examples/convert_legacy_llama.py ./models/codellama-7b --outfile ./models/codellama-7b/ggml-model-f16.gguf --outtype f16 - - ./llama-quantize ./models/codellama-7b/ggml-model-f16.gguf ./models/codellama-7b/ggml-model-q4_0.gguf q4_0 - ./llama-quantize ./models/codellama-7b/ggml-model-f16.gguf ./models/codellama-7b/ggml-model-q4_k.gguf q4_k - ./llama-quantize ./models/codellama-7b/ggml-model-f16.gguf ./models/codellama-7b/ggml-model-q8_0.gguf q8_0 -fi - -if [ "$1" -eq "3" ]; then - cd /workspace - - git-lfs-download https://huggingface.co/codellama/CodeLlama-13b-hf --without *safetensors* - rm -v ./CodeLlama-13b-hf/*safetensors* - - cd /workspace/llama.cpp - - python3 examples/convert_legacy_llama.py ./models/codellama-13b --outfile ./models/codellama-13b/ggml-model-f16.gguf --outtype f16 - - ./llama-quantize ./models/codellama-13b/ggml-model-f16.gguf ./models/codellama-13b/ggml-model-q4_0.gguf q4_0 - ./llama-quantize ./models/codellama-13b/ggml-model-f16.gguf ./models/codellama-13b/ggml-model-q4_k.gguf q4_k - ./llama-quantize ./models/codellama-13b/ggml-model-f16.gguf ./models/codellama-13b/ggml-model-q8_0.gguf q8_0 -fi - -if [ "$1" -eq "4" ]; then - cd /workspace - - git-lfs-download https://huggingface.co/codellama/CodeLlama-34b-hf --without *safetensors* - rm -v ./CodeLlama-34b-hf/*safetensors* - - cd /workspace/llama.cpp - - python3 examples/convert_legacy_llama.py ./models/codellama-34b --outfile ./models/codellama-34b/ggml-model-f16.gguf --outtype f16 - - ./llama-quantize ./models/codellama-34b/ggml-model-f16.gguf ./models/codellama-34b/ggml-model-q4_0.gguf q4_0 - ./llama-quantize ./models/codellama-34b/ggml-model-f16.gguf ./models/codellama-34b/ggml-model-q4_k.gguf q4_k - ./llama-quantize ./models/codellama-34b/ggml-model-f16.gguf ./models/codellama-34b/ggml-model-q8_0.gguf q8_0 -fi - -if [ "$1" -eq "5" ]; then - cd /workspace - - git-lfs-download https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf --without *safetensors* - rm -v ./CodeLlama-7b-Instruct-hf/*safetensors* - - cd /workspace/llama.cpp - - python3 examples/convert_legacy_llama.py ./models/codellama-7b-instruct --outfile ./models/codellama-7b-instruct/ggml-model-f16.gguf --outtype f16 - - ./llama-quantize ./models/codellama-7b-instruct/ggml-model-f16.gguf ./models/codellama-7b-instruct/ggml-model-q4_0.gguf q4_0 - ./llama-quantize ./models/codellama-7b-instruct/ggml-model-f16.gguf ./models/codellama-7b-instruct/ggml-model-q4_k.gguf q4_k - ./llama-quantize ./models/codellama-7b-instruct/ggml-model-f16.gguf ./models/codellama-7b-instruct/ggml-model-q8_0.gguf q8_0 -fi - -if [ "$1" -eq "6" ]; then - cd /workspace - - git-lfs-download https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf --without *safetensors* - rm -v ./CodeLlama-13b-Instruct-hf/*safetensors* - - cd /workspace/llama.cpp - - python3 examples/convert_legacy_llama.py ./models/codellama-13b-instruct --outfile ./models/codellama-13b-instruct/ggml-model-f16.gguf --outtype f16 - - ./llama-quantize ./models/codellama-13b-instruct/ggml-model-f16.gguf ./models/codellama-13b-instruct/ggml-model-q4_0.gguf q4_0 - ./llama-quantize ./models/codellama-13b-instruct/ggml-model-f16.gguf ./models/codellama-13b-instruct/ggml-model-q4_k.gguf q4_k - ./llama-quantize ./models/codellama-13b-instruct/ggml-model-f16.gguf ./models/codellama-13b-instruct/ggml-model-q8_0.gguf q8_0 -fi - -if [ "$1" -eq "7" ]; then - cd /workspace - - git-lfs-download https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf --without *safetensors* - rm -v ./CodeLlama-34b-Instruct-hf/*safetensors* - - cd /workspace/llama.cpp - - python3 examples/convert_legacy_llama.py ./models/codellama-34b-instruct --outfile ./models/codellama-34b-instruct/ggml-model-f16.gguf --outtype f16 - - ./llama-quantize ./models/codellama-34b-instruct/ggml-model-f16.gguf ./models/codellama-34b-instruct/ggml-model-q4_0.gguf q4_0 - ./llama-quantize ./models/codellama-34b-instruct/ggml-model-f16.gguf ./models/codellama-34b-instruct/ggml-model-q4_k.gguf q4_k - ./llama-quantize ./models/codellama-34b-instruct/ggml-model-f16.gguf ./models/codellama-34b-instruct/ggml-model-q8_0.gguf q8_0 -fi - -if [ "$1" -eq "1" ]; then - # perf + perplexity - cd /workspace/llama.cpp/build-cublas - - make -j && ../scripts/run-all-perf.sh tinyllama-1b "f16" "-ngl 99 -t 1 -p 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,64,128,256,512,1024,2048 -n 128" - - ../scripts/get-wikitext-2.sh - unzip wikitext-2-raw-v1.zip - - make -j && ./bin/llama-perplexity -m ../models/tinyllama-1b/ggml-model-f16.gguf -f ./wikitext-2-raw/wiki.test.raw -ngl 100 --chunks 32 - - # batched - cd /workspace/llama.cpp - - GGML_CUDA=1 make -j && ./llama-batched ./models/tinyllama-1b/ggml-model-f16.gguf "Hello, my name is" 8 128 999 - - # batched-bench - cd /workspace/llama.cpp - - GGML_CUDA=1 make -j && ./llama-batched-bench ./models/tinyllama-1b/ggml-model-f16.gguf 4608 1 99 0 512 128 1,2,3,4,5,6,7,8,16,32 - - # parallel - cd /workspace/llama.cpp - - GGML_CUDA=1 make -j && ./llama-parallel -m ./models/tinyllama-1b/ggml-model-f16.gguf -t 1 -ngl 100 -c 4096 -b 512 -s 1 -np 8 -ns 128 -n 100 -cb - -fi - -# speculative -#if [ "$1" -eq "7" ]; then -# cd /workspace/llama.cpp -# -# GGML_CUDA=1 make -j && ./llama-speculative -m ./models/codellama-34b-instruct/ggml-model-f16.gguf -md ./models/codellama-7b-instruct/ggml-model-q4_0.gguf -p "# Dijkstra's shortest path algorithm in Python (4 spaces indentation) + complexity analysis:\n\n" -e -ngl 999 -ngld 999 -t 4 -n 512 -c 4096 -s 21 --draft 16 -np 1 --temp 0.0 -#fi - -# more benches -#GGML_CUDA=1 make -j && ./llama-batched-bench ./models/codellama-7b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,128,800 1 -#GGML_CUDA=1 make -j && ./llama-batched-bench ./models/codellama-13b/ggml-model-q4_k.gguf 4096 1 99 1 512,3200 128,128,800 1 diff --git a/scripts/run-with-preset.py b/scripts/run-with-preset.py deleted file mode 100755 index ee21eab37..000000000 --- a/scripts/run-with-preset.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 - -import logging -import argparse -import os -import subprocess -import sys - -import yaml - -logger = logging.getLogger("run-with-preset") - -CLI_ARGS_LLAMA_CLI_PERPLEXITY = [ - "batch-size", "cfg-negative-prompt", "cfg-scale", "chunks", "color", "ctx-size", "escape", - "export", "file", "frequency-penalty", "grammar", "grammar-file", "hellaswag", - "hellaswag-tasks", "ignore-eos", "in-prefix", "in-prefix-bos", "in-suffix", - "interactive", "interactive-first", "keep", "logdir", "logit-bias", "lora", "lora-base", - "low-vram", "main-gpu", "memory-f32", "mirostat", "mirostat-ent", "mirostat-lr", "mlock", - "model", "multiline-input", "n-gpu-layers", "n-predict", "no-mmap", "no-mul-mat-q", - "np-penalize-nl", "numa", "ppl-output-type", "ppl-stride", "presence-penalty", "prompt", - "prompt-cache", "prompt-cache-all", "prompt-cache-ro", "repeat-last-n", - "repeat-penalty", "reverse-prompt", "rope-freq-base", "rope-freq-scale", "rope-scale", "seed", - "simple-io", "tensor-split", "threads", "temp", "tfs", "top-k", "top-p", "typical", - "verbose-prompt" -] - -CLI_ARGS_LLAMA_BENCH = [ - "batch-size", "memory-f32", "low-vram", "model", "mul-mat-q", "n-gen", "n-gpu-layers", - "n-prompt", "output", "repetitions", "tensor-split", "threads", "verbose" -] - -CLI_ARGS_LLAMA_SERVER = [ - "alias", "batch-size", "ctx-size", "embedding", "host", "memory-f32", "lora", "lora-base", - "low-vram", "main-gpu", "mlock", "model", "n-gpu-layers", "n-probs", "no-mmap", "no-mul-mat-q", - "numa", "path", "port", "rope-freq-base", "timeout", "rope-freq-scale", "tensor-split", - "threads", "verbose" -] - -description = """Run llama.cpp binaries with presets from YAML file(s). -To specify which binary should be run, specify the "binary" property (llama-cli, llama-perplexity, llama-bench, and llama-server are supported). -To get a preset file template, run a llama.cpp binary with the "--logdir" CLI argument. - -Formatting considerations: -- The YAML property names are the same as the CLI argument names of the corresponding binary. -- Properties must use the long name of their corresponding llama.cpp CLI arguments. -- Like the llama.cpp binaries the property names do not differentiate between hyphens and underscores. -- Flags must be defined as ": true" to be effective. -- To define the logit_bias property, the expected format is ": " in the "logit_bias" namespace. -- To define multiple "reverse_prompt" properties simultaneously the expected format is a list of strings. -- To define a tensor split, pass a list of floats. -""" -usage = "run-with-preset.py [-h] [yaml_files ...] [-- ...]" -epilog = (" -- specify additional CLI ars to be passed to the binary (override all preset files). " - "Unknown args will be ignored.") - -parser = argparse.ArgumentParser( - description=description, usage=usage, epilog=epilog, formatter_class=argparse.RawTextHelpFormatter) -parser.add_argument("-bin", "--binary", help="The binary to run.") -parser.add_argument("yaml_files", nargs="*", - help="Arbitrary number of YAML files from which to read preset values. " - "If two files specify the same values the later one will be used.") -parser.add_argument("--verbose", action="store_true", help="increase output verbosity") - -known_args, unknown_args = parser.parse_known_args() - -if not known_args.yaml_files and not unknown_args: - parser.print_help() - sys.exit(0) - -logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO) - -props = dict() - -for yaml_file in known_args.yaml_files: - with open(yaml_file, "r") as f: - props.update(yaml.load(f, yaml.SafeLoader)) - -props = {prop.replace("_", "-"): val for prop, val in props.items()} - -binary = props.pop("binary", "llama-cli") -if known_args.binary: - binary = known_args.binary - -if os.path.exists(f"./{binary}"): - binary = f"./{binary}" - -if binary.lower().endswith("llama-cli") or binary.lower().endswith("llama-perplexity"): - cli_args = CLI_ARGS_LLAMA_CLI_PERPLEXITY -elif binary.lower().endswith("llama-bench"): - cli_args = CLI_ARGS_LLAMA_BENCH -elif binary.lower().endswith("llama-server"): - cli_args = CLI_ARGS_LLAMA_SERVER -else: - logger.error(f"Unknown binary: {binary}") - sys.exit(1) - -command_list = [binary] - -for cli_arg in cli_args: - value = props.pop(cli_arg, None) - - if not value or value == -1: - continue - - if cli_arg == "logit-bias": - for token, bias in value.items(): - command_list.append("--logit-bias") - command_list.append(f"{token}{bias:+}") - continue - - if cli_arg == "reverse-prompt" and not isinstance(value, str): - for rp in value: - command_list.append("--reverse-prompt") - command_list.append(str(rp)) - continue - - command_list.append(f"--{cli_arg}") - - if cli_arg == "tensor-split": - command_list.append(",".join([str(v) for v in value])) - continue - - value = str(value) - - if value != "True": - command_list.append(str(value)) - -num_unused = len(props) -if num_unused > 10: - logger.info(f"The preset file contained a total of {num_unused} unused properties.") -elif num_unused > 0: - logger.info("The preset file contained the following unused properties:") - for prop, value in props.items(): - logger.info(f" {prop}: {value}") - -command_list += unknown_args - -sp = subprocess.Popen(command_list) - -while sp.returncode is None: - try: - sp.wait() - except KeyboardInterrupt: - pass - -sys.exit(sp.returncode) diff --git a/scripts/server-llm.sh b/scripts/server-llm.sh deleted file mode 100644 index 802592a3e..000000000 --- a/scripts/server-llm.sh +++ /dev/null @@ -1,418 +0,0 @@ -#!/bin/bash -# -# Helper script for deploying llama.cpp server with a single Bash command -# -# - Works on Linux and macOS -# - Supports: CPU, CUDA, Metal -# - Can run all GGUF models from HuggingFace -# - Can serve requests in parallel -# - Always builds latest llama.cpp from GitHub -# -# Limitations -# -# - Chat templates are poorly supported (base models recommended) -# - Might be unstable! -# -# Usage: -# ./server-llm.sh [--port] [--repo] [--wtype] [--backend] [--gpu-id] [--n-parallel] [--n-kv] [--verbose] [-non-interactive] -# -# --port: port number, default is 8888 -# --repo: path to a repo containing GGUF model files -# --wtype: weights type (f16, q8_0, q4_0, q4_1), default is user-input -# --backend: cpu, cuda, metal, depends on the OS -# --gpu-id: gpu id, default is 0 -# --n-parallel: number of parallel requests, default is 8 -# --n-kv: KV cache size, default is 4096 -# --verbose: verbose output -# --non-interactive: run without asking a permission to run -# -# Example: -# -# bash -c "$(curl -s https://ggml.ai/server-llm.sh)" -# - -set -e - -# required utils: curl, git, make -if ! command -v curl &> /dev/null; then - printf "[-] curl not found\n" - exit 1 -fi -if ! command -v git &> /dev/null; then - printf "[-] git not found\n" - exit 1 -fi -if ! command -v make &> /dev/null; then - printf "[-] make not found\n" - exit 1 -fi - -# parse arguments -is_interactive=1 -port=8888 -repo="" -wtype="" -backend="cpu" - -# if macOS, use metal backend by default -if [[ "$OSTYPE" == "darwin"* ]]; then - backend="metal" -elif command -v nvcc &> /dev/null; then - backend="cuda" -fi - -gpu_id=0 -n_parallel=8 -n_kv=4096 -verbose=0 - -function print_usage { - printf "Usage:\n" - printf " ./server-llm.sh [--port] [--repo] [--wtype] [--backend] [--gpu-id] [--n-parallel] [--n-kv] [--verbose] [-non-interactive]\n\n" - printf " --port: port number, default is 8888\n" - printf " --repo: path to a repo containing GGUF model files\n" - printf " --wtype: weights type (f16, q8_0, q4_0, q4_1), default is user-input\n" - printf " --backend: cpu, cuda, metal, depends on the OS\n" - printf " --gpu-id: gpu id, default is 0\n" - printf " --n-parallel: number of parallel requests, default is 8\n" - printf " --n-kv: KV cache size, default is 4096\n" - printf " --verbose: verbose output\n\n" - printf " --non-interactive: run without asking a permission to run\n" - printf "Example:\n\n" - printf ' bash -c "$(curl -s https://ggml.ai/server-llm.sh)"\n\n' -} - -while [[ $# -gt 0 ]]; do - key="$1" - case $key in - --non-interactive) - is_interactive=0 - shift - ;; - --port) - port="$2" - shift - shift - ;; - --repo) - repo="$2" - shift - shift - ;; - --wtype) - wtype="$2" - shift - shift - ;; - --backend) - backend="$2" - shift - shift - ;; - --gpu-id) - gpu_id="$2" - shift - shift - ;; - --n-parallel) - n_parallel="$2" - shift - shift - ;; - --n-kv) - n_kv="$2" - shift - shift - ;; - --verbose) - verbose=1 - shift - ;; - --help) - print_usage - exit 0 - ;; - *) - echo "Unknown argument: $key" - print_usage - exit 1 - ;; - esac -done - -# available weights types -wtypes=("F16" "Q8_0" "Q4_0" "Q4_1" "Q5_0" "Q5_1" "Q6_K" "Q5_K_M" "Q5_K_S" "Q4_K_M" "Q4_K_S" "Q3_K_L" "Q3_K_M" "Q3_K_S" "Q2_K") - -wfiles=() -for wt in "${wtypes[@]}"; do - wfiles+=("") -done - -# map wtype input to index -if [[ ! -z "$wtype" ]]; then - iw=-1 - is=0 - for wt in "${wtypes[@]}"; do - # uppercase - uwt=$(echo "$wt" | tr '[:lower:]' '[:upper:]') - if [[ "$uwt" == "$wtype" ]]; then - iw=$is - break - fi - is=$((is+1)) - done - - if [[ $iw -eq -1 ]]; then - printf "[-] Invalid weight type: %s\n" "$wtype" - exit 1 - fi - - wtype="$iw" -fi - -# sample repos -repos=( - "https://huggingface.co/TheBloke/Llama-2-7B-GGUF" - "https://huggingface.co/TheBloke/Llama-2-13B-GGUF" - "https://huggingface.co/TheBloke/Llama-2-70B-GGUF" - "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF" - "https://huggingface.co/TheBloke/CodeLlama-13B-GGUF" - "https://huggingface.co/TheBloke/CodeLlama-34B-GGUF" - "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF" - "https://huggingface.co/TheBloke/zephyr-7B-beta-GGUF" - "https://huggingface.co/TheBloke/OpenHermes-2-Mistral-7B-GGUF" - "https://huggingface.co/TheBloke/CausalLM-7B-GGUF" -) -if [ $is_interactive -eq 1 ]; then - printf "\n" - printf "[I] This is a helper script for deploying llama.cpp's server on this machine.\n\n" - printf " Based on the options that follow, the script might download a model file\n" - printf " from the internet, which can be a few GBs in size. The script will also\n" - printf " build the latest llama.cpp source code from GitHub, which can be unstable.\n" - printf "\n" - printf " Upon success, an HTTP server will be started and it will serve the selected\n" - printf " model using llama.cpp for demonstration purposes.\n" - printf "\n" - printf " Please note:\n" - printf "\n" - printf " - All new data will be stored in the current folder\n" - printf " - The server will be listening on all network interfaces\n" - printf " - The server will run with default settings which are not always optimal\n" - printf " - Do not judge the quality of a model based on the results from this script\n" - printf " - Do not use this script to benchmark llama.cpp\n" - printf " - Do not use this script in production\n" - printf " - This script is only for demonstration purposes\n" - printf "\n" - printf " If you don't know what you are doing, please press Ctrl-C to abort now\n" - printf "\n" - printf " Press Enter to continue ...\n\n" - - read -fi - -if [[ -z "$repo" ]]; then - printf "[+] No repo provided from the command line\n" - printf " Please select a number from the list below or enter an URL:\n\n" - - is=0 - for r in "${repos[@]}"; do - printf " %2d) %s\n" $is "$r" - is=$((is+1)) - done - - # ask for repo until index of sample repo is provided or an URL - while [[ -z "$repo" ]]; do - printf "\n Or choose one from: https://huggingface.co/models?sort=trending&search=gguf\n\n" - read -p "[+] Select repo: " repo - - # check if the input is a number - if [[ "$repo" =~ ^[0-9]+$ ]]; then - if [[ "$repo" -ge 0 && "$repo" -lt ${#repos[@]} ]]; then - repo="${repos[$repo]}" - else - printf "[-] Invalid repo index: %s\n" "$repo" - repo="" - fi - elif [[ "$repo" =~ ^https?:// ]]; then - repo="$repo" - else - printf "[-] Invalid repo URL: %s\n" "$repo" - repo="" - fi - done -fi - -# remove suffix -repo=$(echo "$repo" | sed -E 's/\/tree\/main$//g') - -printf "[+] Checking for GGUF model files in %s\n" "$repo" - -# find GGUF files in the source -# TODO: better logic -model_tree="${repo%/}/tree/main" -model_files=$(curl -s "$model_tree" | grep -i "\\.gguf" | sed -E 's/.*(.*)<\/span><\/a>/\1/g') - -# list all files in the provided git repo -printf "[+] Model files:\n\n" -for file in $model_files; do - # determine iw by grepping the filename with wtypes - iw=-1 - is=0 - for wt in "${wtypes[@]}"; do - # uppercase - ufile=$(echo "$file" | tr '[:lower:]' '[:upper:]') - if [[ "$ufile" =~ "$wt" ]]; then - iw=$is - break - fi - is=$((is+1)) - done - - if [[ $iw -eq -1 ]]; then - continue - fi - - wfiles[$iw]="$file" - - have=" " - if [[ -f "$file" ]]; then - have="*" - fi - - printf " %2d) %s %s\n" $iw "$have" "$file" -done - -wfile="${wfiles[$wtype]}" - -# ask for weights type until provided and available -while [[ -z "$wfile" ]]; do - printf "\n" - read -p "[+] Select weight type: " wtype - wfile="${wfiles[$wtype]}" - - if [[ -z "$wfile" ]]; then - printf "[-] Invalid weight type: %s\n" "$wtype" - wtype="" - fi -done - -printf "[+] Selected weight type: %s (%s)\n" "$wtype" "$wfile" - -url="${repo%/}/resolve/main/$wfile" - -# check file if the model has been downloaded before -chk="$wfile.chk" - -# check if we should download the file -# - if $wfile does not exist -# - if $wfile exists but $chk does not exist -# - if $wfile exists and $chk exists but $wfile is newer than $chk -# TODO: better logic using git lfs info - -do_download=0 - -if [[ ! -f "$wfile" ]]; then - do_download=1 -elif [[ ! -f "$chk" ]]; then - do_download=1 -elif [[ "$wfile" -nt "$chk" ]]; then - do_download=1 -fi - -if [[ $do_download -eq 1 ]]; then - printf "[+] Downloading weights from %s\n" "$url" - - # download the weights file - curl -o "$wfile" -# -L "$url" - - # create a check file if successful - if [[ $? -eq 0 ]]; then - printf "[+] Creating check file %s\n" "$chk" - touch "$chk" - fi -else - printf "[+] Using cached weights %s\n" "$wfile" -fi - -# get latest llama.cpp and build - -printf "[+] Downloading latest llama.cpp\n" - -llama_cpp_dir="__llama_cpp_port_${port}__" - -if [[ -d "$llama_cpp_dir" && ! -f "$llama_cpp_dir/__ggml_script__" ]]; then - # if the dir exists and there isn't a file "__ggml_script__" in it, abort - printf "[-] Directory %s already exists\n" "$llama_cpp_dir" - printf "[-] Please remove it and try again\n" - exit 1 -elif [[ -d "$llama_cpp_dir" ]]; then - printf "[+] Directory %s already exists\n" "$llama_cpp_dir" - printf "[+] Using cached llama.cpp\n" - - cd "$llama_cpp_dir" - git reset --hard - git fetch - git checkout origin/master - - cd .. -else - printf "[+] Cloning llama.cpp\n" - - git clone https://github.com/ggerganov/llama.cpp "$llama_cpp_dir" -fi - -# mark that that the directory is made by this script -touch "$llama_cpp_dir/__ggml_script__" - -if [[ $verbose -eq 1 ]]; then - set -x -fi - -# build -cd "$llama_cpp_dir" - -make clean - -log="--silent" -if [[ $verbose -eq 1 ]]; then - log="" -fi - -if [[ "$backend" == "cuda" ]]; then - printf "[+] Building with CUDA backend\n" - GGML_CUDA=1 make -j llama-server $log -elif [[ "$backend" == "cpu" ]]; then - printf "[+] Building with CPU backend\n" - make -j llama-server $log -elif [[ "$backend" == "metal" ]]; then - printf "[+] Building with Metal backend\n" - make -j llama-server $log -else - printf "[-] Unknown backend: %s\n" "$backend" - exit 1 -fi - -# run the server - -printf "[+] Running server\n" - -args="" -if [[ "$backend" == "cuda" ]]; then - export CUDA_VISIBLE_DEVICES=$gpu_id - args="-ngl 999" -elif [[ "$backend" == "cpu" ]]; then - args="-ngl 0" -elif [[ "$backend" == "metal" ]]; then - args="-ngl 999" -else - printf "[-] Unknown backend: %s\n" "$backend" - exit 1 -fi - -if [[ $verbose -eq 1 ]]; then - args="$args --verbose" -fi - -./llama-server -m "../$wfile" --host 0.0.0.0 --port "$port" -c $n_kv -np "$n_parallel" $args - -exit 0 diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index f16336594..ec4f4b0a2 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -73,16 +73,22 @@ while read c; do src/ggml*.h \ src/ggml*.c \ src/ggml*.cpp \ - src/ggml*.m \ - src/ggml*.metal \ - src/ggml*.cu \ + src/gguf*.cpp \ + src/ggml-blas/* \ src/ggml-cann/* \ + src/ggml-cpu/* \ src/ggml-cuda/* \ + src/ggml-hip/* \ + src/ggml-kompute/* \ + src/ggml-metal/* \ + src/ggml-musa/* \ + src/ggml-opencl/* \ + src/ggml-rpc/* \ src/ggml-sycl/* \ - src/vulkan-shaders/* \ + src/ggml-vulkan/* \ include/ggml*.h \ + include/gguf*.h \ tests/test-opt.cpp \ - tests/test-grad0.cpp \ tests/test-quantize-fns.cpp \ tests/test-quantize-perf.cpp \ tests/test-backend-ops.cpp \ @@ -113,49 +119,31 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # replace filenames: # - # CMakelists.txt -> ggml/CMakeLists.txt - # src/CMakeLists.txt -> ggml/src/CMakeLists.txt - # cmake/FindSIMD.cmake -> ggml/cmake/FindSIMD.cmake + # CMakelists.txt -> ggml/CMakeLists.txt + # src/CMakeLists.txt -> ggml/src/CMakeLists.txt + # cmake/FindSIMD.cmake -> ggml/cmake/FindSIMD.cmake # - # src/ggml.c -> ggml/src/ggml.c - # src/ggml-aarch64.c -> ggml/src/ggml-aarch64.c - # src/ggml-aarch64.h -> ggml/src/ggml-aarch64.h - # src/ggml-alloc.c -> ggml/src/ggml-alloc.c - # src/ggml-backend-impl.h -> ggml/src/ggml-backend-impl.h - # src/ggml-backend.c -> ggml/src/ggml-backend.c - # src/ggml-cann/* -> ggml/src/ggml-cann/ - # src/ggml-cann.cpp -> ggml/src/ggml-cann.cpp - # src/ggml-common.h -> ggml/src/ggml-common.h - # src/ggml-cuda/* -> ggml/src/ggml-cuda/ - # src/ggml-cuda.cu -> ggml/src/ggml-cuda.cu - # src/ggml-impl.h -> ggml/src/ggml-impl.h - # src/ggml-kompute.cpp -> ggml/src/ggml-kompute.cpp - # src/ggml-metal.m -> ggml/src/ggml-metal.m - # src/ggml-quants.c -> ggml/src/ggml-quants.c - # src/ggml-quants.h -> ggml/src/ggml-quants.h - # src/ggml-rpc.cpp -> ggml/src/ggml-rpc.cpp - # src/ggml-sycl/* -> ggml/src/ggml-sycl/ - # src/ggml-sycl.cpp -> ggml/src/ggml-sycl.cpp - # src/ggml-vulkan.cpp -> ggml/src/ggml-vulkan.cpp - # src/vulkan-shaders/* -> ggml/src/vulkan-shaders/ + # src/ggml*.c -> ggml/src/ggml*.c + # src/ggml*.cpp -> ggml/src/ggml*.cpp + # src/ggml*.h -> ggml/src/ggml*.h + # src/gguf*.cpp -> ggml/src/gguf*.cpp + # src/ggml-blas/* -> ggml/src/ggml-blas/* + # src/ggml-cann/* -> ggml/src/ggml-cann/* + # src/ggml-cpu/* -> ggml/src/ggml-cpu/* + # src/ggml-cuda/* -> ggml/src/ggml-cuda/* + # src/ggml-hip/* -> ggml/src/ggml-hip/* + # src/ggml-kompute/* -> ggml/src/ggml-kompute/* + # src/ggml-metal/* -> ggml/src/ggml-metal/* + # src/ggml-musa/* -> ggml/src/ggml-musa/* + # src/ggml-opencl/* -> ggml/src/ggml-opencl/* + # src/ggml-rpc/* -> ggml/src/ggml-rpc/* + # src/ggml-sycl/* -> ggml/src/ggml-sycl/* + # src/ggml-vulkan/* -> ggml/src/ggml-vulkan/* # - # include/ggml.h -> ggml/include/ggml.h - # include/ggml-alloc.h -> ggml/include/ggml-alloc.h - # include/ggml-backend.h -> ggml/include/ggml-backend.h - # include/ggml-blas.h -> ggml/include/ggml-blas.h - # include/ggml-cann.h -> ggml/include/ggml-cann.h - # include/ggml-cuda.h -> ggml/include/ggml-cuda.h - # include/ggml-kompute.h -> ggml/include/ggml-kompute.h - # include/ggml-metal.h -> ggml/include/ggml-metal.h - # include/ggml-rpc.h -> ggml/include/ggml-rpc.h - # include/ggml-sycl.h -> ggml/include/ggml-sycl.h - # include/ggml-vulkan.h -> ggml/include/ggml-vulkan.h + # include/ggml*.h -> ggml/include/ggml*.h + # include/gguf*.h -> ggml/include/gguf*.h # - # tests/test-opt.cpp -> tests/test-opt.cpp - # tests/test-grad0.cpp -> tests/test-grad0.cpp - # tests/test-quantize-fns.cpp -> tests/test-quantize-fns.cpp - # tests/test-quantize-perf.cpp -> tests/test-quantize-perf.cpp - # tests/test-backend-ops.cpp -> tests/test-backend-ops.cpp + # tests/test*.cpp -> tests/ # # LICENSE -> LICENSE # scripts/gen-authors.sh -> scripts/gen-authors.sh @@ -164,42 +152,24 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then -e 's/([[:space:]]|[ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \ -e 's/([[:space:]]|[ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \ -e 's/([[:space:]]|[ab]\/)cmake\/FindSIMD.cmake/\1ggml\/cmake\/FindSIMD.cmake/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml\.c/\1ggml\/src\/ggml.c/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.c/\1ggml\/src\/ggml-aarch64.c/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.h/\1ggml\/src\/ggml-aarch64.h/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-alloc\.c/\1ggml\/src\/ggml-alloc.c/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-backend-impl\.h/\1ggml\/src\/ggml-backend-impl.h/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-backend\.c/\1ggml\/src\/ggml-backend.c/g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \ + -e 's/([[:space:]]|[ab]\/)src\/gguf(.*)\.cpp/\1ggml\/src\/gguf\2.cpp/g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-blas\//\1ggml\/src\/ggml-blas\//g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-cann\.cpp/\1ggml\/src\/ggml-cann.cpp/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-common\.h/\1ggml\/src\/ggml-common.h/g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-cuda\.cu/\1ggml\/src\/ggml-cuda.cu/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-impl\.h/\1ggml\/src\/ggml-impl.h/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-kompute\.cpp/\1ggml\/src\/ggml-kompute.cpp/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-metal\.m/\1ggml\/src\/ggml-metal.m/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-quants\.c/\1ggml\/src\/ggml-quants.c/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-quants\.h/\1ggml\/src\/ggml-quants.h/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-rpc\.cpp/\1ggml\/src\/ggml-rpc.cpp/g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-opencl\//\1ggml\/src\/ggml-opencl\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-sycl\.cpp/\1ggml\/src\/ggml-sycl.cpp/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml-vulkan\.cpp/\1ggml\/src\/ggml-vulkan.cpp/g' \ - -e 's/([[:space:]]|[ab]\/)src\/vulkan-shaders\//\1ggml\/src\/vulkan-shaders\//g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml\.h/\1ggml\/include\/ggml.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-alloc\.h/\1ggml\/include\/ggml-alloc.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-backend\.h/\1ggml\/include\/ggml-backend.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-blas\.h/\1ggml\/include\/ggml-blas.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-cann\.h/\1ggml\/include\/ggml-cann.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-cuda\.h/\1ggml\/include\/ggml-cuda.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-kompute\.h/\1ggml\/include\/ggml-kompute.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-metal\.h/\1ggml\/include\/ggml-metal.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-rpc\.h/\1ggml\/include\/ggml-rpc.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-sycl\.h/\1ggml\/include\/ggml-sycl.h/g' \ - -e 's/([[:space:]]|[ab]\/)include\/ggml-vulkan\.h/\1ggml\/include\/ggml-vulkan.h/g' \ - -e 's/([[:space:]]|[ab]\/)examples\/common\.h/\1examples\/common.h/g' \ - -e 's/([[:space:]]|[ab]\/)examples\/common\.cpp/\1examples\/common.cpp/g' \ - -e 's/([[:space:]]|[ab]\/)examples\/common-ggml\.h/\1examples\/common-ggml.h/g' \ - -e 's/([[:space:]]|[ab]\/)examples\/common-ggml\.cpp/\1examples\/common-ggml.cpp/g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-vulkan\//\1ggml\/src\/ggml-vulkan\//g' \ + -e 's/([[:space:]]|[ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \ + -e 's/([[:space:]]|[ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \ + -e 's/([[:space:]]|[ab]\/)tests\/(.*)\.cpp/\1tests\/\2.cpp/g' \ -e 's/([[:space:]]|[ab]\/)LICENSE/\1LICENSE/g' \ -e 's/([[:space:]]|[ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \ > ggml-src.patch.tmp diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 3d2dfb413..ddb9d817e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -10e83a412717c20d57ba19f025248e18e43addf3 +32f0b85987396945afea2291d5f4c5862434292b diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index 30a62e088..e83d415c0 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -4,43 +4,27 @@ cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake -cp -rpv ../ggml/src/ggml.c ./ggml/src/ggml.c -cp -rpv ../ggml/src/ggml-aarch64.c ./ggml/src/ggml-aarch64.c -cp -rpv ../ggml/src/ggml-aarch64.h ./ggml/src/ggml-aarch64.h -cp -rpv ../ggml/src/ggml-alloc.c ./ggml/src/ggml-alloc.c -cp -rpv ../ggml/src/ggml-backend-impl.h ./ggml/src/ggml-backend-impl.h -cp -rpv ../ggml/src/ggml-backend.c ./ggml/src/ggml-backend.c -cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/ -cp -rpv ../ggml/src/ggml-cann.cpp ./ggml/src/ggml-cann.cpp -cp -rpv ../ggml/src/ggml-common.h ./ggml/src/ggml-common.h -cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/ -cp -rpv ../ggml/src/ggml-cuda.cu ./ggml/src/ggml-cuda.cu -cp -rpv ../ggml/src/ggml-impl.h ./ggml/src/ggml-impl.h -cp -rpv ../ggml/src/ggml-kompute.cpp ./ggml/src/ggml-kompute.cpp -cp -rpv ../ggml/src/ggml-metal.m ./ggml/src/ggml-metal.m -cp -rpv ../ggml/src/ggml-metal.metal ./ggml/src/ggml-metal.metal -cp -rpv ../ggml/src/ggml-quants.c ./ggml/src/ggml-quants.c -cp -rpv ../ggml/src/ggml-quants.h ./ggml/src/ggml-quants.h -cp -rpv ../ggml/src/ggml-rpc.cpp ./ggml/src/ggml-rpc.cpp -cp -rpv ../ggml/src/ggml-sycl/* ./ggml/src/ggml-sycl/ -cp -rpv ../ggml/src/ggml-sycl.cpp ./ggml/src/ggml-sycl.cpp -cp -rpv ../ggml/src/ggml-vulkan.cpp ./ggml/src/ggml-vulkan.cpp -cp -rpv ../ggml/src/vulkan-shaders/* ./ggml/src/vulkan-shaders/ +cp -rpv ../ggml/src/ggml*.c ./ggml/src/ +cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/ +cp -rpv ../ggml/src/ggml*.h ./ggml/src/ +cp -rpv ../ggml/src/gguf*.cpp ./ggml/src/ +cp -rpv ../ggml/src/ggml-blas/* ./ggml/src/ggml-blas/ +cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/ +cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/ +cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/ +cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/ +cp -rpv ../ggml/src/ggml-kompute/* ./ggml/src/ggml-kompute/ +cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/ +cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/ +cp -rpv ../ggml/src/ggml-opencl/* ./ggml/src/ggml-opencl/ +cp -rpv ../ggml/src/ggml-rpc/* ./ggml/src/ggml-rpc/ +cp -rpv ../ggml/src/ggml-sycl/* ./ggml/src/ggml-sycl/ +cp -rpv ../ggml/src/ggml-vulkan/* ./ggml/src/ggml-vulkan/ -cp -rpv ../ggml/include/ggml.h ./ggml/include/ggml.h -cp -rpv ../ggml/include/ggml-alloc.h ./ggml/include/ggml-alloc.h -cp -rpv ../ggml/include/ggml-backend.h ./ggml/include/ggml-backend.h -cp -rpv ../ggml/include/ggml-blas.h ./ggml/include/ggml-blas.h -cp -rpv ../ggml/include/ggml-cann.h ./ggml/include/ggml-cann.h -cp -rpv ../ggml/include/ggml-cuda.h ./ggml/include/ggml-cuda.h -cp -rpv ../ggml/include/ggml-kompute.h ./ggml/include/ggml-kompute.h -cp -rpv ../ggml/include/ggml-metal.h ./ggml/include/ggml-metal.h -cp -rpv ../ggml/include/ggml-rpc.h ./ggml/include/ggml-rpc.h -cp -rpv ../ggml/include/ggml-sycl.h ./ggml/include/ggml-sycl.h -cp -rpv ../ggml/include/ggml-vulkan.h ./ggml/include/ggml-vulkan.h +cp -rpv ../ggml/include/ggml*.h ./ggml/include/ +cp -rpv ../ggml/include/gguf*.h ./ggml/include/ cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp -cp -rpv ../ggml/tests/test-grad0.cpp ./tests/test-grad0.cpp cp -rpv ../ggml/tests/test-quantize-fns.cpp ./tests/test-quantize-fns.cpp cp -rpv ../ggml/tests/test-quantize-perf.cpp ./tests/test-quantize-perf.cpp cp -rpv ../ggml/tests/test-backend-ops.cpp ./tests/test-backend-ops.cpp diff --git a/spm-headers/ggml-cpp.h b/spm-headers/ggml-cpp.h new file mode 120000 index 000000000..8a8604cc2 --- /dev/null +++ b/spm-headers/ggml-cpp.h @@ -0,0 +1 @@ +../ggml/include/ggml-cpp.h \ No newline at end of file diff --git a/spm-headers/ggml-cpu.h b/spm-headers/ggml-cpu.h new file mode 120000 index 000000000..66e629607 --- /dev/null +++ b/spm-headers/ggml-cpu.h @@ -0,0 +1 @@ +../ggml/include/ggml-cpu.h \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 46a6ad562..e1b02e4c0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,9 +1,4 @@ -# TODO: should not use this -if (WIN32) - if (BUILD_SHARED_LIBS) - set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) - endif() -endif() +llama_add_compile_flags() # # libraries @@ -14,20 +9,33 @@ endif() add_library(llama ../include/llama.h llama.cpp - llama-vocab.cpp + llama-adapter.cpp + llama-arch.cpp + llama-batch.cpp + llama-chat.cpp + llama-context.cpp llama-grammar.cpp + llama-hparams.cpp + llama-impl.cpp + llama-kv-cache.cpp + llama-mmap.cpp + llama-model-loader.cpp + llama-model.cpp + llama-quant.cpp llama-sampling.cpp + llama-vocab.cpp unicode.h unicode.cpp unicode-data.cpp ) -target_include_directories(llama PUBLIC . ../include) -target_compile_features (llama PUBLIC cxx_std_11) # don't bump +target_include_directories(llama PUBLIC . ../include ../common) +target_compile_features (llama PUBLIC cxx_std_17) # don't bump target_link_libraries(llama PUBLIC ggml) if (BUILD_SHARED_LIBS) set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(llama PRIVATE LLAMA_SHARED LLAMA_BUILD) + target_compile_definitions(llama PRIVATE LLAMA_BUILD) + target_compile_definitions(llama PUBLIC LLAMA_SHARED) endif() diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp new file mode 100644 index 000000000..8a0800463 --- /dev/null +++ b/src/llama-adapter.cpp @@ -0,0 +1,347 @@ +#include "llama-adapter.h" + +#include "llama-impl.h" +#include "llama-mmap.h" +#include "llama-model.h" + +#include +#include +#include +#include + +// vec + +struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { + if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { + return nullptr; + } + + return tensors[il]; +} + +struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { + ggml_tensor * layer_dir = tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx, cur, layer_dir); + } + + return cur; +} + +bool llama_adapter_cvec::init(const llama_model & model) { + const auto & hparams = model.hparams; + + GGML_ASSERT(tensors.empty()); + GGML_ASSERT(ctxs.empty()); + GGML_ASSERT(bufs.empty()); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + struct ggml_init_params params = { + /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // make tensors + tensors.reserve(hparams.n_layer); + tensors.push_back(nullptr); // there's never a tensor for layer 0 + for (size_t il = 1; il < hparams.n_layer; il++) { + ggml_backend_buffer_type_t buft = model.select_buft(il); + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__); + return false; + } + ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + tensors.push_back(tensor); + } + + // allocate tensors / buffers and zero + bufs.reserve(ctx_map.size()); + for (auto it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx = it.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__); + return false; + } + ggml_backend_buffer_clear(buf, 0); + bufs.emplace_back(buf); + } + + return true; +} + +int32_t llama_adapter_cvec::apply( + const llama_model & model, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + const auto & hparams = model.hparams; + + if (data == nullptr) { + // disable the current control vector (but leave allocated for later) + layer_start = -1; + layer_end = -1; + return 0; + } + + if (n_embd != (int) hparams.n_embd) { + LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__); + return 1; + } + + if (tensors.empty()) { + if (!init(model)) { + return 1; + } + } + + layer_start = il_start; + layer_end = il_end; + + for (size_t il = 1; il < hparams.n_layer; il++) { + assert(tensors[il] != nullptr); + + const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present + if (off + n_embd <= len) { + ggml_backend_tensor_set(tensors[il], data + off, 0, n_embd * ggml_element_size(tensors[il])); + } + } + + return 0; +} + +// lora + +llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) { + const std::string name(w->name); + + const auto pos = ab_map.find(name); + if (pos != ab_map.end()) { + return &pos->second; + } + + return nullptr; +} + +static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) { + LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); + + ggml_context * ctx_init; + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ true, + /* .ctx = */ &ctx_init, + }; + + gguf_context_ptr ctx_gguf { gguf_init_from_file(path_lora, meta_gguf_params) }; + if (!ctx_gguf) { + throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora)); + } + + ggml_context_ptr ctx { ctx_init }; + + // check metadata + { + auto get_kv_str = [&](const std::string & key) -> std::string { + int id = gguf_find_key(ctx_gguf.get(), key.c_str()); + return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf.get(), id)); + }; + auto get_kv_f32 = [&](const std::string & key) -> float { + int id = gguf_find_key(ctx_gguf.get(), key.c_str()); + return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf.get(), id); + }; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE)); + if (general_type != "adapter") { + throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type); + } + + auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE)); + auto general_arch = llm_arch_from_string(general_arch_str); + if (general_arch != model.arch) { + throw std::runtime_error("model arch and LoRA arch mismatch"); + } + + auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE)); + if (adapter_type != "lora") { + throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type); + } + + adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA)); + } + + int n_tensors = gguf_get_n_tensors(ctx_gguf.get()); + + // contexts for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + // add a new context + struct ggml_init_params params = { + /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * buft_ctx = ggml_init(params); + if (!buft_ctx) { + return nullptr; + } + ctx_map[buft] = buft_ctx; + adapter.ctxs.emplace_back(buft_ctx); + return buft_ctx; + }; + return it->second; + }; + + // bundle lora_a and lora_b into pairs + std::map ab_map; + auto str_endswith = [](const std::string & str, const std::string & suffix) { + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; + }; + + for (ggml_tensor * cur = ggml_get_first_tensor(ctx.get()); cur; cur = ggml_get_next_tensor(ctx.get(), cur)) { + std::string name(cur->name); + if (str_endswith(name, ".lora_a")) { + replace_all(name, ".lora_a", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_adapter_lora_weight(cur, nullptr); + } else { + ab_map[name].a = cur; + } + } else if (str_endswith(name, ".lora_b")) { + replace_all(name, ".lora_b", ""); + if (ab_map.find(name) == ab_map.end()) { + ab_map[name] = llama_adapter_lora_weight(nullptr, cur); + } else { + ab_map[name].b = cur; + } + } else if (str_endswith(name, "_norm.weight")) { + // TODO: add support for norm vector + // for now, we don't really care because most adapters still work fine without it + continue; + } else { + throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); + } + } + + // add tensors + for (auto & it : ab_map) { + const std::string & name = it.first; + llama_adapter_lora_weight & w = it.second; + bool is_token_embd = str_endswith(name, "token_embd.weight"); + + if (!w.a || !w.b) { + throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); + } + + // device buft and device ctx + const auto * model_tensor = model.get_tensor(name.c_str()); + if (!model_tensor) { + throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); + } + + struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); + // validate tensor shape + if (is_token_embd) { + // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() + if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + } else { + if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + if (w.a->ne[1] != w.b->ne[0]) { + throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + } + } + + // save tensor to adapter + struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a); + struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); + ggml_set_name(tensor_a, w.a->name); + ggml_set_name(tensor_b, w.b->name); + adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b); + } + + // allocate tensors / buffers and zero + { + adapter.ctxs.reserve(ctx_map.size()); + adapter.bufs.reserve(ctx_map.size()); + for (auto & it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx_dev = it.second; + ggml_backend_buffer_ptr buf { ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft) }; + if (!buf) { + throw std::runtime_error("failed to allocate buffer for lora adapter\n"); + } + LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get())/1024.0/1024.0); + adapter.bufs.emplace_back(std::move(buf)); + } + } + + // set tensor data + { + llama_file gguf_file(path_lora, "rb"); + std::vector read_buf; + auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) { + size_t offs = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), gguf_find_tensor(ctx_gguf.get(), orig->name)); + size_t size = ggml_nbytes(orig); + read_buf.resize(size); + gguf_file.seek(offs, SEEK_SET); + gguf_file.read_raw(read_buf.data(), size); + ggml_backend_tensor_set(dev, read_buf.data(), 0, size); + }; + for (auto & it : adapter.ab_map) { + auto orig = ab_map[it.first]; + auto dev = it.second; + set_tensor(orig.a, dev.a); + set_tensor(orig.b, dev.b); + } + } + + LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); +} + +struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) { + struct llama_adapter_lora * adapter = new llama_adapter_lora(); + + try { + llama_adapter_lora_init_impl(*model, path_lora, *adapter); + return adapter; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); + + delete adapter; + } + + return nullptr; +} + +void llama_adapter_lora_free(struct llama_adapter_lora * adapter) { + delete adapter; +} diff --git a/src/llama-adapter.h b/src/llama-adapter.h new file mode 100644 index 000000000..603fa08f6 --- /dev/null +++ b/src/llama-adapter.h @@ -0,0 +1,74 @@ +#pragma once + +#include "llama.h" + +#include "ggml-cpp.h" + +#include +#include +#include + +// TODO: pimpl + +// +// llama_adapter_cvec +// + +struct llama_adapter_cvec { + struct ggml_tensor * tensor_for(int il) const; + + struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; + + int32_t apply( + const llama_model & model, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end); + +private: + bool init(const llama_model & model); + + int32_t layer_start = -1; + int32_t layer_end = -1; + + std::vector ctxs; + std::vector bufs; + + std::vector tensors; // per layer +}; + +// +// llama_adapter_lora +// + +struct llama_adapter_lora_weight { + struct ggml_tensor * a = nullptr; + struct ggml_tensor * b = nullptr; + + // get actual scale based on rank and alpha + float get_scale(float alpha, float adapter_scale) const { + const float rank = (float) b->ne[0]; + const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale; + return scale; + } + + llama_adapter_lora_weight() = default; + llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} +}; + +struct llama_adapter_lora { + // map tensor name to lora_a_b + std::unordered_map ab_map; + + std::vector ctxs; + std::vector bufs; + + float alpha; + + llama_adapter_lora() = default; + ~llama_adapter_lora() = default; + + llama_adapter_lora_weight * get_weight(struct ggml_tensor * w); +}; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp new file mode 100644 index 000000000..a7260f495 --- /dev/null +++ b/src/llama-arch.cpp @@ -0,0 +1,1489 @@ +#include "llama-arch.h" + +#include "llama-impl.h" + +#include + +static const std::map LLM_ARCH_NAMES = { + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_DECI, "deci" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_COHERE2, "cohere2" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMO2, "olmo2" }, + { LLM_ARCH_OLMOE, "olmoe" }, + { LLM_ARCH_OPENELM, "openelm" }, + { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK, "deepseek" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_T5ENCODER, "t5encoder" }, + { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, + { LLM_ARCH_GRANITE, "granite" }, + { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, +}; + +static const std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_TYPE, "general.type" }, + { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, + { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, + { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, + { LLM_KV_GENERAL_NAME, "general.name" }, + { LLM_KV_GENERAL_AUTHOR, "general.author" }, + { LLM_KV_GENERAL_VERSION, "general.version" }, + { LLM_KV_GENERAL_URL, "general.url" }, + { LLM_KV_GENERAL_DESCRIPTION, "general.description" }, + { LLM_KV_GENERAL_LICENSE, "general.license" }, + { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, + { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, + { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, + { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_SWIN_NORM, "%s.swin_norm" }, + { LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" }, + { LLM_KV_TIME_MIX_EXTRA_DIM, "%s.time_mix_extra_dim" }, + { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, + { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, + { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, + { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, + + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, + { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + + { LLM_KV_SPLIT_NO, "split.no" }, + { LLM_KV_SPLIT_COUNT, "split.count" }, + { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" }, + + { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, + { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, + { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, + { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, + + { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, + { LLM_KV_POSNET_BLOCK_COUNT, "%s.posnet.block_count" }, + + { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" }, + { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, +}; + +static const std::map> LLM_TENSOR_NAMES = { + { + LLM_ARCH_LLAMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_DECI, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_BAICHUAN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_FALCON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GROK, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + }, + }, + { + LLM_ARCH_GPT2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_GPTJ, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, + { + LLM_ARCH_GPTNEOX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MPT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output"}, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"}, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"}, + }, + }, + { + LLM_ARCH_STARCODER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_REFACT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_POS_EMBD, "position_embd" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, + }, + }, + { + LLM_ARCH_NOMIC_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_JINA_BERT_V2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_CLS, "cls" }, + }, + }, + { + LLM_ARCH_BLOOM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_STABLELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_QWEN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN2VL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_QWEN2MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { + LLM_ARCH_PHI2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_PHI3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_PHIMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_PLAMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_CODESHELL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_ORION, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_INTERNLM2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MINICPM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + }, + }, + { + LLM_ARCH_MINICPM3, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_GEMMA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GEMMA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, + { + LLM_ARCH_STARCODER2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_MAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, + { + LLM_ARCH_XVERSE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_COMMAND_R, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_COHERE2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_DBRX, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_OLMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_OLMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_OLMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_OPENELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_ARCTIC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_DEEPSEEK, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, + { + LLM_ARCH_DEEPSEEK2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, + { + LLM_ARCH_CHATGLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_BITNET, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, + { + LLM_ARCH_T5, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" }, + { LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" }, + { LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" }, + { LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" }, + { LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" }, + { LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" }, + { LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" }, + { LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" }, + { LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" }, + { LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" }, + { LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" }, + { LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" }, + { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" }, + { LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" }, + { LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" }, + { LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" }, + { LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_T5ENCODER, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_JAIS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_NEMOTRON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_EXAONE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_RWKV6, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_W, "blk.%d.time_mix_lerp_w" }, + { LLM_TENSOR_TIME_MIX_LERP_K, "blk.%d.time_mix_lerp_k" }, + { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, + { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, + { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_LN, "blk.%d.time_mix_ln" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_K, "blk.%d.channel_mix_lerp_k" }, + { LLM_TENSOR_CHANNEL_MIX_LERP_R, "blk.%d.channel_mix_lerp_r" }, + { LLM_TENSOR_CHANNEL_MIX_KEY, "blk.%d.channel_mix_key" }, + { LLM_TENSOR_CHANNEL_MIX_VALUE, "blk.%d.channel_mix_value" }, + { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, + }, + }, + { + LLM_ARCH_RWKV6QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GRANITE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_GRANITE_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_CHAMELEON, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + }, + }, + { + LLM_ARCH_WAVTOKENIZER_DEC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_CONV1D, "conv1d" }, + { LLM_TENSOR_CONVNEXT_DW, "convnext.%d.dw" }, + { LLM_TENSOR_CONVNEXT_NORM, "convnext.%d.norm" }, + { LLM_TENSOR_CONVNEXT_PW1, "convnext.%d.pw1" }, + { LLM_TENSOR_CONVNEXT_PW2, "convnext.%d.pw2" }, + { LLM_TENSOR_CONVNEXT_GAMMA, "convnext.%d.gamma" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_POS_NET_CONV1, "posnet.%d.conv1" }, + { LLM_TENSOR_POS_NET_CONV2, "posnet.%d.conv2" }, + { LLM_TENSOR_POS_NET_NORM, "posnet.%d.norm" }, + { LLM_TENSOR_POS_NET_NORM1, "posnet.%d.norm1" }, + { LLM_TENSOR_POS_NET_NORM2, "posnet.%d.norm2" }, + { LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" }, + { LLM_TENSOR_POS_NET_ATTN_Q, "posnet.%d.attn_q" }, + { LLM_TENSOR_POS_NET_ATTN_K, "posnet.%d.attn_k" }, + { LLM_TENSOR_POS_NET_ATTN_V, "posnet.%d.attn_v" }, + { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, + }, + }, + { + LLM_ARCH_UNKNOWN, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + }, + }, +}; + +static const std::map LLM_TENSOR_INFOS = { + {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, + {LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, + {LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, + {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_CROSS_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DEC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_INP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_GATE_INP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_IN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_DECAY_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_DECAY_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_TIME_MIX_OUTPUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CHANNEL_MIX_KEY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CHANNEL_MIX_VALUE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}}, + {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, + {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, + {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CHANNEL_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_TIME_MIX_LERP_W, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, + {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_CROSS_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_ENC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + // this tensor is loaded for T5, but never used + {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, + {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_CONV1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_CONV2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_DW, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_CONVNEXT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, +}; + +LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} + +std::string LLM_KV::operator()(llm_kv kv) const { + return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) + : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); +} + +std::string LLM_TN_IMPL::str() const { + if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + return "__missing__"; + } + + std::string name = ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid); + + if (suffix != nullptr) { + name += "."; + name += suffix; + } + + return name; +} + +const char * llm_arch_name(llm_arch arch) { + auto it = LLM_ARCH_NAMES.find(arch); + if (it == LLM_ARCH_NAMES.end()) { + return "unknown"; + } + return it->second; +} + +llm_arch llm_arch_from_string(const std::string & name) { + for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT + if (kv.second == name) { + return kv.first; + } + } + + return LLM_ARCH_UNKNOWN; +} + +const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { + return LLM_TENSOR_INFOS.at(tensor); +} diff --git a/src/llama-arch.h b/src/llama-arch.h new file mode 100644 index 000000000..122fdcebe --- /dev/null +++ b/src/llama-arch.h @@ -0,0 +1,402 @@ +#pragma once + +#include "ggml.h" // ggml_op + +#include + +// +// gguf constants (sync with gguf.py) +// + +enum llm_arch { + LLM_ARCH_LLAMA, + LLM_ARCH_DECI, + LLM_ARCH_FALCON, + LLM_ARCH_BAICHUAN, + LLM_ARCH_GROK, + LLM_ARCH_GPT2, + LLM_ARCH_GPTJ, + LLM_ARCH_GPTNEOX, + LLM_ARCH_MPT, + LLM_ARCH_STARCODER, + LLM_ARCH_REFACT, + LLM_ARCH_BERT, + LLM_ARCH_NOMIC_BERT, + LLM_ARCH_JINA_BERT_V2, + LLM_ARCH_BLOOM, + LLM_ARCH_STABLELM, + LLM_ARCH_QWEN, + LLM_ARCH_QWEN2, + LLM_ARCH_QWEN2MOE, + LLM_ARCH_QWEN2VL, + LLM_ARCH_PHI2, + LLM_ARCH_PHI3, + LLM_ARCH_PHIMOE, + LLM_ARCH_PLAMO, + LLM_ARCH_CODESHELL, + LLM_ARCH_ORION, + LLM_ARCH_INTERNLM2, + LLM_ARCH_MINICPM, + LLM_ARCH_MINICPM3, + LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, + LLM_ARCH_STARCODER2, + LLM_ARCH_MAMBA, + LLM_ARCH_XVERSE, + LLM_ARCH_COMMAND_R, + LLM_ARCH_COHERE2, + LLM_ARCH_DBRX, + LLM_ARCH_OLMO, + LLM_ARCH_OLMO2, + LLM_ARCH_OLMOE, + LLM_ARCH_OPENELM, + LLM_ARCH_ARCTIC, + LLM_ARCH_DEEPSEEK, + LLM_ARCH_DEEPSEEK2, + LLM_ARCH_CHATGLM, + LLM_ARCH_BITNET, + LLM_ARCH_T5, + LLM_ARCH_T5ENCODER, + LLM_ARCH_JAIS, + LLM_ARCH_NEMOTRON, + LLM_ARCH_EXAONE, + LLM_ARCH_RWKV6, + LLM_ARCH_RWKV6QWEN2, + LLM_ARCH_GRANITE, + LLM_ARCH_GRANITE_MOE, + LLM_ARCH_CHAMELEON, + LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_UNKNOWN, +}; + +enum llm_kv { + LLM_KV_GENERAL_TYPE, + LLM_KV_GENERAL_ARCHITECTURE, + LLM_KV_GENERAL_QUANTIZATION_VERSION, + LLM_KV_GENERAL_ALIGNMENT, + LLM_KV_GENERAL_NAME, + LLM_KV_GENERAL_AUTHOR, + LLM_KV_GENERAL_VERSION, + LLM_KV_GENERAL_URL, + LLM_KV_GENERAL_DESCRIPTION, + LLM_KV_GENERAL_LICENSE, + LLM_KV_GENERAL_SOURCE_URL, + LLM_KV_GENERAL_SOURCE_HF_REPO, + + LLM_KV_VOCAB_SIZE, + LLM_KV_CONTEXT_LENGTH, + LLM_KV_EMBEDDING_LENGTH, + LLM_KV_FEATURES_LENGTH, + LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, + LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, + LLM_KV_USE_PARALLEL_RESIDUAL, + LLM_KV_TENSOR_DATA_LAYOUT, + LLM_KV_EXPERT_COUNT, + LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_WEIGHTS_SCALE, + LLM_KV_EXPERT_WEIGHTS_NORM, + LLM_KV_EXPERT_GATING_FUNC, + LLM_KV_POOLING_TYPE, + LLM_KV_LOGIT_SCALE, + LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_SWIN_NORM, + LLM_KV_RESCALE_EVERY_N_LAYERS, + LLM_KV_TIME_MIX_EXTRA_DIM, + LLM_KV_TIME_DECAY_EXTRA_DIM, + LLM_KV_RESIDUAL_SCALE, + LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, + + LLM_KV_ATTENTION_HEAD_COUNT, + LLM_KV_ATTENTION_HEAD_COUNT_KV, + LLM_KV_ATTENTION_MAX_ALIBI_BIAS, + LLM_KV_ATTENTION_CLAMP_KQV, + LLM_KV_ATTENTION_KEY_LENGTH, + LLM_KV_ATTENTION_VALUE_LENGTH, + LLM_KV_ATTENTION_LAYERNORM_EPS, + LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + LLM_KV_ATTENTION_GROUPNORM_EPS, + LLM_KV_ATTENTION_GROUPNORM_GROUPS, + LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, + LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, + LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SCALE, + + LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_DIMENSION_SECTIONS, + LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_SCALING_TYPE, + LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, + LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, + LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_ROPE_SCALING_YARN_LOG_MUL, + + LLM_KV_SPLIT_NO, + LLM_KV_SPLIT_COUNT, + LLM_KV_SPLIT_TENSORS_COUNT, + + LLM_KV_SSM_INNER_SIZE, + LLM_KV_SSM_CONV_KERNEL, + LLM_KV_SSM_STATE_SIZE, + LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_DT_B_C_RMS, + + LLM_KV_WKV_HEAD_SIZE, + + LLM_KV_TOKENIZER_MODEL, + LLM_KV_TOKENIZER_PRE, + LLM_KV_TOKENIZER_LIST, + LLM_KV_TOKENIZER_TOKEN_TYPE, + LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, + LLM_KV_TOKENIZER_SCORES, + LLM_KV_TOKENIZER_MERGES, + LLM_KV_TOKENIZER_BOS_ID, + LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOM_ID, + LLM_KV_TOKENIZER_UNK_ID, + LLM_KV_TOKENIZER_SEP_ID, + LLM_KV_TOKENIZER_PAD_ID, + LLM_KV_TOKENIZER_CLS_ID, + LLM_KV_TOKENIZER_MASK_ID, + LLM_KV_TOKENIZER_ADD_BOS, + LLM_KV_TOKENIZER_ADD_EOS, + LLM_KV_TOKENIZER_ADD_PREFIX, + LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, + LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, + LLM_KV_TOKENIZER_HF_JSON, + LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, + LLM_KV_TOKENIZER_FIM_PRE_ID, + LLM_KV_TOKENIZER_FIM_SUF_ID, + LLM_KV_TOKENIZER_FIM_MID_ID, + LLM_KV_TOKENIZER_FIM_PAD_ID, + LLM_KV_TOKENIZER_FIM_REP_ID, + LLM_KV_TOKENIZER_FIM_SEP_ID, + + LLM_KV_ADAPTER_TYPE, + LLM_KV_ADAPTER_LORA_ALPHA, + + LLM_KV_POSNET_EMBEDDING_LENGTH, + LLM_KV_POSNET_BLOCK_COUNT, + + LLM_KV_CONVNEXT_EMBEDDING_LENGTH, + LLM_KV_CONVNEXT_BLOCK_COUNT, + + // deprecated: + LLM_KV_TOKENIZER_PREFIX_ID, + LLM_KV_TOKENIZER_SUFFIX_ID, + LLM_KV_TOKENIZER_MIDDLE_ID, +}; + +enum llm_tensor { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_TOKEN_TYPES, + LLM_TENSOR_POS_EMBD, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_NORM_2, + LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_ROT_EMBD, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_ACT, + LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility + LLM_TENSOR_FFN_GATE_EXP, + LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_NORM_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, // merged experts + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_SSM_IN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_X, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_OUT, + LLM_TENSOR_TIME_MIX_W1, + LLM_TENSOR_TIME_MIX_W2, + LLM_TENSOR_TIME_MIX_LERP_X, + LLM_TENSOR_TIME_MIX_LERP_W, + LLM_TENSOR_TIME_MIX_LERP_K, + LLM_TENSOR_TIME_MIX_LERP_V, + LLM_TENSOR_TIME_MIX_LERP_R, + LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_LERP_FUSED, + LLM_TENSOR_TIME_MIX_FIRST, + LLM_TENSOR_TIME_MIX_DECAY, + LLM_TENSOR_TIME_MIX_DECAY_W1, + LLM_TENSOR_TIME_MIX_DECAY_W2, + LLM_TENSOR_TIME_MIX_KEY, + LLM_TENSOR_TIME_MIX_VALUE, + LLM_TENSOR_TIME_MIX_RECEPTANCE, + LLM_TENSOR_TIME_MIX_GATE, + LLM_TENSOR_TIME_MIX_LN, + LLM_TENSOR_TIME_MIX_OUTPUT, + LLM_TENSOR_CHANNEL_MIX_LERP_K, + LLM_TENSOR_CHANNEL_MIX_LERP_R, + LLM_TENSOR_CHANNEL_MIX_KEY, + LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, + LLM_TENSOR_CHANNEL_MIX_VALUE, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_DEC_ATTN_NORM, + LLM_TENSOR_DEC_ATTN_Q, + LLM_TENSOR_DEC_ATTN_K, + LLM_TENSOR_DEC_ATTN_V, + LLM_TENSOR_DEC_ATTN_OUT, + LLM_TENSOR_DEC_ATTN_REL_B, + LLM_TENSOR_DEC_CROSS_ATTN_NORM, + LLM_TENSOR_DEC_CROSS_ATTN_Q, + LLM_TENSOR_DEC_CROSS_ATTN_K, + LLM_TENSOR_DEC_CROSS_ATTN_V, + LLM_TENSOR_DEC_CROSS_ATTN_OUT, + LLM_TENSOR_DEC_CROSS_ATTN_REL_B, + LLM_TENSOR_DEC_FFN_NORM, + LLM_TENSOR_DEC_FFN_GATE, + LLM_TENSOR_DEC_FFN_DOWN, + LLM_TENSOR_DEC_FFN_UP, + LLM_TENSOR_DEC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CONV1D, + LLM_TENSOR_CONVNEXT_DW, + LLM_TENSOR_CONVNEXT_NORM, + LLM_TENSOR_CONVNEXT_PW1, + LLM_TENSOR_CONVNEXT_PW2, + LLM_TENSOR_CONVNEXT_GAMMA, + LLM_TENSOR_POS_NET_CONV1, + LLM_TENSOR_POS_NET_CONV2, + LLM_TENSOR_POS_NET_NORM, + LLM_TENSOR_POS_NET_NORM1, + LLM_TENSOR_POS_NET_NORM2, + LLM_TENSOR_POS_NET_ATTN_NORM, + LLM_TENSOR_POS_NET_ATTN_Q, + LLM_TENSOR_POS_NET_ATTN_K, + LLM_TENSOR_POS_NET_ATTN_V, + LLM_TENSOR_POS_NET_ATTN_OUT, +}; + +enum llm_tensor_layer { + LLM_TENSOR_LAYER_INPUT, + LLM_TENSOR_LAYER_REPEATING, + LLM_TENSOR_LAYER_OUTPUT, +}; + +struct LLM_KV { + LLM_KV(llm_arch arch, const char * suffix = nullptr); + + llm_arch arch; + const char * suffix; + + std::string operator()(llm_kv kv) const; +}; + +// helper to handle gguf constants +// usage: +// +// const auto tn = LLM_TN(LLM_ARCH_LLAMA); +// +// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" +// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" +// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" +// +struct LLM_TN_IMPL { + const llm_arch arch; + const llm_tensor tensor; + const char * const suffix; + const int bid; + const int xid; + + std::string str() const; + + operator std::string() const { + return str(); + } + + friend bool operator==(const std::string & str, const LLM_TN_IMPL & tn) { + return str == tn.str(); + } + + friend bool operator!=(const std::string & str, const LLM_TN_IMPL & tn) { + return str != tn.str(); + } +}; + +struct LLM_TN { + LLM_TN(llm_arch arch) : arch(arch) {} + + llm_arch arch; + + LLM_TN_IMPL operator()(llm_tensor tensor, const char * suffix, int bid = -1, int xid = -1) const { + return { arch, tensor, suffix, bid, xid }; + } + + LLM_TN_IMPL operator()(llm_tensor tensor, int bid = -1, int xid = -1) const { + return { arch, tensor, nullptr, bid, xid }; + } +}; + + +struct llm_tensor_info { + llm_tensor_layer layer; + ggml_op op; +}; + +const char * llm_arch_name(llm_arch arch); + +llm_arch llm_arch_from_string(const std::string & name); + +const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp new file mode 100644 index 000000000..01d5ca57f --- /dev/null +++ b/src/llama-batch.cpp @@ -0,0 +1,368 @@ +#include "llama-batch.h" + +#include +#include + +llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { + // clear empty sequences + // the previous ubatch is assumed to be gone, + // so nothing should refer to values in these sequences anymore. + for (size_t i = seq.size(); i-- > 0;) { + if (seq[i].length == 0) { + seq.pop_back(); + } else { + break; + } + } + ubatch_token.resize(!has_embd ? n_ubatch : 0); + ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); + ubatch_pos.resize(n_ubatch); + ubatch_n_seq_id.resize(n_ubatch); + ubatch_seq_id.resize(n_ubatch); + ubatch_output.resize(n_ubatch); + llama_ubatch ubatch = { + /*equal_seqs =*/ true, + /*n_tokens =*/ 0, + /*n_seq_tokens =*/ 0, + /*n_seqs =*/ 0, + /*token =*/ !has_embd ? ubatch_token.data() : nullptr, + /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, + /*pos =*/ ubatch_pos.data(), + /*n_seq_id =*/ ubatch_n_seq_id.data(), + /*seq_id =*/ ubatch_seq_id.data(), + /*output =*/ ubatch_output.data(), + }; + return ubatch; +} + +void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { + GGML_ASSERT(batch != nullptr); + GGML_ASSERT(length <= seq.length); + // Can only add sequences of equal lengths to a batch, + // otherwise it isn't clear to which sequence a token belongs + GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); + GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); + // NOTE: loops are separated for cache-friendliness + if (batch->token) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.token = batch->token + seq.offset; + } + } else { + ubatch.token = nullptr; + } + if (batch->embd) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + memcpy( + ubatch.embd + (n_embd * (ubatch.n_tokens + i)), + batch->embd + (n_embd * ids[seq.offset + i]), + n_embd * sizeof(float) + ); + } + } else { + // simple split + ubatch.embd = batch->embd + (n_embd * seq.offset); + } + } else { + ubatch.embd = nullptr; + } + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; + } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; + } + if (ubatch.equal_seqs) { + ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; + if (seq.seq_id) { + ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; + } + } else { + // simple split + if (batch->n_seq_id) { + ubatch.n_seq_id = batch->n_seq_id + seq.offset; + } else { + for (size_t i = 0; i < length; ++i) { + ubatch.n_seq_id[ubatch.n_seqs + i] = 1; + } + } + if (batch->seq_id) { + ubatch.seq_id = batch->seq_id + seq.offset; + } + } + if (logits_all) { + for (size_t i = 0; i < length; ++i) { + ubatch.output[ubatch.n_tokens + i] = 1; + out_ids.push_back(ids[seq.offset + i]); + } + } else if (batch->logits) { + if (ubatch.equal_seqs) { + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_output = batch->logits[id]; + ubatch.output[ubatch.n_tokens + i] = is_output; + if (is_output) { out_ids.push_back(id); } + } + } else { + // simple split + ubatch.output = batch->logits + seq.offset; + for (size_t i = 0; i < length; ++i) { + if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } + } + } + } else { + // only get last output + for (size_t i = 0; i < length; ++i) { + size_t id = ids[seq.offset + i]; + int8_t is_last = id == ids.size() - 1; + ubatch.output[ubatch.n_tokens + i] = is_last; + if (is_last) { out_ids.push_back(id); } + } + } + if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { + ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; + } + ubatch.n_tokens += length; + ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits + seq.offset += length; + seq.length -= length; + n_tokens -= length; + GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); +} + +llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + ubatch.equal_seqs = false; + if (!seq.empty()) { + llama_sbatch_seq & s = seq[0]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; +} + +llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + size_t length = 0; + size_t n_tokens_in_ubatch = 0; + GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits + // smallest first, because it's easier to split this way; + // starting from the end to pop in constant time. + for (size_t i = seq.size(); i-- > 0;) { + llama_sbatch_seq & s = seq[i]; + GGML_ASSERT(s.length > 0); + if (length == 0) { + length = s.length < n_ubatch ? s.length : n_ubatch; + } + add_seq_to_ubatch(ubatch, s, length); + n_tokens_in_ubatch += length; + // shared prompts can't be mixed with any of their sequences, + // so it's safer to compute them in their own ubatch + if (s.n_seq_id > 1) { break; } + // stop when there isn't enough space for another sequence + if (length + n_tokens_in_ubatch > n_ubatch) { break; } + } + } + return ubatch; +} + +llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { + n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + if (!seq.empty()) { + llama_sbatch_seq & s = seq[seq.size() - 1]; + size_t length = s.length < n_ubatch ? s.length : n_ubatch; + GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits + add_seq_to_ubatch(ubatch, s, length); + } + return ubatch; +} + +void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { + GGML_ASSERT(batch.n_tokens >= 0); + this->batch = &batch; + this->n_embd = n_embd; + this->logits_all = logits_all; + + n_tokens = batch.n_tokens; + ids.resize(n_tokens); + out_ids.clear(); + // TODO: reserve out_ids and seq + + for (size_t i = 0; i < n_tokens; ++i) { + ids[i] = i; + } + if (simple_split) { + seq.resize(1); + llama_sbatch_seq & s = seq[0]; + s.n_seq_id = 0; + s.seq_id = nullptr; + s.offset = 0; + s.length = n_tokens; + return; + } + std::sort(ids.begin(), ids.end(), + [&batch](size_t a, size_t b) { + int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; + int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; + // sort by seq_id, then by pos + if (n_seq_a == n_seq_b) { + if (batch.seq_id) { + for (int32_t i = 0; i < n_seq_a; ++i) { + llama_seq_id seq_id_a = batch.seq_id[a][i]; + llama_seq_id seq_id_b = batch.seq_id[b][i]; + // smaller seq_ids go first + if (seq_id_a != seq_id_b) { + return seq_id_a < seq_id_b; + } + } + } + // when all else is equal, sort by pos + if (batch.pos) { + return batch.pos[a] < batch.pos[b]; + } + // no pos, sort by id + return a < b; + } + // shared prompts go first + return n_seq_a > n_seq_b; + } + ); + // init seq + llama_sbatch_seq * last_seq = nullptr; + + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; + } + } + if (same) { + last_seq->length += 1; + continue; + } + } + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; + seq.push_back(new_seq); + last_seq = &seq.back(); + } + // keep shared prompts first at the end, then sort by length descending. + std::sort(seq.begin(), seq.end(), + [](llama_sbatch_seq & a, llama_sbatch_seq & b) { + if (a.n_seq_id == b.n_seq_id) { + return a.length > b.length; + } + return a.n_seq_id < b.n_seq_id; + } + ); +} + +llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) { + batch = in_batch; + GGML_ASSERT(batch.n_tokens > 0); + if (!batch.pos) { + pos.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + pos[i] = i + p0; + } + batch.pos = pos.data(); + } + if (!batch.n_seq_id) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); + } + batch.n_seq_id = n_seq_id.data(); + } + if (!batch.seq_id) { + seq_id.resize(batch.n_tokens + 1); + seq_id[batch.n_tokens] = NULL; + for (int32_t i = 0; i < batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + if (!batch.logits) { + logits.resize(batch.n_tokens); + logits[logits.size() - 1] = true; + batch.logits = logits.data(); + } +} + +// +// interface implementation +// + +struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens) { + return { + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; +} + +struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; + + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + } + + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + for (int i = 0; i < n_tokens_alloc; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens_alloc] = nullptr; + + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + + return batch; +} + +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i] != nullptr; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} diff --git a/src/llama-batch.h b/src/llama-batch.h new file mode 100644 index 000000000..773c3808b --- /dev/null +++ b/src/llama-batch.h @@ -0,0 +1,88 @@ +#pragma once + +#include "llama.h" + +#include +#include + +// very similar to llama_batch, +// but has more metadata about sequences +struct llama_ubatch { + bool equal_seqs; + // TODO: whole_seqs for embeddings? + + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_seq_tokens; // tokens per sequence + uint32_t n_seqs; + + llama_token * token; // [n_tokens] + float * embd; // [n_embd, n_tokens] + llama_pos * pos; // [n_tokens] + int32_t * n_seq_id; // [n_seqs] + llama_seq_id ** seq_id; // [n_seqs] + int8_t * output; // [n_tokens] +}; + +struct llama_sbatch_seq { + int32_t n_seq_id; + + llama_seq_id * seq_id; + + size_t offset; + size_t length; +}; + +// sequence-length-aware batch splitting +struct llama_sbatch { + // tokens left in this batch + size_t n_tokens; + + size_t n_embd; + + bool logits_all; // TODO: remove once lctx.logits_all is removed too + + // sorted indices into the batch + std::vector ids; + // batch indices of the output + std::vector out_ids; + std::vector seq; + + const llama_batch * batch = nullptr; + + // buffers for the ubatch + std::vector ubatch_token; + std::vector ubatch_embd; + std::vector ubatch_pos; + std::vector ubatch_n_seq_id; + std::vector ubatch_seq_id; + std::vector ubatch_output; + + llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); + + void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length); + + // simple split, unknown number of sequences of unequal lengths + llama_ubatch split_simple(size_t n_ubatch); + + // make batches of equal-length sequences + llama_ubatch split_equal(size_t n_ubatch); + + // sequence-wise split + llama_ubatch split_seq(size_t n_ubatch); + + void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); +}; + +// temporary allocate memory for the input batch if needed +struct llama_batch_allocr { + struct llama_batch batch; + + std::array seq_id_0 = { 0 }; // default sequence id + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector logits; + + // optionally fulfill the batch returned by llama_batch_get_one + llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); +}; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp new file mode 100644 index 000000000..5c19bab24 --- /dev/null +++ b/src/llama-chat.cpp @@ -0,0 +1,578 @@ +#include "llama-chat.h" + +#include "llama.h" + +#include +#include + +#if __cplusplus >= 202000L + #define LU8(x) (const char*)(u8##x) +#else + #define LU8(x) u8##x +#endif + +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && isspace(str[start])) { + start += 1; + } + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + return str.substr(start, end - start); +} + +static const std::map LLM_CHAT_TEMPLATES = { + { "chatml", LLM_CHAT_TEMPLATE_CHATML }, + { "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 }, + { "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS }, + { "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS }, + { "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP }, + { "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 }, + { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, + { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, + { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, + { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, + { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, + { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, + { "monarch", LLM_CHAT_TEMPLATE_MONARCH }, + { "gemma", LLM_CHAT_TEMPLATE_GEMMA }, + { "orion", LLM_CHAT_TEMPLATE_ORION }, + { "openchat", LLM_CHAT_TEMPLATE_OPENCHAT }, + { "vicuna", LLM_CHAT_TEMPLATE_VICUNA }, + { "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA }, + { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, + { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, + { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, + { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, + { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, + { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, + { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, + { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, + { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE }, + { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, + { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, +}; + +llm_chat_template llm_chat_template_from_str(const std::string & name) { + return LLM_CHAT_TEMPLATES.at(name); +} + +llm_chat_template llm_chat_detect_template(const std::string & tmpl) { + try { + return llm_chat_template_from_str(tmpl); + } catch (const std::out_of_range &) { + // ignore + } + + auto tmpl_contains = [&tmpl](const char * haystack) -> bool { + return tmpl.find(haystack) != std::string::npos; + }; + if (tmpl_contains("<|im_start|>")) { + return tmpl_contains("<|im_sep|>") + ? LLM_CHAT_TEMPLATE_PHI_4 + : LLM_CHAT_TEMPLATE_CHATML; + } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { + if (tmpl_contains("[SYSTEM_PROMPT]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V7; + } else if ( + // catches official 'v1' template + tmpl_contains("' [INST] ' + system_message") + // catches official 'v3' and 'v3-tekken' templates + || tmpl_contains("[AVAILABLE_TOOLS]") + ) { + // Official mistral 'v1', 'v3' and 'v3-tekken' templates + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + if (tmpl_contains(" [INST]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V1; + } else if (tmpl_contains("\"[INST]\"")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN; + } + return LLM_CHAT_TEMPLATE_MISTRAL_V3; + } else { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl_contains("<>"); + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + bool strip_message = tmpl_contains("content.strip()"); + if (strip_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + } else if (add_bos_inside_history) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + } else if (support_system_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS; + } else { + return LLM_CHAT_TEMPLATE_LLAMA_2; + } + } + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { + return LLM_CHAT_TEMPLATE_PHI_3; + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + return LLM_CHAT_TEMPLATE_FALCON_3; + } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { + return LLM_CHAT_TEMPLATE_ZEPHYR; + } else if (tmpl_contains("bos_token + message['role']")) { + return LLM_CHAT_TEMPLATE_MONARCH; + } else if (tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_GEMMA; + } else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + // OrionStarAI/Orion-14B-Chat + return LLM_CHAT_TEMPLATE_ORION; + } else if (tmpl_contains("GPT4 Correct ")) { + // openchat/openchat-3.5-0106 + return LLM_CHAT_TEMPLATE_OPENCHAT; + } else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + if (tmpl_contains("SYSTEM: ")) { + return LLM_CHAT_TEMPLATE_VICUNA_ORCA; + } + return LLM_CHAT_TEMPLATE_VICUNA; + } else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) { + // deepseek-ai/deepseek-coder-33b-instruct + return LLM_CHAT_TEMPLATE_DEEPSEEK; + } else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) { + // CohereForAI/c4ai-command-r-plus + return LLM_CHAT_TEMPLATE_COMMAND_R; + } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) { + return LLM_CHAT_TEMPLATE_LLAMA_3; + } else if (tmpl_contains("[gMASK]sop")) { + // chatglm3-6b + return LLM_CHAT_TEMPLATE_CHATGML_3; + } else if (tmpl_contains("[gMASK]")) { + return LLM_CHAT_TEMPLATE_CHATGML_4; + } else if (tmpl_contains(LU8("<用户>"))) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + return LLM_CHAT_TEMPLATE_MINICPM; + } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_2; + } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_3; + } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + return LLM_CHAT_TEMPLATE_EXAONE_3; + } else if (tmpl_contains("rwkv-world")) { + return LLM_CHAT_TEMPLATE_RWKV_WORLD; + } else if (tmpl_contains("<|start_of_role|>")) { + return LLM_CHAT_TEMPLATE_GRANITE; + } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { + return LLM_CHAT_TEMPLATE_GIGACHAT; + } else if (tmpl_contains("<|role_start|>")) { + return LLM_CHAT_TEMPLATE_MEGREZ; + } + return LLM_CHAT_TEMPLATE_UNKNOWN; +} + +// Simple version of "llama_apply_chat_template" that only works with strings +// This function uses heuristic checks to determine commonly used template. It is not a jinja parser. +int32_t llm_chat_apply_template( + llm_chat_template tmpl, + const std::vector & chat, + std::string & dest, bool add_ass) { + // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + std::stringstream ss; + if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; + } + if (add_ass) { + ss << "<|im_start|>assistant\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { + // Official mistral 'v7' template + // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + for (auto message : chat) { + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; + } else if (role == "user") { + ss << "[INST] " << content << "[/INST]"; + } + else { + ss << " " << content << "
"; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) { + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : ""; + std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " "; + bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3; + bool is_inside_turn = false; + for (auto message : chat) { + if (!is_inside_turn) { + ss << leading_space << "[INST]" << trailing_space; + is_inside_turn = true; + } + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << content << "\n\n"; + } else if (role == "user") { + ss << content << leading_space << "[/INST]"; + } else { + ss << trailing_space << (trim_assistant_message ? trim(content) : content) << "
"; + is_inside_turn = false; + } + } + } else if ( + tmpl == LLM_CHAT_TEMPLATE_LLAMA_2 + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2; + // [variant] add BOS inside history + bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + // [variant] trim spaces from the input message + bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + // construct the prompt + bool is_inside_turn = true; // skip BOS at the beginning + ss << "[INST] "; + for (auto message : chat) { + std::string content = strip_message ? trim(message->content) : message->content; + std::string role(message->role); + if (!is_inside_turn) { + is_inside_turn = true; + ss << (add_bos_inside_history ? "[INST] " : "[INST] "); + } + if (role == "system") { + if (support_system_message) { + ss << "<>\n" << content << "\n<>\n\n"; + } else { + // if the model does not support system message, we still include it in the first message, but without <> + ss << content << "\n"; + } + } else if (role == "user") { + ss << content << " [/INST]"; + } else { + ss << content << ""; + is_inside_turn = false; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) { + // Phi 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "<|end|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_4) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "<|im_sep|>" << message->content << "<|im_end|>"; + } + if (add_ass) { + ss << "<|im_start|>assistant<|im_sep|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) { + // Falcon 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) { + // zephyr template + for (auto message : chat) { + ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) { + // mlabonne/AlphaMonarch-7B template (the is included inside history) + for (auto message : chat) { + std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message + ss << bos << message->role << "\n" << message->content << "\n"; + } + if (add_ass) { + ss << "assistant\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) { + // google/gemma-7b-it + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken + system_prompt = trim(message->content); + continue; + } + // in gemma, "assistant" is "model" + role = role == "assistant" ? "model" : message->role; + ss << "" << role << "\n"; + if (!system_prompt.empty() && role != "model") { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << trim(message->content) << "\n"; + } + if (add_ass) { + ss << "model\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_ORION) { + // OrionStarAI/Orion-14B-Chat + std::string system_prompt = ""; + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // there is no system message support, we will merge it with user prompt + system_prompt = message->content; + continue; + } else if (role == "user") { + ss << "Human: "; + if (!system_prompt.empty()) { + ss << system_prompt << "\n\n"; + system_prompt = ""; + } + ss << message->content << "\n\nAssistant: "; + } else { + ss << message->content << ""; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) { + // openchat/openchat-3.5-0106, + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "<|end_of_turn|>"; + } else { + role[0] = toupper(role[0]); + ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>"; + } + } + if (add_ass) { + ss << "GPT4 Correct Assistant:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + // Orca-Vicuna variant uses a system prefix + if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { + ss << "SYSTEM: " << message->content << "\n"; + } else { + ss << message->content << "\n\n"; + } + } else if (role == "user") { + ss << "USER: " << message->content << "\n"; + } else if (role == "assistant") { + ss << "ASSISTANT: " << message->content << "\n"; + } + } + if (add_ass) { + ss << "ASSISTANT:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) { + // deepseek-ai/deepseek-coder-33b-instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content; + } else if (role == "user") { + ss << "### Instruction:\n" << message->content << "\n"; + } else if (role == "assistant") { + ss << "### Response:\n" << message->content << "\n<|EOT|>\n"; + } + } + if (add_ass) { + ss << "### Response:\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) { + // CohereForAI/c4ai-command-r-plus + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "user") { + ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } else if (role == "assistant") { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>"; + } + } + if (add_ass) { + ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) { + // Llama 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>"; + } + if (add_ass) { + ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) { + // chatglm3-6b + ss << "[gMASK]" << "sop"; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n " << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) { + ss << "[gMASK]" << ""; + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << LU8("<用户>"); + ss << trim(message->content); + ss << ""; + } else { + ss << trim(message->content); + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) { + // DeepSeek-V2 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << "User: " << message->content << "\n\n"; + } else if (role == "assistant") { + ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << "Assistant:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_3) { + // DeepSeek-V3 + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << LU8("<|User|>") << message->content; + } else if (role == "assistant") { + ss << LU8("<|Assistant|>") << message->content << LU8("<|end▁of▁sentence|>"); + } + } + if (add_ass) { + ss << LU8("<|Assistant|>"); + } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "user") { + ss << "[|user|]" << trim(message->content) << "\n"; + } else if (role == "assistant") { + ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n"; + } + } + if (add_ass) { + ss << "[|assistant|]"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { + // this template requires the model to have "\n\n" as EOT token + for (auto message : chat) { + std::string role(message->role); + if (role == "user") { + ss << "User: " << message->content << "\n\nAssistant:"; + } else { + ss << message->content << "\n\n"; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { + // IBM Granite template + for (const auto & message : chat) { + std::string role(message->role); + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + if (role == "assistant_tool_call") { + ss << "<|tool_call|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { + // GigaChat template + bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; + + // Handle system message if present + if (has_system) { + ss << "" << chat[0]->content << "<|message_sep|>"; + } else { + ss << ""; + } + + // Process remaining messages + for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "user") { + ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>" + << "available functions<|role_sep|>[]<|message_sep|>"; + } else if (role == "assistant") { + ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>"; + } + } + + // Add generation prompt if needed + if (add_ass) { + ss << "assistant<|role_sep|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MEGREZ) { + // Megrez template + for (auto message : chat) { + std::string role(message->role); + ss << "<|role_start|>" << role << "<|role_end|>" << message->content << "<|turn_end|>"; + } + + if (add_ass) { + ss << "<|role_start|>assistant<|role_end|>"; + } + } else { + // template not supported + return -1; + } + dest = ss.str(); + return dest.size(); +} + +// public interface + +int32_t llama_chat_builtin_templates(const char ** output, size_t len) { + auto it = LLM_CHAT_TEMPLATES.begin(); + for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) { + output[i] = it->first.c_str(); + std::advance(it, 1); + } + return (int32_t) LLM_CHAT_TEMPLATES.size(); +} + diff --git a/src/llama-chat.h b/src/llama-chat.h new file mode 100644 index 000000000..3a4d07ce3 --- /dev/null +++ b/src/llama-chat.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include + +enum llm_chat_template { + LLM_CHAT_TEMPLATE_CHATML, + LLM_CHAT_TEMPLATE_LLAMA_2, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP, + LLM_CHAT_TEMPLATE_MISTRAL_V1, + LLM_CHAT_TEMPLATE_MISTRAL_V3, + LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, + LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_PHI_3, + LLM_CHAT_TEMPLATE_PHI_4, + LLM_CHAT_TEMPLATE_FALCON_3, + LLM_CHAT_TEMPLATE_ZEPHYR, + LLM_CHAT_TEMPLATE_MONARCH, + LLM_CHAT_TEMPLATE_GEMMA, + LLM_CHAT_TEMPLATE_ORION, + LLM_CHAT_TEMPLATE_OPENCHAT, + LLM_CHAT_TEMPLATE_VICUNA, + LLM_CHAT_TEMPLATE_VICUNA_ORCA, + LLM_CHAT_TEMPLATE_DEEPSEEK, + LLM_CHAT_TEMPLATE_DEEPSEEK_2, + LLM_CHAT_TEMPLATE_DEEPSEEK_3, + LLM_CHAT_TEMPLATE_COMMAND_R, + LLM_CHAT_TEMPLATE_LLAMA_3, + LLM_CHAT_TEMPLATE_CHATGML_3, + LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_MINICPM, + LLM_CHAT_TEMPLATE_EXAONE_3, + LLM_CHAT_TEMPLATE_RWKV_WORLD, + LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_GIGACHAT, + LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_UNKNOWN, +}; + +struct llama_chat_message; + +llm_chat_template llm_chat_template_from_str(const std::string & name); + +llm_chat_template llm_chat_detect_template(const std::string & tmpl); + +int32_t llm_chat_apply_template( + llm_chat_template tmpl, + const std::vector & chat, + std::string & dest, bool add_ass); diff --git a/src/llama-context.cpp b/src/llama-context.cpp new file mode 100644 index 000000000..671d2a81a --- /dev/null +++ b/src/llama-context.cpp @@ -0,0 +1,1775 @@ +#include "llama-context.h" + +#include "llama-impl.h" +#include "llama-mmap.h" + +#include +#include +#include +#include + +void llama_set_k_shift(struct llama_context & lctx) { + const int64_t kv_size = lctx.kv_self.size; + + assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); + + int32_t * data = (int32_t *) lctx.inp_K_shift->data; + + for (int i = 0; i < kv_size; ++i) { + data[i] = lctx.kv_self.cells[i].delta; + } +} + +void llama_set_s_copy(struct llama_context & lctx) { + const int64_t kv_size = lctx.kv_self.size; + + assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + for (int i = 0; i < kv_size; ++i) { + data[i] = lctx.kv_self.cells[i].src; + } +} + +// llama input + +static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + n_buckets >>= 1; + } + + const int64_t max_exact = n_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + if (bidirectional) { + relative_bucket += (relative_position > 0) * n_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + return relative_bucket; +} + +void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { + // + // set input data + // + + const auto & hparams = lctx.model.hparams; + const auto & cparams = lctx.cparams; + const auto & kv_self = lctx.kv_self; + + if (ubatch.token) { + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); + } + + if (ubatch.embd) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_tokens = ubatch.n_tokens; + + ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd)); + } + + if (ubatch.pos && lctx.inp_pos) { + const int64_t n_tokens = ubatch.n_tokens; + auto n_pos = lctx.n_pos_per_token; + ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(lctx.inp_pos)); + } + + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + //GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); + + if (!lctx.inp_out_ids) { + LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__); + } else { + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); + int32_t * data = (int32_t *) lctx.inp_out_ids->data; + + if (lctx.n_outputs == n_tokens) { + for (int i = 0; i < n_tokens; ++i) { + data[i] = i; + } + } else if (ubatch.output) { + int32_t n_outputs = 0; + for (int i = 0; i < n_tokens; ++i) { + if (ubatch.output[i]) { + data[n_outputs++] = i; + } + } + // the graph needs to have been passed the correct number of outputs + GGML_ASSERT(lctx.n_outputs == n_outputs); + } else if (lctx.n_outputs == 1) { + // only keep last output + data[0] = n_tokens - 1; + } else { + GGML_ASSERT(lctx.n_outputs == 0); + } + } + } + + GGML_ASSERT( + // (!a || b) is a logical implication (a -> b) + // !hparams.causal_attn -> !cparams.causal_attn + (hparams.causal_attn || !cparams.causal_attn) && + "causal attention is not supported by this model" + ); + + if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) { + // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. + if (cparams.causal_attn && !lctx.is_encoding) { + const int64_t n_kv = kv_self.n; + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + + float * data = nullptr; + float * data_swa = nullptr; + + if (lctx.inp_KQ_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + data = (float *) lctx.inp_KQ_mask->data; + } + + if (lctx.inp_KQ_mask_swa) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer)); + data_swa = (float *) lctx.inp_KQ_mask_swa->data; + } + + // For causal attention, use only the previous KV cells + // of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + for (int h = 0; h < 1; ++h) { + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + j]; + + for (int i = 0; i < n_kv; ++i) { + float f; + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + f = -INFINITY; + } else { + if (hparams.use_alibi) { + f = -std::abs(kv_self.cells[i].pos - pos); + } else { + f = 0.0f; + } + } + + if (data) { + data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + + // may need to cut off old tokens for sliding window + if (data_swa) { + if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) { + f = -INFINITY; + } + data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; + } + } + } + } + + if (data) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + + if (data_swa) { + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + } + } else { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + // when using kv cache, the mask needs to match the kv cache size + const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); + + float * data = (float *) lctx.inp_KQ_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int s1 = 0; s1 < n_seqs; ++s1) { + const llama_seq_id seq_id = ubatch.seq_id[s1][0]; + + for (int j = 0; j < n_seq_tokens; ++j) { + const int32_t tj = s1*n_seq_tokens + j; + + for (int s0 = 0; s0 < n_seqs; ++s0) { + for (int i = 0; i < n_seq_tokens; ++i) { + const int32_t ti = s0*n_seq_tokens + i; + float f = -INFINITY; + + for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) { + if (ubatch.seq_id[s0][s] == seq_id) { + if (hparams.use_alibi) { + f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]); + } else { + f = 0.0f; + } + break; + } + } + + data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; + } + } + + for (int i = n_tokens; i < n_stride; ++i) { + data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; + } + } + } + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(lctx.inp_mean); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer)); + + float * data = (float *) lctx.inp_mean->data; + memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean)); + + std::vector sum(n_tokens, 0); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN"); + + sum[seq_id] += ubatch.n_seq_tokens; + } + + std::vector div(n_tokens, 0.0f); + for (int i = 0; i < n_tokens; ++i) { + const uint64_t s = sum[i]; + if (s > 0) { + div[i] = 1.0f/float(s); + } + } + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + for (int i = 0; i < n_seq_tokens; ++i) { + data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id]; + } + } + } + + if (cparams.embeddings && ( + cparams.pooling_type == LLAMA_POOLING_TYPE_CLS || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(lctx.inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); + + uint32_t * data = (uint32_t *) lctx.inp_cls->data; + memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; + + if (pos == 0) { + data[seq_id] = s*n_seq_tokens + i; + } + } + } + } + + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(lctx.inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); + + uint32_t * data = (uint32_t *) lctx.inp_cls->data; + memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); + + std::vector last_pos(n_tokens, -1); + std::vector last_row(n_tokens, -1); + + for (int s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + + // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); + + for (int i = 0; i < n_seq_tokens; ++i) { + const llama_pos pos = ubatch.pos[s*n_seq_tokens + i]; + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = s*n_seq_tokens + i; + } + } + } + + for (int i = 0; i < n_tokens; ++i) { + if (last_row[i] >= 0) { + data[i] = last_row[i]; + } + } + } + + if (kv_self.recurrent) { + const int64_t n_kv = kv_self.n; + + if (lctx.inp_s_mask) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); + float * data = (float *) lctx.inp_s_mask->data; + + // clear unused states + for (int i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; + + data[i] = (float) (kv_cell.src >= 0); + + // only clear once + if (kv_cell.src < 0) { + kv_cell.src = cell_id; + } + } + } + + if (lctx.inp_s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t cell_id = i + kv_self.head; + llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; + + // prevent out-of-bound sources + if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + kv_cell.src = cell_id; + } + + data[i] = kv_cell.src; + + // ensure copy only happens once + if (kv_cell.src != (int32_t) cell_id) { + kv_cell.src = cell_id; + } + } + } + } + + if (lctx.inp_pos_bucket) { + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer)); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing + + int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; + + if (!lctx.is_encoding) { + const int64_t n_kv = kv_self.n; + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + } + } + } + } else { + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + } + } + } + } + } + + if (!lctx.is_encoding && lctx.inp_embd_enc) { + assert(lctx.inp_embd_enc->type == GGML_TYPE_F32); + assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size()); + + ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc)); + } + + if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) { + const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd; + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer)); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing + + float * data = (float *) lctx.inp_KQ_mask_cross->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_output_enc; ++i) { + float f = -INFINITY; + for (int s = 0; s < ubatch.n_seq_id[j]; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[j][s]; + if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) { + f = 0.0f; + } + } + data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f; + } + } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_output_enc; ++j) { + data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY; + } + } + } + } +} + +// llama output + +size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { + const auto & cparams = lctx.cparams; + const auto & hparams = lctx.model.hparams; + const auto & vocab = lctx.model.vocab; + + const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); + + const auto n_batch = cparams.n_batch; + const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; + + // TODO: use a per-batch flag for logits presence instead + const bool has_logits = !cparams.embeddings; + const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + + const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; + const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; + + if (lctx.output_ids.empty()) { + // init, never resized afterwards + lctx.output_ids.resize(n_batch); + } + + const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output.get()) : 0; + const size_t new_size = (logits_size + embd_size) * sizeof(float); + + // alloc only when more than the current capacity is required + // TODO: also consider shrinking the buffer + if (!lctx.buf_output || prev_size < new_size) { + if (lctx.buf_output) { +#ifndef NDEBUG + // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) + LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); +#endif + lctx.buf_output = nullptr; + lctx.logits = nullptr; + lctx.embd = nullptr; + } + + auto * buft = ggml_backend_cpu_buffer_type(); + // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory + auto * output_dev = lctx.model.dev_output(); + auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; + if (output_dev_host_buft) { + buft = output_dev_host_buft; + } + lctx.buf_output.reset(ggml_backend_buft_alloc_buffer(buft, new_size)); + if (lctx.buf_output == nullptr) { + LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); + return 0; + } + } + + float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output.get()); + + lctx.logits = has_logits ? output_base : nullptr; + lctx.embd = has_embd ? output_base + logits_size : nullptr; + + lctx.output_size = n_outputs_max; + lctx.logits_size = logits_size; + lctx.embd_size = embd_size; + + // set all ids as invalid (negative) + std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1); + + ggml_backend_buffer_clear(lctx.buf_output.get(), 0); + + lctx.n_outputs = 0; + + return n_outputs_max; +} + +void llama_output_reorder(struct llama_context & ctx) { + std::vector & out_ids = ctx.sbatch.out_ids; + if (!out_ids.empty()) { + const uint32_t n_vocab = ctx.model.vocab.n_tokens(); + const uint32_t n_embd = ctx.model.hparams.n_embd; + + const int32_t n_outputs = ctx.n_outputs; + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (ctx.logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(ctx.logits[i*n_vocab + k], ctx.logits[j_min*n_vocab + k]); + } + } + if (ctx.embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(ctx.embd[i*n_embd + k], ctx.embd[j_min*n_embd + k]); + } + } + } + std::fill(ctx.output_ids.begin(), ctx.output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + ctx.output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} + +// +// interface implementation +// + +void llama_free(struct llama_context * ctx) { + delete ctx; +} + +uint32_t llama_n_ctx(const struct llama_context * ctx) { + return ctx->cparams.n_ctx; +} + +uint32_t llama_n_batch(const struct llama_context * ctx) { + return ctx->cparams.n_batch; +} + +uint32_t llama_n_ubatch(const struct llama_context * ctx) { + return ctx->cparams.n_ubatch; +} + +uint32_t llama_n_seq_max(const struct llama_context * ctx) { + return ctx->kv_self.size; +} + +const struct llama_model * llama_get_model(const struct llama_context * ctx) { + return &ctx->model; +} + +enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { + return ctx->cparams.pooling_type; +} + +void llama_attach_threadpool( + struct llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) { + ctx->threadpool = threadpool; + ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool; +} + +void llama_detach_threadpool(struct llama_context * ctx) { + ctx->threadpool = nullptr; + ctx->threadpool_batch = nullptr; +} + +void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) { + ctx->cparams.n_threads = n_threads; + ctx->cparams.n_threads_batch = n_threads_batch; +} + +int32_t llama_n_threads(struct llama_context * ctx) { + return ctx->cparams.n_threads; +} + +int32_t llama_n_threads_batch(struct llama_context * ctx) { + return ctx->cparams.n_threads_batch; +} + +void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = abort_callback_data; + + for (auto & backend : ctx->backends) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data); + } + } +} + +void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { + ctx->cparams.embeddings = embeddings; +} + +void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { + ctx->cparams.causal_attn = causal_attn; +} + +void llama_synchronize(struct llama_context * ctx) { + ggml_backend_sched_synchronize(ctx->sched.get()); + + // FIXME: if multiple single tokens are evaluated without a synchronization, + // the stats will be added to the prompt evaluation stats + // this should only happen when using batch size 1 to evaluate a batch + + // add the evaluation to the stats + if (ctx->n_queued_tokens == 1) { + if (!ctx->cparams.no_perf) { + ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us; + } + ctx->n_eval++; + } else if (ctx->n_queued_tokens > 1) { + if (!ctx->cparams.no_perf) { + ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us; + } + ctx->n_p_eval += ctx->n_queued_tokens; + } + + // get a more accurate load time, upon first eval + if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) { + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; + } + + ctx->n_queued_tokens = 0; + ctx->t_compute_start_us = 0; +} + +float * llama_get_logits(struct llama_context * ctx) { + llama_synchronize(ctx); + + // reorder logits for backward compatibility + // TODO: maybe deprecate this + llama_output_reorder(*ctx); + + return ctx->logits; +} + +float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { + int32_t j = -1; + + llama_synchronize(ctx); + + try { + if (ctx->logits == nullptr) { + throw std::runtime_error("no logits"); + } + + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; + } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= ctx->n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); + } + + return ctx->logits + j*ctx->model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_get_embeddings(struct llama_context * ctx) { + llama_synchronize(ctx); + + // reorder embeddings for backward compatibility + // TODO: maybe deprecate this + llama_output_reorder(*ctx); + + return ctx->embd; +} + +float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { + int32_t j = -1; + + llama_synchronize(ctx); + + try { + if (ctx->embd == nullptr) { + throw std::runtime_error("no embeddings"); + } + + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; + } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= ctx->n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); + } + + return ctx->embd + j*ctx->model.hparams.n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { + llama_synchronize(ctx); + + auto it = ctx->embd_seq.find(seq_id); + if (it == ctx->embd_seq.end()) { + return nullptr; + } + + return it->second.data(); +} + +// llama state API + +// deprecated +size_t llama_get_state_size(struct llama_context * ctx) { + return llama_state_get_size(ctx); +} + +// deprecated +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { + return llama_state_get_data(ctx, dst, -1); +} + +// deprecated +size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { + return llama_state_set_data(ctx, src, -1); +} + +// deprecated +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); +} + +// deprecated +bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + return llama_state_save_file(ctx, path_session, tokens, n_token_count); +} + +// TODO: replace all non-fatal assertions with returned errors or exceptions +struct llama_data_write { + virtual void write(const void * src, size_t size) = 0; + virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0; + virtual size_t get_size_written() = 0; + virtual ~llama_data_write() = default; + + void write_string(const std::string & str) { + uint32_t str_size = str.size(); + + write(&str_size, sizeof(str_size)); + write(str.data(), str_size); + } + + void write_model_info(const struct llama_context * ctx) { + const std::string arch_str = llm_arch_name(ctx->model.arch); + write_string(arch_str); + // TODO: add more model-specific info which should prevent loading the session file if not identical + } + + //void write_rng(const std::mt19937 & rng) { + // std::ostringstream rng_ss; + // rng_ss << rng; + + // const std::string & rng_str = rng_ss.str(); + + // write_string(rng_str); + //} + + void write_output_ids(struct llama_context * ctx) { + llama_output_reorder(*ctx); + + const uint32_t n_outputs = ctx->n_outputs; + + std::vector output_pos; + + const size_t n_batch = ctx->cparams.n_batch; + const auto & output_ids = ctx->output_ids; + + GGML_ASSERT(n_outputs <= ctx->output_size); + + output_pos.resize(n_outputs); + + // build a more compact representation of the output ids + for (size_t i = 0; i < n_batch; ++i) { + // map an output id to a position in the batch + int32_t pos = output_ids[i]; + if (pos >= 0) { + GGML_ASSERT((uint32_t) pos < n_outputs); + output_pos[pos] = i; + } + } + + write(&n_outputs, sizeof(n_outputs)); + + if (n_outputs) { + write(output_pos.data(), n_outputs * sizeof(int32_t)); + } + } + + void write_logits(const struct llama_context * ctx) { + const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens()); + + write(&logits_size, sizeof(logits_size)); + + if (logits_size) { + write(ctx->logits, logits_size * sizeof(float)); + } + } + + void write_embeddings(const struct llama_context * ctx) { + const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd); + + write(&embeddings_size, sizeof(embeddings_size)); + + if (embeddings_size) { + write(ctx->embd, embeddings_size * sizeof(float)); + } + } + + void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = kv_self.cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + write(&pos, sizeof(pos)); + write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + write(&seq_id, sizeof(seq_id)); + } + } + } + } + } + + void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { + const struct llama_kv_cache & kv_self = ctx->kv_self; + const struct llama_hparams & hparams = ctx->model.hparams; + + const uint32_t v_trans = kv_self.v_trans ? 1 : 0; + const uint32_t n_layer = hparams.n_layer; + + write(&v_trans, sizeof(v_trans)); + write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!kv_self.v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = kv_self.size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + write_tensor_data(kv_self.v_l[il], src_offset, buf_size); + } + } + } + } + } + + void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { + const struct llama_kv_cache & kv_self = ctx->kv_self; + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = kv_self.size; + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto & cell = kv_self.cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == kv_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != kv_self.size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = kv_self.size; + } + } + } + if (cell_range_begin != kv_self.size) { + cell_ranges.emplace_back(cell_range_begin, kv_self.size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + write(&cell_count, sizeof(cell_count)); + + write_kv_cache_meta(kv_self, cell_ranges, seq_id); + write_kv_cache_data(ctx, cell_ranges); + } +}; + +struct llama_data_read { + virtual const uint8_t * read(size_t size) = 0; + virtual void read_to(void * dst, size_t size) = 0; + virtual size_t get_size_read() = 0; + virtual ~llama_data_read() = default; + + void read_string(std::string & str) { + uint32_t str_size; + read_to(&str_size, sizeof(str_size)); + + str.assign((const char *) read(str_size), str_size); + } + + // validate model information + void read_model_info(const struct llama_context * ctx) { + const std::string cur_arch_str = llm_arch_name(ctx->model.arch); + + std::string arch_str; + read_string(arch_str); + if (cur_arch_str != arch_str) { + throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); + } + // TODO: add more info which needs to be identical but which is not verified otherwise + } + + //void read_rng(std::mt19937 & rng) { + // std::string rng_str; + // read_string(rng_str); + + // std::istringstream rng_ss(rng_str); + // rng_ss >> rng; + + // if (rng_ss.fail()) { + // throw std::runtime_error("failed to load RNG state"); + // } + //} + + void read_output_ids(struct llama_context * ctx) { + std::vector output_pos; + + uint32_t n_outputs; + read_to(&n_outputs, sizeof(n_outputs)); + + if (n_outputs > llama_output_reserve(*ctx, n_outputs)) { + throw std::runtime_error("could not reserve outputs"); + } + + if (n_outputs) { + output_pos.resize(n_outputs); + read_to(output_pos.data(), n_outputs * sizeof(int32_t)); + + for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { + int32_t id = output_pos[i]; + if ((uint32_t) id >= ctx->cparams.n_batch) { + throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch)); + } + ctx->output_ids[id] = i; + } + + ctx->n_outputs = n_outputs; + } + } + + void read_logits(struct llama_context * ctx) { + uint64_t logits_size; + read_to(&logits_size, sizeof(logits_size)); + + if (ctx->logits_size < logits_size) { + throw std::runtime_error("logits buffer too small"); + } + + if (logits_size) { + read_to(ctx->logits, logits_size * sizeof(float)); + } + } + + void read_embeddings(struct llama_context * ctx) { + uint64_t embeddings_size; + read_to(&embeddings_size, sizeof(embeddings_size)); + + if (ctx->embd_size < embeddings_size) { + throw std::runtime_error("embeddings buffer too small"); + } + + if (embeddings_size) { + read_to(ctx->embd, embeddings_size * sizeof(float)); + } + } + + bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) { + struct llama_kv_cache & kv_self = ctx->kv_self; + + if (dest_seq_id != -1) { + // single sequence + + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + if (!llama_kv_cache_find_slot(kv_self, batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); + GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > kv_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + llama_kv_cache_clear(kv_self); + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_kv_cell & cell = kv_self.cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); + + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + return false; + } + + cell.seq_id.insert(seq_id); + + if (kv_self.recurrent) { + int32_t & tail = kv_self.cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + } + + kv_self.head = 0; + kv_self.used = cell_count; + } + + if (kv_self.recurrent) { + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = kv_self.head + i; + // make sure the recurrent states will keep their restored state + kv_self.cells[cell_id].src = cell_id; + } + } + + return true; + } + + bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { + const struct llama_hparams & hparams = ctx->model.hparams; + struct llama_kv_cache & kv_self = ctx->kv_self; + uint32_t v_trans; + uint32_t n_layer; + read_to(&v_trans, sizeof(v_trans)); + read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > kv_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size); + return false; + } + if (kv_self.v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); + } + } + + if (!kv_self.v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + return true; + } + + void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { + uint32_t cell_count; + read_to(&cell_count, sizeof(cell_count)); + + bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); + + if (!res) { + if (seq_id == -1) { + llama_kv_cache_clear(ctx); + } else { + llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } + } +}; + +struct llama_data_write_dummy : llama_data_write { + size_t size_written = 0; + + llama_data_write_dummy() {} + + void write(const void * /* src */, size_t size) override { + size_written += size; + } + + void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_write_buffer : llama_data_write { + uint8_t * ptr; + size_t buf_size = 0; + size_t size_written = 0; + + llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + void write(const void * src, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(ptr, src, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ggml_backend_tensor_get(tensor, ptr, offset, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_read_buffer : llama_data_read { + const uint8_t * ptr; + size_t buf_size = 0; + size_t size_read = 0; + + llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + const uint8_t * read(size_t size) override { + const uint8_t * base_ptr = ptr; + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ptr += size; + size_read += size; + buf_size -= size; + return base_ptr; + } + + void read_to(void * dst, size_t size) override { + memcpy(dst, read(size), size); + } + + size_t get_size_read() override { + return size_read; + } +}; + +struct llama_data_write_file : llama_data_write { + llama_file * file; + size_t size_written = 0; + std::vector temp_buffer; + + llama_data_write_file(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + temp_buffer.resize(size); + ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); + write(temp_buffer.data(), temp_buffer.size()); + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_read_file : llama_data_read { + llama_file * file; + size_t size_read = 0; + std::vector temp_buffer; + + llama_data_read_file(llama_file * f) : file(f) {} + + void read_to(void * dst, size_t size) override { + file->read_raw(dst, size); + size_read += size; + } + + const uint8_t * read(size_t size) override { + temp_buffer.resize(size); + read_to(temp_buffer.data(), size); + return temp_buffer.data(); + } + + size_t get_size_read() override { + return size_read; + } +}; + +/** copy state data into either a buffer or file depending on the passed in context + * + * file context: + * llama_file file("/path", "wb"); + * llama_data_write_file data_ctx(&file); + * llama_state_get_data_internal(ctx, data_ctx); + * + * buffer context: + * std::vector buf(max_size, 0); + * llama_data_write_buffer data_ctx(buf.data(), max_size); + * llama_state_get_data_internal(ctx, data_ctx); + * +*/ +static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) { + llama_synchronize(ctx); + + data_ctx.write_model_info(ctx); + + // copy outputs + data_ctx.write_output_ids(ctx); + data_ctx.write_logits(ctx); + data_ctx.write_embeddings(ctx); + + data_ctx.write_kv_cache(ctx); + + return data_ctx.get_size_written(); +} + +size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) { + llama_data_write_buffer data_ctx(dst, size); + try { + return llama_state_get_data_internal(ctx, data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } +} + +// Returns the *actual* size of the state. +// Intended to be used when saving to state to a buffer. +size_t llama_state_get_size(struct llama_context * ctx) { + llama_data_write_dummy data_ctx; + try { + return llama_state_get_data_internal(ctx, data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } +} + +static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) { + llama_synchronize(ctx); + + data_ctx.read_model_info(ctx); + + // set outputs + data_ctx.read_output_ids(ctx); + data_ctx.read_logits(ctx); + data_ctx.read_embeddings(ctx); + + data_ctx.read_kv_cache(ctx); + + return data_ctx.get_size_read(); +} + +// Sets the state reading from the specified source address +size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) { + llama_data_read_buffer data_ctx(src, size); + try { + return llama_state_set_data_internal(ctx, data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } +} + +static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(path_session, "rb"); + + // sanity checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + return false; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return false; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t n_state_size_cur = file.size() - file.tell(); + + llama_data_read_file data_ctx(&file); + const size_t n_read = llama_state_set_data_internal(ctx, data_ctx); + + if (n_read != n_state_size_cur) { + LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); + return false; + } + } + return true; +} + +bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); + return false; + } +} + +static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + llama_file file(path_session, "wb"); + + file.write_u32(LLAMA_SESSION_MAGIC); + file.write_u32(LLAMA_SESSION_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_write_file data_ctx(&file); + llama_state_get_data_internal(ctx, data_ctx); + + return true; +} + +bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); + return false; + } +} + +static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { + llama_synchronize(ctx); + + data_ctx.write_kv_cache(ctx, seq_id); + + return data_ctx.get_size_written(); +} + +size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) { + llama_data_write_dummy data_ctx; + return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); +} + +size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { + llama_data_write_buffer data_ctx(dst, size); + try { + return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what()); + return 0; + } +} + +static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { + llama_synchronize(ctx); + + data_ctx.read_kv_cache(ctx, dest_seq_id); + + return data_ctx.get_size_read(); +} + +size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) { + llama_data_read_buffer data_ctx(src, size); + try { + return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what()); + return 0; + } +} + +static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_STATE_SEQ_MAGIC); + file.write_u32(LLAMA_STATE_SEQ_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_write_file data_ctx(&file); + llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + + const size_t res = file.tell(); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); + return res; +} + +static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // version checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); + return 0; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return 0; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t state_size = file.size() - file.tell(); + llama_data_read_file data_ctx(&file); + const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); + if (!nread) { + LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); + return 0; + } + GGML_ASSERT(nread <= state_size); + GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); + } + + return file.tell(); +} + +size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); + return 0; + } +} + +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +) { + return ctx->model.tensors_by_name; +} diff --git a/src/llama-context.h b/src/llama-context.h new file mode 100644 index 000000000..a9268b292 --- /dev/null +++ b/src/llama-context.h @@ -0,0 +1,128 @@ +#pragma once + +#include "llama.h" +#include "llama-batch.h" +#include "llama-cparams.h" +#include "llama-model.h" +#include "llama-kv-cache.h" +#include "llama-adapter.h" + +#include "ggml-cpp.h" + +#include +#include +#include +#include + +struct llama_context { + llama_context(const llama_model & model) + : model(model) + , t_start_us(model.t_start_us) + , t_load_us(model.t_load_us) {} + + const struct llama_model & model; + + struct llama_cparams cparams; + struct llama_sbatch sbatch; // TODO: revisit if needed + struct llama_kv_cache kv_self; + struct llama_adapter_cvec cvec; + + std::unordered_map lora; + + std::vector backends; + std::vector> set_n_threads_fns; + + ggml_backend_t backend_cpu = nullptr; + + ggml_threadpool_t threadpool = nullptr; + ggml_threadpool_t threadpool_batch = nullptr; + + bool has_evaluated_once = false; + + mutable int64_t t_start_us; + mutable int64_t t_load_us; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; + + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; + + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls + + // host buffer for the model output (logits and embeddings) + ggml_backend_buffer_ptr buf_output; + + // decode output (2-dimensional array: [n_outputs][n_vocab]) + size_t logits_size = 0; // capacity (of floats) for logits + float * logits = nullptr; + + std::vector output_ids; // map batch token positions to ids of the logits and embd buffers + size_t output_size = 0; // capacity (of tokens positions) for the output buffers + int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch + + bool logits_all = false; + + // embeddings output (2-dimensional array: [n_outputs][n_embd]) + // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE + size_t embd_size = 0; // capacity (of floats) for embeddings + float * embd = nullptr; + + // sequence embeddings output (map of [n_embd] vectors) + // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE + std::map> embd_seq; + + // whether we are computing encoder output or decoder output + bool is_encoding = false; + + // TODO: find a better way to accommodate mutli-dimension position encoding methods + // number of position id each token get, 1 for each token in most cases. + // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate. + int n_pos_per_token = 1; + + // output of the encoder part of the encoder-decoder models + std::vector embd_enc; + std::vector> seq_ids_enc; + + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + ggml_backend_sched_ptr sched; + + ggml_abort_callback abort_callback = nullptr; + void * abort_callback_data = nullptr; + + // input tensors + struct ggml_tensor * inp_tokens; // I32 [n_batch] + struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] + struct ggml_tensor * inp_pos; // I32 [n_batch] + struct ggml_tensor * inp_out_ids; // I32 [n_outputs] + struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_K_shift; // I32 [kv_size] + struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] + struct ggml_tensor * inp_cls; // I32 [n_batch] + struct ggml_tensor * inp_s_copy; // I32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] + struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] + struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] + struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] +}; + +// TODO: make these methods of llama_context +void llama_set_k_shift(struct llama_context & lctx); + +void llama_set_s_copy(struct llama_context & lctx); + +void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch); + +// Make sure enough space is available for outputs. +// Returns max number of outputs for which space was reserved. +size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs); + +// make the outputs have the same order they had in the user-provided batch +void llama_output_reorder(struct llama_context & ctx); + +// For internal test use +// TODO: remove +const std::vector> & llama_internal_get_tensor_map(struct llama_context * ctx); diff --git a/src/llama-cparams.cpp b/src/llama-cparams.cpp new file mode 100644 index 000000000..28369be36 --- /dev/null +++ b/src/llama-cparams.cpp @@ -0,0 +1 @@ +#include "llama-cparams.h" diff --git a/src/llama-cparams.h b/src/llama-cparams.h new file mode 100644 index 000000000..252012f3d --- /dev/null +++ b/src/llama-cparams.h @@ -0,0 +1,37 @@ +#pragma once + +#include "llama.h" + +#include + +struct llama_cparams { + uint32_t n_ctx; // context size used during inference + uint32_t n_batch; + uint32_t n_ubatch; + uint32_t n_seq_max; + int n_threads; // number of threads to use for generation + int n_threads_batch; // number of threads to use for batch processing + + float rope_freq_base; + float rope_freq_scale; + + uint32_t n_ctx_orig_yarn; + // These hyperparameters are not exposed in GGUF, because all + // existing YaRN models use the same values for them. + float yarn_ext_factor; + float yarn_attn_factor; + float yarn_beta_fast; + float yarn_beta_slow; + float defrag_thold; + + bool embeddings; + bool causal_attn; + bool offload_kqv; + bool flash_attn; + bool no_perf; + + enum llama_pooling_type pooling_type; + + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; +}; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 74e9f64b3..6be5cbe0e 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1,5 +1,6 @@ #include "llama-grammar.h" +#include "llama-impl.h" #include "llama-vocab.h" #include "llama-sampling.h" @@ -559,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) { } } } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); rules.clear(); return false; } @@ -822,15 +823,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & stacks_new) { - stacks_new.clear(); - stacks_new.reserve(stacks.size()); +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar->stacks.size()); - for (const auto & stack : stacks) { + for (const auto & stack : grammar->stacks) { if (stack.empty()) { continue; } @@ -844,9 +841,11 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, stacks_new); + llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); } } + + grammar->stacks = std::move(stacks_new); } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -961,10 +960,28 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy =*/ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_tokens = */ {}, + /* .trigger_words = */ {}, + }; } -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it @@ -1036,10 +1053,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, } } while (true); + std::vector vec_trigger_tokens; + std::vector vec_trigger_words; + for (size_t i = 0; i < num_trigger_tokens; i++) { + GGML_ASSERT(trigger_tokens != nullptr); + vec_trigger_tokens.push_back(trigger_tokens[i]); + } + for (size_t i = 0; i < num_trigger_words; i++) { + GGML_ASSERT(trigger_words != nullptr); + vec_trigger_words.push_back(trigger_words[i]); + } + // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + std::move(vec_trigger_tokens), + std::move(vec_trigger_words), + }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1051,7 +1089,17 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { - llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; + llama_grammar * result = new llama_grammar { + grammar.vocab, + grammar.rules, + grammar.stacks, + grammar.partial_utf8, + grammar.lazy, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_tokens, + grammar.trigger_words, + }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { @@ -1059,7 +1107,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; + result->stacks[is][ie] = &result->rules[ir0][ir1]; } } } @@ -1072,6 +1120,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { GGML_ASSERT(grammar.vocab != nullptr); + if (grammar.awaiting_trigger) { + return; + } + bool allow_eog = false; for (const auto & stack : grammar.stacks) { if (stack.empty()) { @@ -1088,9 +1140,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; - const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); + const std::string & piece = grammar.vocab->token_to_piece(id); - if (llama_token_is_eog_impl(*grammar.vocab, id)) { + if (grammar.vocab->is_eog(id)) { if (!allow_eog) { cur_p->data[i].logit = -INFINITY; } @@ -1111,7 +1163,35 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); - if (llama_token_is_eog_impl(*grammar.vocab, token)) { + const auto & piece = grammar.vocab->token_to_piece(token); + + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, piece); + LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); + return; + } else { + // TODO: consider a smarter incremental substring search algorithm (store last position to search from). + grammar.trigger_buffer += piece; + for (const auto & word : grammar.trigger_words) { + auto pos = grammar.trigger_buffer.find(word); + if (pos != std::string::npos) { + grammar.awaiting_trigger = false; + auto constrained_str = grammar.trigger_buffer.substr(pos); + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, constrained_str); + LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str()); + return; + } + } + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`) (buffer: `%s`)\n", token, piece.c_str(), grammar.trigger_buffer.c_str()); + return; + } + } + + if (grammar.vocab->is_eog(token)) { for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; @@ -1120,17 +1200,16 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); + llama_grammar_accept_str(grammar, piece); +} +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; - llama_grammar_stacks stacks_new; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new); - grammar.stacks = std::move(stacks_new); + llama_grammar_accept(&grammar, *it); } grammar.partial_utf8 = decoded.second; diff --git a/src/llama-grammar.h b/src/llama-grammar.h index f529ce351..252d54d4c 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -1,8 +1,10 @@ #pragma once -#include "llama-impl.h" +#include "llama.h" #include +#include +#include struct llama_vocab; @@ -58,6 +60,7 @@ using llama_grammar_rules = std::vector; using llama_grammar_stacks = std::vector; using llama_grammar_candidates = std::vector; +// TODO: remove, needed for tests atm const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); @@ -65,11 +68,7 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - uint32_t chr, - llama_grammar_stacks & stacks_new); +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, @@ -115,6 +114,15 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + + // lazy grammars wait for trigger words or tokens before constraining the sampling. + // we still ahve trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // (useful e.g. for tool_choice=required) + bool lazy = false; + bool awaiting_trigger = false; // Initialized to true for lazy grammars only + std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). + std::vector trigger_words; }; // @@ -128,7 +136,15 @@ struct llama_grammar * llama_grammar_init_impl( size_t n_rules, size_t start_rule_index); -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); void llama_grammar_free_impl(struct llama_grammar * grammar); @@ -142,3 +158,7 @@ void llama_grammar_apply_impl( void llama_grammar_accept_impl( struct llama_grammar & grammar, llama_token token); + +void llama_grammar_accept_str( + struct llama_grammar & grammar, + const std::string & piece); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp new file mode 100644 index 000000000..ea87b2953 --- /dev/null +++ b/src/llama-hparams.cpp @@ -0,0 +1,71 @@ +#include "llama-hparams.h" + +#include "ggml.h" + +uint32_t llama_hparams::n_head(uint32_t il) const { + if (il < n_layer) { + return n_head_arr[il]; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_head_kv(uint32_t il) const { + if (il < n_layer) { + return n_head_kv_arr[il]; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_ff(uint32_t il) const { + if (il < n_layer) { + return n_ff_arr[il]; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_gqa(uint32_t il) const { + const uint32_t n_head = this->n_head(il); + const uint32_t n_head_kv = this->n_head_kv(il); + + if (n_head_kv == 0) { + return 0; + } + + return n_head/n_head_kv; +} + +uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { + const uint32_t n_head_kv = this->n_head_kv(il); + + return n_embd_head_k * n_head_kv; +} + +uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { + const uint32_t n_head_kv = this->n_head_kv(il); + + return n_embd_head_v * n_head_kv; +} + +uint32_t llama_hparams::n_embd_k_s() const { + if (wkv_head_size != 0) { + // for RWKV models + return token_shift_count * n_embd; + } + + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; +} + +uint32_t llama_hparams::n_embd_v_s() const { + if (wkv_head_size != 0) { + // corresponds to RWKV's wkv_states size + return n_embd * wkv_head_size; + } + + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; +} diff --git a/src/llama-hparams.h b/src/llama-hparams.h new file mode 100644 index 000000000..1fe454103 --- /dev/null +++ b/src/llama-hparams.h @@ -0,0 +1,139 @@ +#pragma once + +#include "llama.h" + +#include + +// bump if necessary +#define LLAMA_MAX_LAYERS 512 +#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 + +enum llama_expert_gating_func_type { + LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, +}; + +struct llama_hparams_posnet { + uint32_t n_embd; + uint32_t n_layer; +}; + +struct llama_hparams_convnext { + uint32_t n_embd; + uint32_t n_layer; +}; + +struct llama_hparams { + bool vocab_only; + bool rope_finetuned; + bool use_par_res; + bool swin_norm; + + uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_embd; + uint32_t n_embd_features = 0; + uint32_t n_layer; + uint32_t n_rot; + uint32_t n_swa = 0; // sliding window attention (SWA) + uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads + uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head + uint32_t n_expert = 0; + uint32_t n_expert_used = 0; + uint32_t n_rel_attn_bkts = 0; + + // for WavTokenizer + struct llama_hparams_posnet posnet; + struct llama_hparams_convnext convnext; + + std::array n_head_arr; + std::array n_head_kv_arr; + std::array n_ff_arr; + + uint32_t n_layer_dense_lead = 0; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_ff_exp = 0; + uint32_t n_ff_shexp = 0; + uint32_t n_expert_shared = 0; + uint32_t n_norm_groups = 0; + + float expert_weights_scale = 0.0; + bool expert_weights_norm = false; + uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; + + float f_norm_eps; + float f_norm_rms_eps; + float f_norm_group_eps; + + float f_attn_logit_softcapping = 50.0f; + float f_final_logit_softcapping = 30.0f; + + // for RWKV + uint32_t rescale_every_n_layers = 0; + uint32_t time_mix_extra_dim = 0; + uint32_t time_decay_extra_dim = 0; + uint32_t wkv_head_size = 0; + uint32_t token_shift_count = 2; + + float rope_attn_factor = 1.0f; + float rope_freq_base_train; + float rope_freq_scale_train; + uint32_t n_ctx_orig_yarn; + float rope_yarn_log_mul; + + std::array rope_sections; + + // for State Space Models + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + + bool ssm_dt_b_c_rms = false; + + float f_clamp_kqv = 0.0f; + float f_max_alibi_bias = 0.0f; + float f_logit_scale = 0.0f; + + // Additional scale factors (Granite/Granite MoE) + float f_residual_scale = 0.0f; + float f_embedding_scale = 0.0f; + float f_attention_scale = 0.0f; + + bool causal_attn = true; + bool use_alibi = false; + bool attn_soft_cap = false; + + // needed by encoder-decoder models (e.g. T5, FLAN-T5) + // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + llama_token dec_start_token_id = LLAMA_TOKEN_NULL; + + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; + enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; + enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + uint32_t n_head(uint32_t il = 0) const; + + uint32_t n_head_kv(uint32_t il = 0) const; + + uint32_t n_ff(uint32_t il = 0) const; + + uint32_t n_gqa(uint32_t il = 0) const; + + // dimension of key embeddings across all k-v heads + uint32_t n_embd_k_gqa(uint32_t il = 0) const; + + // dimension of value embeddings across all k-v heads + uint32_t n_embd_v_gqa(uint32_t il = 0) const; + + // dimension of the rolling state embeddings + // corresponds to Mamba's conv_states size or RWKV's token_shift states size + uint32_t n_embd_k_s() const; + + // dimension of the recurrent state embeddings + uint32_t n_embd_v_s() const; +}; + +static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); + diff --git a/src/llama-impl.cpp b/src/llama-impl.cpp new file mode 100644 index 000000000..6ec709dd3 --- /dev/null +++ b/src/llama-impl.cpp @@ -0,0 +1,167 @@ +#include "llama-impl.h" + +#include "gguf.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include + +struct llama_logger_state { + ggml_log_callback log_callback = llama_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static llama_logger_state g_logger_state; + +time_meas::time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} + +time_meas::~time_meas() { + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; + } + } + +void llama_log_set(ggml_log_callback log_callback, void * user_data) { + ggml_log_set(log_callback, user_data); + g_logger_state.log_callback = log_callback ? log_callback : llama_log_callback_default; + g_logger_state.log_callback_user_data = user_data; +} + +static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) { + va_list args_copy; + va_copy(args_copy, args); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data); + } else { + char * buffer2 = new char[len + 1]; + vsnprintf(buffer2, len + 1, format, args_copy); + buffer2[len] = 0; + g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args_copy); +} + +void llama_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + llama_log_internal_v(level, format, args); + va_end(args); +} + +void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +void replace_all(std::string & s, const std::string & search, const std::string & replace) { + if (search.empty()) { + return; + } + std::string builder; + builder.reserve(s.length()); + size_t pos = 0; + size_t last_pos = 0; + while ((pos = s.find(search, last_pos)) != std::string::npos) { + builder.append(s, last_pos, pos - last_pos); + builder.append(replace); + last_pos = pos + search.length(); + } + builder.append(s, last_pos, std::string::npos); + s = std::move(builder); +} + +std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +std::string llama_format_tensor_shape(const std::vector & ne) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + for (size_t i = 1; i < ne.size(); i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + } + return buf; +} + +std::string llama_format_tensor_shape(const struct ggml_tensor * t) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + } + return buf; +} + +static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) { + switch (type) { + case GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); + case GGUF_TYPE_INT8: return std::to_string(((const int8_t *)data)[i]); + case GGUF_TYPE_UINT16: return std::to_string(((const uint16_t *)data)[i]); + case GGUF_TYPE_INT16: return std::to_string(((const int16_t *)data)[i]); + case GGUF_TYPE_UINT32: return std::to_string(((const uint32_t *)data)[i]); + case GGUF_TYPE_INT32: return std::to_string(((const int32_t *)data)[i]); + case GGUF_TYPE_UINT64: return std::to_string(((const uint64_t *)data)[i]); + case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); + case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); + case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); + case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + default: return format("unknown type %d", type); + } +} + +std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + + switch (type) { + case GGUF_TYPE_STRING: + return gguf_get_val_str(ctx_gguf, i); + case GGUF_TYPE_ARRAY: + { + const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); + int arr_n = gguf_get_arr_n(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); + std::stringstream ss; + ss << "["; + for (int j = 0; j < arr_n; j++) { + if (arr_type == GGUF_TYPE_STRING) { + std::string val = gguf_get_arr_str(ctx_gguf, i, j); + // escape quotes + replace_all(val, "\\", "\\\\"); + replace_all(val, "\"", "\\\""); + ss << '"' << val << '"'; + } else if (arr_type == GGUF_TYPE_ARRAY) { + ss << "???"; + } else { + ss << gguf_data_to_str(arr_type, data, j); + } + if (j < arr_n - 1) { + ss << ", "; + } + } + ss << "]"; + return ss.str(); + } + default: + return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0); + } +} diff --git a/src/llama-impl.h b/src/llama-impl.h index 87012617f..12d1fb082 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -1,10 +1,9 @@ #pragma once -#include "llama.h" +#include "ggml.h" // for ggml_log_level #include #include -#include #ifdef __GNUC__ #ifdef __MINGW32__ @@ -24,155 +23,39 @@ LLAMA_ATTRIBUTE_FORMAT(2, 3) void llama_log_internal (ggml_log_level level, const char * format, ...); void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); +#define LLAMA_LOG(...) llama_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__) #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__) // // helpers // -struct time_meas { - time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} +template +struct no_init { + T value; + no_init() { /* do nothing */ } +}; - ~time_meas() { - if (t_start_us >= 0) { - t_acc += ggml_time_us() - t_start_us; - } - } +struct time_meas { + time_meas(int64_t & t_acc, bool disable = false); + ~time_meas(); const int64_t t_start_us; int64_t & t_acc; }; -static void replace_all(std::string & s, const std::string & search, const std::string & replace) { - if (search.empty()) { - return; - } - std::string builder; - builder.reserve(s.length()); - size_t pos = 0; - size_t last_pos = 0; - while ((pos = s.find(search, last_pos)) != std::string::npos) { - builder.append(s, last_pos, pos - last_pos); - builder.append(replace); - last_pos = pos + search.length(); - } - builder.append(s, last_pos, std::string::npos); - s = std::move(builder); -} +void replace_all(std::string & s, const std::string & search, const std::string & replace); -const std::vector> & llama_internal_get_tensor_map( - struct llama_context * ctx -); +// TODO: rename to llama_format ? +LLAMA_ATTRIBUTE_FORMAT(1, 2) +std::string format(const char * fmt, ...); -// the ring buffer works similarly to std::deque, but with a fixed capacity -template -struct ring_buffer { - ring_buffer(size_t cap) : capacity(cap), data(cap) {} +std::string llama_format_tensor_shape(const std::vector & ne); +std::string llama_format_tensor_shape(const struct ggml_tensor * t); - T & front() { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[first]; - } - - const T & front() const { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[first]; - } - - T & back() { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[pos]; - } - - const T & back() const { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - return data[pos]; - } - - void push_back(const T & value) { - if (capacity == 0) { - throw std::runtime_error("ring buffer: capacity is zero"); - } - - if (sz == capacity) { - // advance the start when buffer is full - first = (first + 1) % capacity; - } else { - sz++; - } - data[pos] = value; - pos = (pos + 1) % capacity; - } - - T pop_front() { - if (sz == 0) { - throw std::runtime_error("ring buffer is empty"); - } - T value = data[first]; - first = (first + 1) % capacity; - sz--; - return value; - } - - //T & operator[](size_t i) { - // if (i >= sz) { - // throw std::runtime_error("ring buffer: index out of bounds"); - // } - // return data[(first + i) % capacity]; - //} - - //const T & at(size_t i) const { - // if (i >= sz) { - // throw std::runtime_error("ring buffer: index out of bounds"); - // } - // return data[(first + i) % capacity]; - //} - - const T & rat(size_t i) const { - if (i >= sz) { - throw std::runtime_error("ring buffer: index out of bounds"); - } - return data[(first + sz - i - 1) % capacity]; - } - - std::vector to_vector() const { - std::vector result; - result.reserve(sz); - for (size_t i = 0; i < sz; i++) { - result.push_back(data[(first + i) % capacity]); - } - return result; - } - - void clear() { - // here only reset the status of the buffer - sz = 0; - first = 0; - pos = 0; - } - - bool empty() const { - return sz == 0; - } - - size_t size() const { - return sz; - } - - size_t capacity = 0; - size_t sz = 0; - size_t first = 0; - size_t pos = 0; - std::vector data; -}; +std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp new file mode 100644 index 000000000..feffdf0de --- /dev/null +++ b/src/llama-kv-cache.cpp @@ -0,0 +1,718 @@ +#include "llama-kv-cache.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-cparams.h" +#include "llama-model.h" + +#include +#include +#include + +static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; + +uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} + +bool llama_kv_cache_init( + struct llama_kv_cache & cache, + const llama_model & model, + const llama_cparams & cparams, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + bool offload) { + const struct llama_hparams & hparams = model.hparams; + + const int32_t n_layer = hparams.n_layer; + + cache.has_shift = false; + + cache.recurrent = llama_model_is_recurrent(&model); + cache.v_trans = !cache.recurrent && !cparams.flash_attn; + cache.can_shift = !cache.recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + + LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", + __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, cache.can_shift); + + cache.head = 0; + cache.size = kv_size; + cache.used = 0; + + cache.type_k = type_k; + cache.type_v = type_v; + + cache.cells.clear(); + cache.cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + struct ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + ctx_map[buft] = ctx; + cache.ctxs.emplace_back(ctx); + return ctx; + } + return it->second; + }; + + cache.k_l.reserve(n_layer); + cache.v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa); + + ggml_backend_buffer_type_t buft; + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + } else { + buft = ggml_backend_cpu_buffer_type(); + } + ggml_context * ctx = ctx_for_buft(buft); + + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); + return false; + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.k_l.push_back(k); + cache.v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + cache.bufs.emplace_back(buf); + } + + return true; +} + +struct llama_kv_cache_slot_info llama_kv_cache_find_slot( + struct llama_kv_cache & cache, + const struct llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + if (cache.recurrent) { + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = cache.size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + return llama_kv_cache_slot_info_failed; + } + if (j > 0) { + llama_kv_cell & seq = cache.cells[seq_id]; + if (seq.tail >= 0) { + llama_kv_cell & cell = cache.cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + cache.used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(cache.size, -1); + for (uint32_t i = 0; i < cache.size; ++i) { + llama_kv_cell & cell = cache.cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < cache.size; ++i) { + if (tails_verif[i] != cache.cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = cache.head; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } + llama_kv_cell & cell = cache.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + llama_kv_cell & seq_meta = cache.cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + llama_kv_cell & cell = cache.cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + llama_kv_cell & empty_cell = cache.cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + llama_kv_cell & orig_cell = cache.cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < cache.size; ++i) { + if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } + llama_kv_cell & cell = cache.cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + llama_kv_cell & dst_cell = cache.cells[dst_id]; + llama_kv_cell & src_cell = cache.cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cache.cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cache.cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + llama_kv_cell & cell = cache.cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cache.cells[seq_id].tail = cell_id; + } + } + + // allow getting the range of used cells, from head to head + n + cache.head = min; + cache.n = max - min + 1; + cache.used = std::count_if(cache.cells.begin(), cache.cells.end(), + [](const llama_kv_cell& cell){ return !cell.is_empty(); }); + + // sanity check + return llama_kv_cache_slot_info(cache.n >= n_seqs); + } + // otherwise, one cell per token. + + if (n_tokens > cache.size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); + return llama_kv_cache_slot_info_failed; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.head + n_tokens > cache.size) { + n_tested += cache.size - cache.head; + cache.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= cache.size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return llama_kv_cache_slot_info_failed; + } + } + + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cache.cells[cache.head + k].pos = ubatch.pos[k]; + + for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { + cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]); + } + } + } + + cache.used += n_tokens; + + return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens); +} + +uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { + for (uint32_t i = cache.size; i > 0; --i) { + const llama_kv_cell & cell = cache.cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + +void llama_kv_cache_clear(struct llama_kv_cache & cache) { + for (int32_t i = 0; i < (int32_t) cache.size; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + cache.cells[i].src = -1; + cache.cells[i].tail = -1; + } + cache.head = 0; + cache.used = 0; + + for (auto & buf : cache.bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + // models like Mamba or RWKV can't have a state partially erased + if (cache.recurrent) { + if (seq_id >= (int64_t) cache.size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + const llama_kv_cell & cell = cache.cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + } + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if (seq_id < 0) { + cache.cells[i].seq_id.clear(); + } else if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cache.cells[i].is_empty()) { + // keep count of the number of used cells + if (cache.cells[i].pos >= 0) cache.used--; + + cache.cells[i].pos = -1; + cache.cells[i].src = -1; + if (new_head == cache.size) new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size && new_head < cache.head) cache.head = new_head; + + return true; +} + +void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + if (cache.recurrent) { + if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { + llama_kv_cell & tail_src = cache.cells[seq_id_src]; + llama_kv_cell & tail_dst = cache.cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + llama_kv_cell & cell_dst = cache.cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.delta = -1; + cell_dst.src = -1; + cache.used -= 1; + } + } + if (tail_src.tail >= 0) { + llama_kv_cell & cell_src = cache.cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } + + return; + } + // otherwise, this is the KV cache of a Transformer-like model + + cache.head = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } + } +} + +void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { + uint32_t new_head = cache.size; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.recurrent && (llama_seq_id) i != seq_id) { + cache.cells[i].tail = -1; + } + if (!cache.cells[i].has_seq_id(seq_id)) { + if (cache.cells[i].pos >= 0) cache.used--; + cache.cells[i].pos = -1; + cache.cells[i].src = -1; + cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; + } else { + cache.cells[i].seq_id.clear(); + cache.cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size && new_head < cache.head) cache.head = new_head; +} + +void llama_kv_cache_seq_add( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) return; + + if (cache.recurrent) { + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) cache.size) { + const int32_t tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cache.cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += delta; + } + } + } + return; + } + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.has_shift = true; + cache.cells[i].pos += delta; + cache.cells[i].delta += delta; + + if (cache.cells[i].pos < 0) { + if (!cache.cells[i].is_empty()) { + cache.used--; + } + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + if (new_head == cache.size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.head = new_head != cache.size ? new_head : 0; +} + +void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) return; + + if (cache.recurrent) { + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) cache.size) { + const int32_t tail_id = cache.cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cache.cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } + return; + } + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.has_shift = true; + + { + llama_pos p_old = cache.cells[i].pos; + cache.cells[i].pos /= d; + cache.cells[i].delta += cache.cells[i].pos - p_old; + } + } + } +} + +llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { + llama_pos result = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id)) { + result = std::max(result, cache.cells[i].pos); + } + } + + return result; +} + +void llama_kv_cache_defrag(struct llama_kv_cache & cache) { + if (!cache.recurrent) { + cache.do_defrag = true; + } +} + +int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv) { + int result = 0; + + for (uint32_t i = 0; i < kv.size; i++) { + result += kv.cells[i].seq_id.size(); + } + + return result; +} + +int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv) { + return kv.used; +} + +bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv) { + return kv.can_shift; +} + +// +// kv cache view +// + +struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max) { + struct llama_kv_cache_view result = { + /*.n_cells = */ 0, + /*.n_seq_max = */ n_seq_max, + /*.token_count = */ 0, + /*.used_cells = */ llama_get_kv_cache_used_cells(kv), + /*.max_contiguous = */ 0, + /*.max_contiguous_idx = */ -1, + /*.cells = */ nullptr, + /*.cells_sequences = */ nullptr, + }; + + return result; +} + +void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { + if (view->cells != nullptr) { + free(view->cells); + view->cells = nullptr; + } + if (view->cells_sequences != nullptr) { + free(view->cells_sequences); + view->cells_sequences = nullptr; + } +} + +void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv) { + if (uint32_t(view->n_cells) < kv.size || view->cells == nullptr) { + view->n_cells = int32_t(kv.size); + void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); + GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); + view->cells = (struct llama_kv_cache_view_cell *)p; + p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells); + GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences"); + view->cells_sequences = (llama_seq_id *)p; + } + + const std::vector & kv_cells = kv.cells; + llama_kv_cache_view_cell * c_curr = view->cells; + llama_seq_id * cs_curr = view->cells_sequences; + int32_t used_cells = 0; + int32_t token_count = 0; + int32_t curr_contig_idx = -1; + uint32_t max_contig = 0; + int32_t max_contig_idx = -1; + + for (int32_t i = 0; i < int32_t(kv.size); i++, c_curr++, cs_curr += view->n_seq_max) { + const size_t curr_size = kv_cells[i].seq_id.size(); + token_count += curr_size; + c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; + + if (curr_size > 0) { + if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) { + max_contig = i - curr_contig_idx; + max_contig_idx = curr_contig_idx; + } + curr_contig_idx = -1; + } else if (curr_contig_idx < 0) { + curr_contig_idx = i; + } + + int seq_idx = 0; + for (const llama_seq_id it : kv_cells[i].seq_id) { + if (seq_idx >= view->n_seq_max) { + break; + } + cs_curr[seq_idx] = it; + seq_idx++; + } + if (seq_idx != 0) { + used_cells++; + } + for (; seq_idx < view->n_seq_max; seq_idx++) { + cs_curr[seq_idx] = -1; + } + } + if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) { + max_contig_idx = curr_contig_idx; + max_contig = kv_cells.size() - curr_contig_idx; + } + view->max_contiguous = max_contig; + view->max_contiguous_idx = max_contig_idx; + view->token_count = token_count; + view->used_cells = used_cells; + if (uint32_t(used_cells) != kv.used) { + LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", + __func__, kv.used, used_cells); + } +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h new file mode 100644 index 000000000..dca6f3998 --- /dev/null +++ b/src/llama-kv-cache.h @@ -0,0 +1,218 @@ +#pragma once + +#include "llama.h" + +#include "ggml-cpp.h" + +#include +#include + +struct llama_kv_cell { + llama_pos pos = -1; + llama_pos delta = 0; + int32_t src = -1; // used by recurrent state models to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const llama_kv_cell & other) const { + return seq_id == other.seq_id; + } +}; + +// ring-buffer of cached KV data +struct llama_kv_cache { + bool has_shift = false; + bool do_defrag = false; + bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token + bool v_trans = true; // the value tensor is transposed + bool can_shift = false; + + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_internal also uses it, so it + // cannot be freely changed after a slot has been allocated. + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + ggml_type type_k = GGML_TYPE_F16; + ggml_type type_v = GGML_TYPE_F16; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + + std::vector ctxs; + std::vector bufs; + + size_t total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; + } + + // TODO: better data structures to reduce the cost of this operation + llama_pos max_pos() const { + llama_pos max_pos = -1; + for (const auto & cell : cells) { + max_pos = std::max(max_pos, cell.pos); + } + + return max_pos; + } +}; + +// a structure holds information about the slot found in llama_kv_cache_find_slot +struct llama_kv_cache_slot_info { + std::pair boundaries; // slot boundaries [begin, end) + bool found = false; // the slot was found + + explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} + llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} + + operator bool() const { return found; } +}; + +// TODO: maybe not needed +uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams); + +bool llama_kv_cache_init( + struct llama_kv_cache & cache, + const llama_model & model, + const llama_cparams & cparams, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + bool offload); + +// find an empty slot of size "n_tokens" in the cache +// updates the cache head +// returns a structure holding information about the slot found +// Note: On success, it's important that cache.head points +// to the first cell of the slot. +struct llama_kv_cache_slot_info llama_kv_cache_find_slot( + struct llama_kv_cache & cache, + const struct llama_ubatch & batch); + +// find how many cells are currently in use +uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache); + +void llama_kv_cache_clear(struct llama_kv_cache & cache); + +bool llama_kv_cache_seq_rm( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1); + +void llama_kv_cache_seq_cp( + struct llama_kv_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); + +void llama_kv_cache_seq_keep( + struct llama_kv_cache & cache, + llama_seq_id seq_id); + +void llama_kv_cache_seq_add( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta); + +void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d); + +llama_pos llama_kv_cache_seq_pos_max( + struct llama_kv_cache & cache, + llama_seq_id seq_id); + +void llama_kv_cache_defrag(struct llama_kv_cache & cache); + +int32_t llama_get_kv_cache_token_count(const struct llama_kv_cache & kv); + +int32_t llama_get_kv_cache_used_cells(const struct llama_kv_cache & kv); + +bool llama_kv_cache_can_shift(const struct llama_kv_cache & kv); + +// +// kv cache view +// + +struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_kv_cache & kv, int32_t n_seq_max); + +void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct llama_kv_cache & kv); + +// +// kv cache restore +// + +// saves the kv_cache state for future recovery. +// used to rollback llama_kv_cache_find_slot changes. +struct llama_kv_slot_restorer { + struct llama_kv_cache_state { + uint32_t head = 0; + uint32_t n = 0; + } old_state; + + // for non-recurrent models only + // list of slots to restore + std::vector> slot_boundaries; + + bool do_restore = false; + + explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { + old_state.head = cache.head; + old_state.n = cache.n; + } + + // saves a slot information for future restoration + void save(const struct llama_kv_cache_slot_info & slot) { + if (slot) { + do_restore = true; + if (slot.boundaries.first != slot.boundaries.second) { + slot_boundaries.push_back(slot.boundaries); + } + } + } + + // must be explicitly called to restore the kv_cache state + // and rollback changes from all llama_kv_cache_find_slot calls + void restore(struct llama_kv_cache & cache) { + if (do_restore) { + cache.head = old_state.head; + cache.n = old_state.n; + + if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased + llama_kv_cache_seq_rm(cache, -1, -1, -1); + } else { + for (auto & slot : slot_boundaries) { + llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second); + } + } + } + } +}; + diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp new file mode 100644 index 000000000..b716630a8 --- /dev/null +++ b/src/llama-mmap.cpp @@ -0,0 +1,590 @@ +#include "llama-mmap.h" + +#include "llama-impl.h" + +#include "ggml.h" + +#include +#include +#include +#include + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #ifndef PATH_MAX + #define PATH_MAX MAX_PATH + #endif + #include +#endif + +// TODO: consider moving to llama-impl.h if needed in more places +#if defined(_WIN32) +static std::string llama_format_win_err(DWORD err) { + LPSTR buf; + size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); + if (!size) { + return "FormatMessageA failed"; + } + std::string ret(buf, size); + LocalFree(buf); + return ret; +} +#endif + +// llama_file + +struct llama_file::impl { +#if defined(_WIN32) + HANDLE fp_win32; + std::string GetErrorMessageWin32(DWORD error_code) const { + std::string ret; + LPSTR lpMsgBuf = NULL; + DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); + if (!bufLen) { + ret = format("Win32 error code: %lx", error_code); + } else { + ret = lpMsgBuf; + LocalFree(lpMsgBuf); + } + + return ret; + } + + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { + LARGE_INTEGER li; + li.QuadPart = 0; + BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + + return li.QuadPart; + } + + void seek(size_t offset, int whence) const { + static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN"); + static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT"); + static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END"); + + LARGE_INTEGER li; + li.QuadPart = offset; + BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + } + + void read_raw(void * ptr, size_t len) const { + size_t bytes_read = 0; + while (bytes_read < len) { + size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); + DWORD chunk_read = 0; + BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL); + if (!result) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_read < chunk_size || chunk_read == 0) { + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += chunk_read; + } + } + + uint32_t read_u32() const { + uint32_t val; + read_raw(&val, sizeof(val)); + return val; + } + + void write_raw(const void * ptr, size_t len) const { + size_t bytes_written = 0; + while (bytes_written < len) { + size_t chunk_size = std::min(len - bytes_written, 64*1024*1024); + DWORD chunk_written = 0; + BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL); + if (!result) { + throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_written < chunk_size || chunk_written == 0) { + throw std::runtime_error("unexpectedly failed to write bytes"); + } + + bytes_written += chunk_written; + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#else + impl(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + + return (size_t) ret; + } + + void seek(size_t offset, int whence) const { +// TODO: this ifdef is never true? +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + if (ret != 0) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + } + + void read_raw(void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, len, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error("unexpectedly reached end of file"); + } + } + + uint32_t read_u32() const { + uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + void write_raw(const void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, len, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~impl() { + if (fp) { + std::fclose(fp); + } + } +#endif + + FILE * fp; + size_t size; +}; + +llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique(fname, mode)) {} +llama_file::~llama_file() = default; + +size_t llama_file::tell() const { return pimpl->tell(); } +size_t llama_file::size() const { return pimpl->size; } + +int llama_file::file_id() const { +#ifdef _WIN32 + return _fileno(pimpl->fp); +#else +#if defined(fileno) + return fileno(pimpl->fp); +#else + return ::fileno(pimpl->fp); +#endif +#endif +} + +void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } +void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } + +uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } + +void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); } +void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); } + +// llama_mmap + +struct llama_mmap::impl { +#ifdef _POSIX_MAPPED_FILES + std::vector> mapped_fragments; + + impl(struct llama_file * file, size_t prefetch, bool numa) { + size = file->size(); + int fd = file->file_id(); + int flags = MAP_SHARED; + if (numa) { prefetch = 0; } +#ifdef __linux__ + if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) { + LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n", + strerror(errno)); + } + if (prefetch) { flags |= MAP_POPULATE; } +#endif + addr = mmap(NULL, file->size(), PROT_READ, flags, fd, 0); + if (addr == MAP_FAILED) { + throw std::runtime_error(format("mmap failed: %s", strerror(errno))); + } + + if (prefetch > 0) { + if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } + } + if (numa) { + if (posix_madvise(addr, file->size(), POSIX_MADV_RANDOM)) { + LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n", + strerror(errno)); + } + } + + mapped_fragments.emplace_back(0, file->size()); + } + + static void align_range(size_t * first, size_t * last, size_t page_size) { + size_t offset_in_page = *first & (page_size - 1); + size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page; + *first += offset_to_page; + + *last = *last & ~(page_size - 1); + + if (*last <= *first) { + *last = *first; + } + } + + void unmap_fragment(size_t first, size_t last) { + int page_size = sysconf(_SC_PAGESIZE); + align_range(&first, &last, page_size); + size_t len = last - first; + + if (len == 0) { + return; + } + + GGML_ASSERT(first % page_size == 0); + GGML_ASSERT(last % page_size == 0); + GGML_ASSERT(last > first); + + void * next_page_start = (uint8_t *) addr + first; + + if (munmap(next_page_start, len)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + + std::vector> new_mapped_fragments; + for (const auto & frag : mapped_fragments) { + if (frag.first < first && frag.second > last) { + new_mapped_fragments.emplace_back(frag.first, first); + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first < first && frag.second > first) { + new_mapped_fragments.emplace_back(frag.first, first); + } else if (frag.first < last && frag.second > last) { + new_mapped_fragments.emplace_back(last, frag.second); + } else if (frag.first >= first && frag.second <= last) { + } else { + new_mapped_fragments.push_back(frag); + } + } + mapped_fragments = std::move(new_mapped_fragments); + } + + ~impl() { + for (const auto & frag : mapped_fragments) { + if (munmap((char *) addr + frag.first, frag.second - frag.first)) { + LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno)); + } + } + } +#elif defined(_WIN32) + impl(struct llama_file * file, size_t prefetch, bool numa) { + GGML_UNUSED(numa); + + size = file->size(); + + HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + + if (hMapping == NULL) { + DWORD error = GetLastError(); + throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); + } + + addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + DWORD error = GetLastError(); + CloseHandle(hMapping); + + if (addr == NULL) { + throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); + } + + if (prefetch > 0) { +#if _WIN32_WINNT >= 0x602 + BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); + HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); + + pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory"); + + if (pPrefetchVirtualMemory) { + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T) std::min(size, prefetch); + if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + throw std::runtime_error("PrefetchVirtualMemory unavailable"); +#endif + } + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + } + + ~impl() { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + impl(struct llama_file * file, size_t prefetch, bool numa) { + GGML_UNUSED(file); + GGML_UNUSED(prefetch); + GGML_UNUSED(numa); + + throw std::runtime_error("mmap not supported"); + } + + void unmap_fragment(size_t first, size_t last) { + GGML_UNUSED(first); + GGML_UNUSED(last); + + throw std::runtime_error("mmap not supported"); + } +#endif + + void * addr; + size_t size; +}; + +llama_mmap::llama_mmap(struct llama_file * file, size_t prefetch, bool numa) : pimpl(std::make_unique(file, prefetch, numa)) {} +llama_mmap::~llama_mmap() = default; + +size_t llama_mmap::size() const { return pimpl->size; } +void * llama_mmap::addr() const { return pimpl->addr; } + +void llama_mmap::unmap_fragment(size_t first, size_t last) { pimpl->unmap_fragment(first, last); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mmap::SUPPORTED = true; +#else +const bool llama_mmap::SUPPORTED = false; +#endif + +// llama_mlock + +struct llama_mlock::impl { +#ifdef _POSIX_MEMLOCK_RANGE + static size_t lock_granularity() { + return (size_t) sysconf(_SC_PAGESIZE); + } + + bool raw_lock(const void * addr, size_t size) const { + if (!mlock(addr, size)) { + return true; + } + +#ifdef __APPLE__ +#define MLOCK_SUGGESTION \ + "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n" +#else +#define MLOCK_SUGGESTION \ + "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n" +#endif + + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); + + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { + suggest = false; + } + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + suggest = false; + } + + LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); + return false; + } + + static void raw_unlock(void * addr, size_t size) { + if (munlock(addr, size)) { + LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno)); + } + } +#elif defined(_WIN32) + static size_t lock_granularity() { + SYSTEM_INFO si; + GetSystemInfo(&si); + return (size_t) si.dwPageSize; + } + + bool raw_lock(void * ptr, size_t len) const { + for (int tries = 1; ; tries++) { + if (VirtualLock(ptr, len)) { + return true; + } + if (tries == 2) { + LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", + len, size, llama_format_win_err(GetLastError()).c_str()); + return false; + } + + SIZE_T min_ws_size, max_ws_size; + if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { + LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + size_t increment = len + 1048576; + min_ws_size += increment; + max_ws_size += increment; + if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { + LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + } + } + + static void raw_unlock(void * ptr, size_t len) { + if (!VirtualUnlock(ptr, len)) { + LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static size_t lock_granularity() { + return (size_t) 65536; + } + + bool raw_lock(const void * addr, size_t len) const { + LLAMA_LOG_WARN("warning: mlock not supported on this system\n"); + return false; + } + + static void raw_unlock(const void * addr, size_t len) {} +#endif + + impl() : addr(NULL), size(0), failed_already(false) {} + + void init(void * ptr) { + GGML_ASSERT(addr == NULL && size == 0); + addr = ptr; + } + + void grow_to(size_t target_size) { + GGML_ASSERT(addr); + if (failed_already) { + return; + } + size_t granularity = lock_granularity(); + target_size = (target_size + granularity - 1) & ~(granularity - 1); + if (target_size > size) { + if (raw_lock((uint8_t *) addr + size, target_size - size)) { + size = target_size; + } else { + failed_already = true; + } + } + } + + void * addr; + size_t size; + + bool failed_already; +}; + +llama_mlock::llama_mlock() : pimpl(std::make_unique()) {} +llama_mlock::~llama_mlock() = default; + +void llama_mlock::init(void * ptr) { pimpl->init(ptr); } +void llama_mlock::grow_to(size_t target_size) { pimpl->grow_to(target_size); } + +#if defined(_POSIX_MEMLOCK_RANGE) || defined(_WIN32) +const bool llama_mlock::SUPPORTED = true; +#else +const bool llama_mlock::SUPPORTED = false; +#endif + +size_t llama_path_max() { + return PATH_MAX; +} diff --git a/src/llama-mmap.h b/src/llama-mmap.h new file mode 100644 index 000000000..1da9ecb6b --- /dev/null +++ b/src/llama-mmap.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +struct llama_file; +struct llama_mmap; +struct llama_mlock; + +using llama_files = std::vector>; +using llama_mmaps = std::vector>; +using llama_mlocks = std::vector>; + +struct llama_file { + llama_file(const char * fname, const char * mode); + ~llama_file(); + + size_t tell() const; + size_t size() const; + + int file_id() const; // fileno overload + + void seek(size_t offset, int whence) const; + + void read_raw(void * ptr, size_t len) const; + uint32_t read_u32() const; + + void write_raw(const void * ptr, size_t len) const; + void write_u32(uint32_t val) const; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mmap { + llama_mmap(const llama_mmap &) = delete; + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false); + ~llama_mmap(); + + size_t size() const; + void * addr() const; + + void unmap_fragment(size_t first, size_t last); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +struct llama_mlock { + llama_mlock(); + ~llama_mlock(); + + void init(void * ptr); + void grow_to(size_t target_size); + + static const bool SUPPORTED; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +size_t llama_path_max(); diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp new file mode 100644 index 000000000..05d58ad90 --- /dev/null +++ b/src/llama-model-loader.cpp @@ -0,0 +1,1124 @@ +#include "llama-model-loader.h" + +#include "ggml.h" + +#include +#include +#include +#include + +static const size_t kiB = 1024; +static const size_t MiB = 1024*kiB; +static const size_t GiB = 1024*MiB; + +const char * llama_file_version_name(llama_fver version) { + switch (version) { + case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; + case GGUF_FILE_VERSION_V2: return "GGUF V2"; + case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)"; + } + + return "unknown"; +} + +static std::string llama_model_ftype_name(llama_ftype ftype) { + if (ftype & LLAMA_FTYPE_GUESSED) { + return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; + } + + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + + default: return "unknown, may not work"; + } +} + +// return a list of splits for a given path +// for example, given "-00002-of-00004.gguf", returns list of all 4 splits +static std::vector llama_get_list_splits(const std::string & path, const int idx, const int n_split) { + std::vector paths; + std::string split_prefix; + std::vector buf(llama_path_max(), 0); + + { + int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split); + if (!ret) { + throw std::runtime_error(format("invalid split file name: %s", path.c_str())); + } + split_prefix = std::string(buf.data(), ret); + } + + if (split_prefix.empty()) { + throw std::runtime_error(format("invalid split file: %s", path.c_str())); + } + + for (int idx = 0; idx < n_split; ++idx) { + int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split); + paths.push_back(std::string(buf.data(), ret)); + } + + return paths; +} + +namespace GGUFMeta { + template + struct GKV_Base_Type { + static constexpr gguf_type gt = gt_; + + static T getter(const gguf_context * ctx, const int kid) { + return gfun(ctx, kid); + } + }; + + template struct GKV_Base; + + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + + template<> struct GKV_Base { + static constexpr gguf_type gt = GGUF_TYPE_STRING; + + static std::string getter(const gguf_context * ctx, const int kid) { + return gguf_get_val_str(ctx, kid); + } + }; + + struct ArrayInfo { + const gguf_type gt; + const size_t length; + const void * data; + }; + + template<> struct GKV_Base { + public: + static constexpr gguf_type gt = GGUF_TYPE_ARRAY; + static ArrayInfo getter(const gguf_context *ctx, const int k) { + const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); + return ArrayInfo { + arr_type, + size_t(gguf_get_arr_n(ctx, k)), + arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), + }; + } + }; + + template + class GKV : public GKV_Base { + GKV() = delete; + + public: + static T get_kv(const gguf_context * ctx, const int k) { + const enum gguf_type kt = gguf_get_kv_type(ctx, k); + + if (kt != GKV::gt) { + throw std::runtime_error(format("key %s has wrong type %s but expected type %s", + gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt))); + } + return GKV::getter(ctx, k); + } + + static const char * override_type_to_str(const llama_model_kv_override_type ty) { + switch (ty) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: return "bool"; + case LLAMA_KV_OVERRIDE_TYPE_INT: return "int"; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float"; + case LLAMA_KV_OVERRIDE_TYPE_STR: return "str"; + } + return "unknown"; + } + + static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) { + if (!ovrd) { return false; } + if (ovrd->tag == expected_type) { + LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ", + __func__, override_type_to_str(ovrd->tag), ovrd->key); + switch (ovrd->tag) { + case LLAMA_KV_OVERRIDE_TYPE_BOOL: { + LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false"); + } break; + case LLAMA_KV_OVERRIDE_TYPE_INT: { + LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_FLOAT: { + LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64); + } break; + case LLAMA_KV_OVERRIDE_TYPE_STR: { + LLAMA_LOG_INFO("%s\n", ovrd->val_str); + } break; + default: + // Shouldn't be possible to end up here, but just in case... + throw std::runtime_error( + format("Unsupported attempt to override %s type for metadata key %s\n", + override_type_to_str(ovrd->tag), ovrd->key)); + } + return true; + } + LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n", + __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag)); + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) { + target = ovrd->val_bool; + return true; + } + return false; + } + + template + static typename std::enable_if::value && std::is_integral::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) { + target = ovrd->val_i64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) { + target = ovrd->val_f64; + return true; + } + return false; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override * ovrd) { + if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) { + target = ovrd->val_str; + return true; + } + return false; + } + + static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + if (try_override(target, ovrd)) { + return true; + } + if (k < 0) { return false; } + target = get_kv(ctx, k); + return true; + } + + static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, gguf_find_key(ctx, key), target, ovrd); + } + + static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) { + return set(ctx, key.c_str(), target, ovrd); + } + }; +} + + template + typename std::enable_if::value, bool>::type + llama_model_loader::get_arr_n(const std::string & key, T & result, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + + result = arr_info.length; + return true; + } + + template + typename std::enable_if::value, bool>::type + llama_model_loader::get_arr_n(enum llm_kv kid, T & result, bool required) { + return get_arr_n(llm_kv(kid), result, required); + } + + template bool llama_model_loader::get_arr_n(enum llm_kv kid, uint32_t & result, bool required); + + template + bool llama_model_loader::get_arr(const std::string & key, std::vector & result, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + + return true; + } + + template + bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + if (arr_info.length > N_MAX) { + throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); + } + + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + + return true; + } + + template + bool llama_model_loader::get_arr(enum llm_kv kid, T & result, bool required) { + return get_arr(llm_kv(kid), result, required); + } + + template + bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { + auto it = kv_overrides.find(key); + + const struct llama_model_kv_override * override = + it != kv_overrides.end() ? &it->second : nullptr; + + const bool found = GGUFMeta::GKV::set(meta.get(), key, result, override); + + if (required && !found) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + + return found; + } + + template + bool llama_model_loader::get_key(enum llm_kv kid, T & result, bool required) { + return get_key(llm_kv(kid), result, required); + } + + template bool llama_model_loader::get_key (enum llm_kv kid, bool & result, bool required); + template bool llama_model_loader::get_key (enum llm_kv kid, float & result, bool required); + template bool llama_model_loader::get_key (enum llm_kv kid, uint32_t & result, bool required); + template bool llama_model_loader::get_key(enum llm_kv kid, std::string & result, bool required); + + template<> + bool llama_model_loader::get_key(enum llm_kv kid, enum llama_pooling_type & result, bool required) { + uint32_t tmp; + const bool found = get_key(kid, tmp, required); + if (found) { + result = (enum llama_pooling_type) tmp; + } else { + result = LLAMA_POOLING_TYPE_UNSPECIFIED; + } + return found; + } + + // get array of n <= N_MAX elements, or a single element repeated n times + template + bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required) { + const int kid = gguf_find_key(meta.get(), key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + if (n > N_MAX) { + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + } + + if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) { + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta.get(), kid); + + if (n != arr_info.length) { + throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); + } + + return get_arr(key, result, required); + } + + T value; + + bool ok = get_key(key, value, required); + if (!ok) { + return false; + } + + for (uint32_t i = 0; i < n; i++) { + result[i] = value; + } + + return true; + } + + template + bool llama_model_loader::get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required) { + return get_key_or_arr(llm_kv(kid), result, n, required); + } + + // TODO: this is not very clever - figure out something better + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + +llama_model_loader::llama_model_loader( + const std::string & fname, + std::vector & splits, + bool use_mmap, + bool check_tensors, + const struct llama_model_kv_override * param_overrides_p) { + int trace = 0; + if (getenv("LLAMA_TRACE")) { + trace = atoi(getenv("LLAMA_TRACE")); + } + + if (param_overrides_p != nullptr) { + for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) { + kv_overrides.insert({std::string(p->key), *p}); + } + } + + // Load the main GGUF + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + + meta.reset(gguf_init_from_file(fname.c_str(), params)); + if (!meta) { + throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + } + + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(fname.c_str(), "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset of the main file. + // For subsidiary files, `meta` tensor data offset must not be used, + // so we build a unified tensors index for weights. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur)); + } + uint16_t n_split = 0; + get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); + + // Load additional GGML contexts + if (n_split > 1) { + // make sure the main file is loaded first + uint16_t idx = 0; + const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); + get_key(kv_split_no, idx); + if (idx != 0) { + throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); + } + + // generate list of splits if needed + if (splits.empty()) { + splits = llama_get_list_splits(fname, idx, n_split); + } + + // in case user give a custom list of splits, check if it matches the expected number + if (n_split != (uint16_t)splits.size()) { + throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); + } + + if (trace > 0) { + LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); + } + + // load other splits + for (idx = 1; idx < n_split; idx++) { + const char * fname_split = splits[idx].c_str(); + + struct gguf_init_params split_params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split)); + } + + // check idx + { + const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); + if (kid < 0) { + throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + } + int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); + if (idx_gguf != idx) { + throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + } + } + + files.emplace_back(new llama_file(fname_split, "rb")); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the shard. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); + } + } + + get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); + + // sanity check + { + const int n_tensors_loaded = (int) weights_map.size(); + if (n_tensors != n_tensors_loaded) { + throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); + } + } + + LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); + } + + n_kv = gguf_get_n_kv(meta.get()); + n_tensors = weights_map.size(); + + fver = (enum llama_fver) gguf_get_version(meta.get()); + + LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", + __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + + // determine file type based on the number of tensors for each quantization and print meta data + // TODO: make optional + { + std::map n_type; + + uint32_t n_type_max = 0; + enum ggml_type type_max = GGML_TYPE_F32; + + for (const auto & it : weights_map) { + const llama_tensor_weight & w = it.second; + const ggml_tensor * tensor = w.tensor; + + enum ggml_type type = tensor->type; + + n_type[type]++; + + if (n_type_max < n_type[type]) { + n_type_max = n_type[type]; + type_max = type; + } + + if (trace > 0) { + const uint16_t sid = w.idx; + LLAMA_LOG_INFO("%s: - tensor split %2d: %32s %-8s [ %s ]\n", __func__, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str()); + } + } + + switch (type_max) { + case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; + case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_BF16: ftype = LLAMA_FTYPE_MOSTLY_BF16; break; + case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; + case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; + case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; + case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; + case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; + case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; + case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; + case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; + case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + case GGML_TYPE_TQ1_0: ftype = LLAMA_FTYPE_MOSTLY_TQ1_0; break; + case GGML_TYPE_TQ2_0: ftype = LLAMA_FTYPE_MOSTLY_TQ2_0; break; + case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; + case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; + case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break; + case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break; + case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break; + case GGML_TYPE_IQ1_M: ftype = LLAMA_FTYPE_MOSTLY_IQ1_M; break; + case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; + case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; + case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + default: + { + LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); + ftype = LLAMA_FTYPE_ALL_F32; + } break; + } + + // this is a way to mark that we have "guessed" the file type + ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); + + { + const int kid = gguf_find_key(meta.get(), "general.file_type"); // TODO: use LLM_KV + if (kid >= 0) { + ftype = (llama_ftype) gguf_get_val_u32(meta.get(), kid); + } + } + + LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); + + for (int i = 0; i < n_kv; i++) { + const char * name = gguf_get_key(meta.get(), i); + const enum gguf_type type = gguf_get_kv_type(meta.get(), i); + const std::string type_name = + type == GGUF_TYPE_ARRAY + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) + : gguf_type_name(type); + + std::string value = gguf_kv_to_str(meta.get(), i); + const size_t MAX_VALUE_LEN = 40; + if (value.size() > MAX_VALUE_LEN) { + value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); + } + replace_all(value, "\n", "\\n"); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str()); + } + + // print type counts + for (auto & kv : n_type) { + if (kv.second == 0) { + continue; + } + + LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); + } + } + + if (!llama_mmap::SUPPORTED) { + LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); + use_mmap = false; + } + + this->use_mmap = use_mmap; + this->check_tensors = check_tensors; +} + +std::string llama_model_loader::get_arch_name() const { + return arch_name; +} + +enum llm_arch llama_model_loader::get_arch() const { + return llm_kv.arch; +} + +const llama_model_loader::llama_tensor_weight * llama_model_loader::get_weight(const char * name) const { + auto pos = weights_map.find(name); + if (pos != weights_map.end()) { + return &pos->second; + } + + return nullptr; +} + +const llama_model_loader::llama_tensor_weight & llama_model_loader::require_weight(const char * name) const { + const llama_tensor_weight * weight = get_weight(name); + if (!weight) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name)); + } + return *weight; +} + +struct ggml_tensor * llama_model_loader::get_tensor_meta(const char * name) const { + const auto * weight = get_weight(name); + if (!weight) { + return nullptr; + } + return weight->tensor; +} + +struct ggml_tensor * llama_model_loader::require_tensor_meta(const std::string & name) const { + struct ggml_tensor * tensor = get_tensor_meta(name.c_str()); + if (!tensor) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + return tensor; +} + +const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const { + const struct ggml_tensor * cur = get_tensor_meta(name.c_str()); + + if (cur == NULL) { + if (!required) { + return NULL; + } + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + + { + bool is_ok = true; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) { + is_ok = false; + break; + } + } + if (!is_ok) { + throw std::runtime_error( + format("%s: tensor '%s' has wrong shape; expected %s, got %s", + __func__, name.c_str(), + llama_format_tensor_shape(ne).c_str(), + llama_format_tensor_shape(cur).c_str())); + } + } + + return cur; +} + +struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); + + if (cur == NULL) { + return NULL; + } + + bool duplicated = flags & TENSOR_DUPLICATED; + + struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); + ggml_set_name(tensor, ggml_get_name(cur)); + + if (duplicated) { + size_data += ggml_nbytes(cur); + } else { + n_created++; + } + + return tensor; + +} + +struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, required); + + if (cur == NULL) { + return NULL; + } + + if (cur->type != base->type) { + throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type))); + } + + std::array dims; + for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { + dims[i] = i < ne.size() ? ne.begin()[i] : 1; + } + + struct ggml_tensor * tensor = ggml_view_4d(ctx, base, + dims[0], dims[1], dims[2], dims[3], + cur->nb[1], cur->nb[2], cur->nb[3], + offset); + + ggml_set_name(tensor, name.c_str()); + + n_created++; + + return tensor; +} + +void llama_model_loader::done_getting_tensors() const { + if (n_created != n_tensors) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } +} + +void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps) { + if (use_mmap) { + mappings.reserve(files.size()); + mmaps_used.reserve(files.size()); + for (const auto & file : files) { + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); + auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); + mmaps_used.emplace_back(mapping->size(), 0); + if (mlock_mmaps) { + std::unique_ptr mlock_mmap(new llama_mlock()); + mlock_mmap->init(mapping->addr()); + mlock_mmaps->emplace_back(std::move(mlock_mmap)); + } + mappings.emplace_back(std::move(mapping)); + } + } + + // compute the total size of all tensors for progress reporting + for (const auto & it : weights_map) { + size_data += ggml_nbytes(it.second.tensor); + } +} + +void llama_model_loader::get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const { + GGML_ASSERT(!mappings.empty()); + const auto & mapping = mappings.at(idx); + + *first = mapping->size(); + *last = 0; + *addr = mapping->addr(); + for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) { + const auto * weight = get_weight(ggml_get_name(tensor)); + if (!weight || weight->idx != idx) { + continue; + } + *first = std::min(*first, weight->offs); + *last = std::max(*last, weight->offs + ggml_nbytes(tensor)); + } +} + +void llama_model_loader::load_data_for(struct ggml_tensor * cur) const { + const auto & w = require_weight(ggml_get_name(cur)); + + if (use_mmap) { + const auto & mapping = mappings.at(w.idx); + if (cur->data == nullptr) { + cur->data = (uint8_t *)mapping->addr() + w.offs; + } else { + memcpy(cur->data, (uint8_t *)mapping->addr() + w.offs, ggml_nbytes(cur)); + } + } else { + GGML_ASSERT(cur->data != nullptr); + GGML_ASSERT(w.idx < files.size()); + const auto & file = files.at(w.idx); + file->seek(w.offs, SEEK_SET); + file->read_raw(cur->data, ggml_nbytes(cur)); + } + + if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } +} + +bool llama_model_loader::load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data) { + GGML_ASSERT(size_data != 0 && "call init_mappings() first"); + + std::vector> read_buf; + std::vector>> validation_result; + + // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. + // NVMe raid configurations might require more / larger buffers. + constexpr size_t n_buffers = 4; + constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + std::vector host_buffers; + std::vector events; + std::vector host_ptrs; + size_t buffer_idx = 0; // buffer to use for async loads + ggml_backend_t upload_backend = [&](const char * func) -> ggml_backend_t { + if (use_mmap || check_tensors) { + return nullptr; + } + // When not using mmaped io use async uploads from pinned memory to GPU memory. + // First determine if the backend supports the necessary features for async uploads. + auto * buf = bufs.count(0) ? bufs.at(0) : nullptr; + if (!buf) { + LLAMA_LOG_DEBUG("%s: no buffer found for async uploads\n", func); + return nullptr; + } + + auto * buft = ggml_backend_buffer_get_type(buf); + auto * dev = ggml_backend_buft_get_device(buft); + if (!dev) { + LLAMA_LOG_DEBUG("%s: no device found for buffer type %s for async uploads\n", func, + ggml_backend_buft_name(buft)); + return nullptr; + } + + if (buft != ggml_backend_dev_buffer_type(dev)) { + LLAMA_LOG_DEBUG("%s: buffer type %s is not the default buffer type for device %s for async uploads\n", func, + ggml_backend_buft_name(buft), ggml_backend_dev_name(dev)); + return nullptr; + } + + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + if (!props.caps.async || !props.caps.host_buffer || !props.caps.events) { + LLAMA_LOG_DEBUG("%s: device %s does not support async, host buffers or events\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + if (!host_buft) { + LLAMA_LOG_DEBUG("%s: no host buffer type found for device %s\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + // If the backend is supported, create pinned memory buffers and events for synchronisation. + for (size_t idx = 0; idx < n_buffers; ++idx) { + auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size); + if (!buf) { + LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + host_buffers.emplace_back(buf); + host_ptrs.emplace_back(ggml_backend_buffer_get_base(buf)); + + auto * event = ggml_backend_event_new(dev); + if (!event) { + LLAMA_LOG_DEBUG("%s: failed to create event for async uploads for device %s\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + events.emplace_back(event); + } + + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + LLAMA_LOG_DEBUG("%s: failed to initialize backend for device %s for async uploads\n", func, + ggml_backend_dev_name(dev)); + return nullptr; + } + + return backend; + }(__func__); + + if (upload_backend) { + LLAMA_LOG_DEBUG("%s: using async uploads for device %s, buffer type %s, backend %s\n", __func__, + ggml_backend_dev_name(ggml_backend_get_device(upload_backend)), + ggml_backend_buft_name(ggml_backend_buffer_get_type(bufs.at(0))), + ggml_backend_name(upload_backend)); + } + + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { + const auto * weight = get_weight(ggml_get_name(cur)); + if (weight == nullptr) { + // this can happen with split experts models + continue; + } + + if (progress_callback) { + if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) { + return false; + } + } + + size_t n_size = ggml_nbytes(cur); + + if (use_mmap) { + const auto & mapping = mappings.at(weight->idx); + ggml_backend_buffer_t buf_mmap = nullptr; + if (bufs.count(weight->idx)) { + buf_mmap = bufs.at(weight->idx); + } + uint8_t * data = (uint8_t *) mapping->addr() + weight->offs; + + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size)); + })); + } + + GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated + if (buf_mmap && cur->data == nullptr) { + ggml_backend_tensor_alloc(buf_mmap, cur, data); + if (lmlocks) { + const auto & lmlock = lmlocks->at(weight->idx); + lmlock->grow_to(weight->offs + n_size); + } + + auto & mmap_used = mmaps_used[weight->idx]; + mmap_used.first = std::min(mmap_used.first, weight->offs); + mmap_used.second = std::max(mmap_used.second, weight->offs + n_size); + } else { + ggml_backend_tensor_set(cur, data, 0, n_size); + } + } else { + const auto & file = files.at(weight->idx); + if (ggml_backend_buffer_is_host(cur->buffer)) { + file->seek(weight->offs, SEEK_SET); + file->read_raw(cur->data, n_size); + if (check_tensors) { + validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] { + return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size)); + })); + } + } else { + // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (upload_backend) { + file->seek(weight->offs, SEEK_SET); + + size_t bytes_read = 0; + + while (bytes_read < n_size) { + size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + + ggml_backend_event_synchronize(events[buffer_idx]); + file->read_raw(host_ptrs[buffer_idx], read_iteration); + ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + ggml_backend_event_record(events[buffer_idx], upload_backend); + + bytes_read += read_iteration; + ++buffer_idx; + buffer_idx %= n_buffers; + } + } else { + read_buf.resize(n_size); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); + ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } + } + } + } + + size_done += n_size; + } + + // free temporary resources used for async uploads + for (auto * event : events) { + ggml_backend_event_synchronize(event); + ggml_backend_event_free(event); + } + for (auto * buf : host_buffers) { + ggml_backend_buffer_free(buf); + } + ggml_backend_free(upload_backend); + + // check validation results + bool validation_failed = false; + for (auto & future : validation_result) { + auto result = future.get(); + if (!result.second) { + LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first)); + validation_failed = true; + } + } + if (validation_failed) { + throw std::runtime_error("found tensors with invalid data"); + } + + // check if this is the last call and do final cleanup + if (size_done >= size_data) { + // unmap offloaded tensors and metadata + if (use_mmap) { + for (uint32_t idx = 0; idx < mappings.size(); idx++) { + const auto & mmap_used = mmaps_used.at(idx); + auto & mapping = mappings.at(idx); + mapping->unmap_fragment(0, mmap_used.first); + if (mmap_used.second != 0) { + mapping->unmap_fragment(mmap_used.second, mapping->size()); + } + } + } + if (progress_callback) { + // Even though the model is done loading, we still honor + // cancellation since we need to free allocations. + return progress_callback(1.0f, progress_callback_user_data); + } + } + + return true; +} + +std::string llama_model_loader::ftype_name() const { + return llama_model_ftype_name(ftype); +} + +void llama_model_loader::print_info() const { + LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver)); + LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str()); + if (n_bytes < GiB) { + LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements); + } else { + LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements); + } +} diff --git a/src/llama-model-loader.h b/src/llama-model-loader.h new file mode 100644 index 000000000..fe35404b2 --- /dev/null +++ b/src/llama-model-loader.h @@ -0,0 +1,167 @@ +#pragma once + +#include "llama.h" + +#include "llama-impl.h" +#include "llama-arch.h" +#include "llama-mmap.h" + +#include "ggml-cpp.h" + +#include +#include +#include +#include + +using llama_buf_map = std::unordered_map; + +enum llama_fver { + GGUF_FILE_VERSION_V1 = 1, + GGUF_FILE_VERSION_V2 = 2, + GGUF_FILE_VERSION_V3 = 3, +}; + +const char * llama_file_version_name(llama_fver version); + +struct llama_model_loader { + // Holds information on a model weight + struct llama_tensor_weight { + uint16_t idx; // source file index + size_t offs; // tensor data offset in the original file + + ggml_tensor * tensor; + + llama_tensor_weight(const llama_file * file, uint16_t idx, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) { + const int tensor_idx = gguf_find_tensor(gguf_ctx, ggml_get_name(tensor)); + if (tensor_idx < 0) { + throw std::runtime_error(format("tensor '%s' not found in the model", ggml_get_name(tensor))); + } + + offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx); + if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size()) { + throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", ggml_get_name(tensor))); + } + } + }; + + // custom comparator to sort weights more nicely by layer + struct weight_name_comparer { + bool operator()(const std::string & a, const std::string & b) const { + int a_layer = -1; + int b_layer = -1; + sscanf(a.c_str(), "blk.%d.", &a_layer); + sscanf(b.c_str(), "blk.%d.", &b_layer); + if (a_layer != b_layer) { + return a_layer < b_layer; + } + return a < b; + } + }; + + static const int TENSOR_NOT_REQUIRED = 1; + static const int TENSOR_DUPLICATED = 2; + + int n_kv = 0; + int n_tensors = 0; + int n_created = 0; + + uint64_t n_elements = 0; + size_t n_bytes = 0; + + bool use_mmap = false; + bool check_tensors; + + llama_files files; + llama_ftype ftype; + llama_fver fver; + + llama_mmaps mappings; + + std::map weights_map; + std::unordered_map kv_overrides; + + gguf_context_ptr meta; + std::vector contexts; + + std::string arch_name; + LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN); + + size_t size_done = 0; + size_t size_data = 0; + std::vector> mmaps_used; + + llama_model_loader( + const std::string & fname, + std::vector & splits, // optional, only need if the split does not follow naming scheme + bool use_mmap, + bool check_tensors, + const struct llama_model_kv_override * param_overrides_p); + + template + typename std::enable_if::value, bool>::type + get_arr_n(const std::string & key, T & result, bool required = true); + + template + typename std::enable_if::value, bool>::type + get_arr_n(enum llm_kv kid, T & result, bool required = true); + + template + bool get_arr(const std::string & key, std::vector & result, bool required = true); + + template + bool get_arr(const std::string & key, std::array & result, bool required = true); + + template + bool get_arr(enum llm_kv kid, T & result, bool required = true); + + template + bool get_key(const std::string & key, T & result, bool required = true); + + template + bool get_key(enum llm_kv kid, T & result, bool required = true); + + template + bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required = true); + + template + bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true); + + std::string get_arch_name() const; + + enum llm_arch get_arch() const; + + const llama_tensor_weight * get_weight(const char * name) const; + + const llama_tensor_weight & require_weight(const char * name) const; + + struct ggml_tensor * get_tensor_meta(const char * name) const; + + struct ggml_tensor * require_tensor_meta(const std::string & name) const; + + const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const; + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags = 0); + + struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); + + void done_getting_tensors() const; + + void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); + + void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const; + + // for backwards compatibility, does not support ggml-backend + void load_data_for(struct ggml_tensor * cur) const; + + // Returns false if cancelled by progress_callback + bool load_all_data( + struct ggml_context * ctx, + llama_buf_map & bufs, + llama_mlocks * lmlocks, + llama_progress_callback progress_callback, + void * progress_callback_user_data); + + std::string ftype_name() const; + + void print_info() const; +}; diff --git a/src/llama-model.cpp b/src/llama-model.cpp new file mode 100644 index 000000000..18bd0b071 --- /dev/null +++ b/src/llama-model.cpp @@ -0,0 +1,4001 @@ +#include "llama-model.h" + +#include "llama-impl.h" +#include "llama-mmap.h" +#include "llama-model-loader.h" + +#include "ggml-cpp.h" + +#include +#include +#include +#include +#include +#include +#include + +const char * llm_type_name(llm_type type) { + switch (type) { + case LLM_TYPE_14M: return "14M"; + case LLM_TYPE_17M: return "17M"; + case LLM_TYPE_22M: return "22M"; + case LLM_TYPE_33M: return "33M"; + case LLM_TYPE_60M: return "60M"; + case LLM_TYPE_70M: return "70M"; + case LLM_TYPE_80M: return "80M"; + case LLM_TYPE_109M: return "109M"; + case LLM_TYPE_137M: return "137M"; + case LLM_TYPE_160M: return "160M"; + case LLM_TYPE_220M: return "220M"; + case LLM_TYPE_250M: return "250M"; + case LLM_TYPE_270M: return "270M"; + case LLM_TYPE_335M: return "335M"; + case LLM_TYPE_410M: return "410M"; + case LLM_TYPE_450M: return "450M"; + case LLM_TYPE_770M: return "770M"; + case LLM_TYPE_780M: return "780M"; + case LLM_TYPE_0_5B: return "0.5B"; + case LLM_TYPE_1B: return "1B"; + case LLM_TYPE_1_3B: return "1.3B"; + case LLM_TYPE_1_4B: return "1.4B"; + case LLM_TYPE_1_5B: return "1.5B"; + case LLM_TYPE_1_6B: return "1.6B"; + case LLM_TYPE_2B: return "2B"; + case LLM_TYPE_2_8B: return "2.8B"; + case LLM_TYPE_3B: return "3B"; + case LLM_TYPE_4B: return "4B"; + case LLM_TYPE_6B: return "6B"; + case LLM_TYPE_6_9B: return "6.9B"; + case LLM_TYPE_7B: return "7B"; + case LLM_TYPE_8B: return "8B"; + case LLM_TYPE_9B: return "9B"; + case LLM_TYPE_11B: return "11B"; + case LLM_TYPE_12B: return "12B"; + case LLM_TYPE_13B: return "13B"; + case LLM_TYPE_14B: return "14B"; + case LLM_TYPE_15B: return "15B"; + case LLM_TYPE_16B: return "16B"; + case LLM_TYPE_20B: return "20B"; + case LLM_TYPE_30B: return "30B"; + case LLM_TYPE_32B: return "32B"; + case LLM_TYPE_34B: return "34B"; + case LLM_TYPE_35B: return "35B"; + case LLM_TYPE_40B: return "40B"; + case LLM_TYPE_65B: return "65B"; + case LLM_TYPE_70B: return "70B"; + case LLM_TYPE_236B: return "236B"; + case LLM_TYPE_314B: return "314B"; + case LLM_TYPE_671B: return "671B"; + case LLM_TYPE_SMALL: return "0.1B"; + case LLM_TYPE_MEDIUM: return "0.4B"; + case LLM_TYPE_LARGE: return "0.8B"; + case LLM_TYPE_XL: return "1.5B"; + case LLM_TYPE_A1_7B: return "A1.7B"; + case LLM_TYPE_A2_7B: return "A2.7B"; + case LLM_TYPE_8x7B: return "8x7B"; + case LLM_TYPE_8x22B: return "8x22B"; + case LLM_TYPE_16x12B: return "16x12B"; + case LLM_TYPE_16x3_8B: return "16x3.8B"; + case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B"; + case LLM_TYPE_57B_A14B: return "57B.A14B"; + case LLM_TYPE_27B: return "27B"; + default: return "?B"; + } +} + +static const char * llama_expert_gating_func_name(llama_expert_gating_func_type type) { + switch (type) { + case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax"; + case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid"; + default: return "unknown"; + } +} + +static const std::map LLAMA_ROPE_SCALING_TYPES = { + { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, + { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, + { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, + { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" }, +}; + +static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { + for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { + if (kv.second == name) { + return (llama_rope_scaling_type) kv.first; + } + } + + return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; +} + +// checks if the weight tensor can be used with the specified buffer type and device +static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + GGML_ASSERT(w != nullptr); + + if (op == GGML_OP_NONE) { + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + switch (op) { + case GGML_OP_GET_ROWS: + { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_get_rows(ctx, w, b); + } break; + case GGML_OP_MUL_MAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } break; + case GGML_OP_MUL_MAT_ID: + { + int n_expert_used = hparams.n_expert_used; + ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_mul_mat_id(ctx, w, b, ids); + } break; + case GGML_OP_ADD: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_add(ctx, a, w); + } break; + case GGML_OP_MUL: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_mul(ctx, a, w); + } break; + case GGML_OP_DIV: + { + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); + op_tensor = ggml_div(ctx, a, w); + } break; + case GGML_OP_ROPE: + { + int n_embd_head = hparams.n_embd_head_v; + int n_head = hparams.n_head(); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_rope_ext( + ctx, a, b, w, + 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ); + + } break; + case GGML_OP_SSM_CONV: + { + // FIXME + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + op_tensor = ggml_ssm_conv(ctx, conv_x, w); + } break; + case GGML_OP_SSM_SCAN: + { + // FIXME + const int64_t d_state = w->ne[0]; + const int64_t d_inner = w->ne[1]; + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 1; + ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); + ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + } break; + case GGML_OP_RWKV_WKV6: + { + // FIXME + const int64_t S = 123; + const int64_t H = 123; + const int64_t n_tokens = 123; + const int64_t n_seqs = 123; + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * tf = w; + ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); + } break; + case GGML_OP_IM2COL: + { + const int n_embd = hparams.n_embd; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); + op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); + } break; + default: + GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + + return op_supported; +} + +// lists of buffer types used for each layer +using buft_list_t = std::vector>; + +// find the first buffer type in the list that can use the tensor +static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t & buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { + return cur_buft; + } + } + return nullptr; +} + +// CPU: ACCEL -> CPU extra -> GPU host -> CPU +static buft_list_t make_cpu_buft_list(const std::vector & devices) { + buft_list_t buft_list; + + // add ACCEL buffer types + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + auto * buft = ggml_backend_dev_buffer_type(dev); + // skip + if (buft != ggml_backend_cpu_buffer_type()) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add extra buffer types + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // add a host buffer type + // storing the tensors in a host buffer is useful when the processing of large batches + // is offloaded to a GPU device, since it reduces the time spent on data transfers + // generally, this will be done using the first device in the list + // a better approach would be to handle this on a weight-by-weight basis using the offload_op + // function of the device to determine if it would benefit from being stored in a host buffer + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } + } + + // add the CPU buffer type + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) { + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + } + } + + return buft_list; +} + +// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU +static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) { + buft_list_t buft_list; + + // add the device split buffer type if requested and available + if (split_mode == LLAMA_SPLIT_MODE_ROW) { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type"); + if (ggml_backend_split_buffer_type_fn) { + size_t dev_index = [&]() { + auto * reg = ggml_backend_dev_backend_reg(dev); + for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) { + if (ggml_backend_reg_dev_get(reg, i) == dev) { + return i; + } + } + throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev))); + }(); + auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split); + if (buft != nullptr) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add the device default buffer type + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + + return buft_list; +} + +struct llama_model::impl { + impl() {} + ~impl() {} + + uint64_t n_elements = 0; + + size_t n_bytes = 0; + + std::string desc_str; + + // model memory mapped files + llama_mmaps mappings; + + // objects representing data potentially being locked in memory + llama_mlocks mlock_bufs; + llama_mlocks mlock_mmaps; + + // contexts where the model tensors metadata is stored + std::vector ctxs; + + // the model memory buffers for the tensor data + std::vector bufs; + + buft_list_t cpu_buft_list; + std::map gpu_buft_list; + + struct layer_dev { + ggml_backend_dev_t dev; + buft_list_t * buft_list; + }; + + layer_dev dev_input = {}; + layer_dev dev_output = {}; + std::vector dev_layer; +}; + +llama_model::llama_model(const struct llama_model_params & params) : params(params), pimpl(std::make_unique()) { +} + +llama_model::~llama_model() {} + +void llama_model::load_stats(llama_model_loader & ml) { + pimpl->n_elements = ml.n_elements; + pimpl->n_bytes = ml.n_bytes; +} + +void llama_model::load_arch(llama_model_loader & ml) { + arch = ml.get_arch(); + if (arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); + } +} + +void llama_model::load_hparams(llama_model_loader & ml) { + const gguf_context * ctx = ml.meta.get(); + + // get metadata as string + for (int i = 0; i < gguf_get_n_kv(ctx); i++) { + enum gguf_type type = gguf_get_kv_type(ctx, i); + if (type == GGUF_TYPE_ARRAY) { + continue; + } + const char * name = gguf_get_key(ctx, i); + const std::string value = gguf_kv_to_str(ctx, i); + gguf_kv.emplace(name, value); + } + + // get general kv + ml.get_key(LLM_KV_GENERAL_NAME, name, false); + + // everything past this point is not vocab-related + if (hparams.vocab_only) { + return; + } + + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + + ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); + ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); + + ml.get_key(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd); + ml.get_key(LLM_KV_CONVNEXT_BLOCK_COUNT, hparams.convnext.n_layer); + } + + GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); + GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); + if (hparams.n_expert > 0) { + GGML_ASSERT(hparams.n_expert_used > 0); + } else { + GGML_ASSERT(hparams.n_expert_used == 0); + } + + // zero-out the array hparams + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); + std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + + // n_head_kv is optional, default to n_head + hparams.n_head_kv_arr = hparams.n_head_arr; + + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false); + + bool rope_finetuned = false; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + hparams.n_ctx_orig_yarn = hparams.n_ctx_train; + ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false); + + // rope_freq_base (optional) + hparams.rope_freq_base_train = 10000.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train, false); + + std::string rope_scaling("linear"); + ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false); + hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); + GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED); + + // rope_freq_scale (inverse of the kv) is optional + float ropescale = 0.0f; + if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) { + // try the old key name + ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false); + } + hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + + // non-transformer models do not have attention heads + if (hparams.n_head() > 0) { + // gpt-neox n_rot = rotary_pct * (n_embd / n_head) + // gpt-j n_rot = rotary_dim + + hparams.n_embd_head_k = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + + hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); + + // sanity check for n_rot (optional) + hparams.n_rot = hparams.n_embd_head_k; + + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + + if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { + if (hparams.n_rot != hparams.n_embd_head_k) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); + } + } + } else { + hparams.n_rot = 0; + hparams.n_embd_head_k = 0; + hparams.n_embd_head_v = 0; + } + + // for differentiating model types + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + + // arch-specific KVs + switch (arch) { + case LLM_ARCH_LLAMA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 8) { + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8x7B; break; + case 56: type = LLM_TYPE_8x22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B + case 22: type = LLM_TYPE_1B; break; + case 26: type = LLM_TYPE_3B; break; + case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B + // granite uses a vocab with len 49152 + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 36: type = LLM_TYPE_8B; break; // granite + case 40: type = LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_34B; break; + case 60: type = LLM_TYPE_30B; break; + case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } + } break; + case LLM_ARCH_DECI: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_MINICPM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_1B; break; + case 40: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_MINICPM3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GROK: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 64: type = LLM_TYPE_314B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_FALCON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 60: type = LLM_TYPE_40B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_BAICHUAN: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + if (type == LLM_TYPE_13B) { + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } + } break; + case LLM_ARCH_STARCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + case 42: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_REFACT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } break; + case LLM_ARCH_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 3: + type = LLM_TYPE_17M; break; // bge-micro + case 6: + type = LLM_TYPE_22M; break; // MiniLM-L6 + case 12: + switch (hparams.n_embd) { + case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small + case 768: type = LLM_TYPE_109M; break; // bge-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + type = LLM_TYPE_335M; break; // bge-large + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_JINA_BERT_V2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + hparams.f_max_alibi_bias = 8.0f; + + switch (hparams.n_layer) { + case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small + case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_NOMIC_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + if (hparams.n_layer == 12 && hparams.n_embd == 768) { + type = LLM_TYPE_137M; + } + } break; + case LLM_ARCH_BLOOM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } break; + case LLM_ARCH_MPT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_30B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_STABLELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN2VL: + { + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + } + // fall through + case LLM_ARCH_QWEN2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; + case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; + case 32: type = LLM_TYPE_7B; break; + case 36: type = LLM_TYPE_3B; break; + case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN2MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_A2_7B; break; + case 28: type = LLM_TYPE_57B_A14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_PHI2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_PHI3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931 + if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) { + // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct + hparams.n_swa = 2047; + } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) { + // default value for Phi-3-mini-128k-instruct + hparams.n_swa = 262144; + } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) { + // default value for Phi-3-medium-128k-instruct + hparams.n_swa = 131072; + } + bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (!found_swa && hparams.n_swa == 0) { + throw std::runtime_error("invalid value for sliding_window"); + } + } break; + case LLM_ARCH_PHIMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_16x3_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_PLAMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GPT2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_SMALL; break; + case 24: type = LLM_TYPE_MEDIUM; break; + case 36: type = LLM_TYPE_LARGE; break; + case 48: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_CODESHELL: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 42: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_ORION: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_INTERNLM2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GEMMA: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 18: type = LLM_TYPE_2B; break; + case 28: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GEMMA2: + { + hparams.n_swa = 4096; // default value of gemma 2 + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + hparams.attn_soft_cap = true; + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_2B; break; + case 42: type = LLM_TYPE_9B; break; + case 46: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_STARCODER2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_3B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + case 52: type = LLM_TYPE_20B; break; // granite + case 88: type = LLM_TYPE_34B; break; // granite + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_MAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_XVERSE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 80: type = LLM_TYPE_65B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_COMMAND_R: + { + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_35B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_COHERE2: + { + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_DBRX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_OLMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + + switch (hparams.n_layer) { + case 22: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_OLMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_OLMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_OPENELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GPTNEOX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + switch (hparams.n_layer) { + case 6: + switch (hparams.n_ff()) { + case 512: type = LLM_TYPE_14M; break; + case 2048: type = LLM_TYPE_70M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_160M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 16: + switch (hparams.n_ff()) { + case 8192: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_410M; break; + case 8192: type = LLM_TYPE_1_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_ff()) { + case 10240: type = LLM_TYPE_2_8B; break; + case 16384: type = LLM_TYPE_6_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 36: + switch (hparams.n_ff()) { + case 20480: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 44: + switch (hparams.n_ff()) { + case 24576: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_ARCTIC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 128) { + switch (hparams.n_layer) { + case 35: type = LLM_TYPE_10B_128x3_66B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_DEEPSEEK: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_DEEPSEEK2: + { + bool is_lite = (hparams.n_layer == 27); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + if (!is_lite) { + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + // for compatibility with existing DeepSeek V2 and V2.5 GGUFs + // that have no expert_gating_func model parameter set + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_16B; break; + case 60: type = LLM_TYPE_236B; break; + case 61: type = LLM_TYPE_671B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_CHATGLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_6B; break; + case 40: type = LLM_TYPE_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_BITNET: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_T5: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; + } + + switch (hparams.n_layer) { + case 6: type = LLM_TYPE_60M; break; // t5-small + case 8: type = LLM_TYPE_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_220M; break; // t5-base + case 2048: type = LLM_TYPE_250M; break; // flan-t5-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_770M; break; // t5-large + case 2816: type = LLM_TYPE_780M; break; // flan-t5-large + case 16384: type = LLM_TYPE_3B; break; // t5-3b + case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl + case 65536: type = LLM_TYPE_11B; break; // t5-11b + case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_T5ENCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + type = LLM_TYPE_UNKNOWN; + } break; + case LLM_ARCH_JAIS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_3B; break; + case 40: type = LLM_TYPE_13B; break; + /* TODO: add variants */ + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_NEMOTRON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_EXAONE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_CHAMELEON: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_34B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + } break; + default: throw std::runtime_error("unsupported model architecture"); + } + + pimpl->n_bytes = ml.n_bytes; + + pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); + + if (hparams.f_max_alibi_bias > 0.0f) { + hparams.use_alibi = true; + } + + hparams.rope_type = llama_model_rope_type(this); +} + +void llama_model::load_vocab(llama_model_loader & ml) { + const auto kv = LLM_KV(arch); + + vocab.load(ml, kv); +} + +bool llama_model::load_tensors(llama_model_loader & ml) { + const auto & split_mode = params.split_mode; + const auto & n_gpu_layers = params.n_gpu_layers; + const auto & use_mlock = params.use_mlock; + const auto & tensor_split = params.tensor_split; + + const int n_layer = hparams.n_layer; + + const bool use_mmap_buffer = true; + + // build a list of buffer types for the CPU and GPU devices + pimpl->cpu_buft_list = make_cpu_buft_list(devices); + for (auto * dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); + // add CPU buffer types as a fallback + buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); + pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); + } + + // calculate the split points + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); + std::vector splits(n_devices()); + if (all_zero) { + // default split, by free memory + for (size_t i = 0; i < n_devices(); ++i) { + ggml_backend_dev_t dev = devices[i]; + size_t total; + size_t free; + ggml_backend_dev_memory(dev, &free, &total); + splits[i] = free; + } + } else { + std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); + } + + // sum and normalize the splits to get the split points + float split_sum = 0.0f; + for (size_t i = 0; i < n_devices(); ++i) { + split_sum += splits[i]; + splits[i] = split_sum; + } + for (size_t i = 0; i < n_devices(); ++i) { + splits[i] /= split_sum; + } + + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); + auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { + if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev)); + return {cpu_dev, &pimpl->cpu_buft_list}; + } + const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); + auto * dev = devices.at(layer_gpu); + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev)); + return {dev, &pimpl->gpu_buft_list.at(dev)}; + }; + + // assign the input layer + // there is very little benefit to offloading the input layer, so always keep it on the CPU + pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; + + // assign the repeating layers to the devices according to the splits + pimpl->dev_layer.resize(n_layer); + for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer[il] = get_layer_buft_list(il); + } + + // assign the output layer + pimpl->dev_output = get_layer_buft_list(n_layer); + + // one ggml context per buffer type + int max_n_tensors = ml.n_tensors; + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += n_layer*2; // duplicated rope freq tensors + const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; + + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ctx_map[buft] = ctx; + pimpl->ctxs.emplace_back(ctx); + + return ctx; + } + return it->second; + }; + + const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + + // create tensors for the weights + { + // note: cast to int64_t since we will use these for the tensor dimensions + const int64_t n_head = hparams.n_head(); + const int64_t n_head_kv = hparams.n_head_kv(); + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_ff = hparams.n_ff(); + const int64_t n_embd_gqa = n_embd_v_gqa; + const int64_t n_vocab = vocab.n_tokens(); + const int64_t n_token_types = vocab.n_token_types(); + const int64_t n_rot = hparams.n_rot; + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_ctx_train = hparams.n_ctx_train; + + if (n_expert > 0 && hparams.n_expert_used == 0) { + throw std::runtime_error("model has expert layers but no expert layers are used"); + } + + int n_moved_tensors = 0; + ggml_tensor * first_moved_tensor = nullptr; + ggml_backend_buffer_type_t first_moved_from_buft = nullptr; + ggml_backend_buffer_type_t first_moved_to_buft = nullptr; + + auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { + ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); + + if (!t_meta) { + if (flags & TENSOR_NOT_REQUIRED) { + return nullptr; + } + throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + } + + // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops + // the tensor is duplicated + // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor + llm_tensor tn_tensor = tn.tensor; + if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & TENSOR_DUPLICATED) { + tn_tensor = LLM_TENSOR_OUTPUT; + } + + llm_tensor_info info; + try { + info = llm_tensor_info_for(tn_tensor); + } catch (const std::out_of_range & e) { + throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); + } + + // tensors with "bias" suffix are always used with GGML_OP_ADD + ggml_op op; + bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; + if (bias) { + op = GGML_OP_ADD; + } else { + op = info.op; + } + + // sanity checks + if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (tn.bid != -1) { + GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); + } + } else { + if (tn.bid == -1) { + GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); + } + } + + // select the buffer type for this tensor + buft_list_t * buft_list; + switch (info.layer) { + case LLM_TENSOR_LAYER_INPUT: + buft_list = pimpl->dev_input.buft_list; + break; + case LLM_TENSOR_LAYER_OUTPUT: + buft_list = pimpl->dev_output.buft_list; + break; + case LLM_TENSOR_LAYER_REPEATING: + buft_list = pimpl->dev_layer.at(tn.bid).buft_list; + break; + default: + GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); + } + + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } + + // avoid using a host buffer when using mmap + auto * buft_dev = ggml_backend_buft_get_device(buft); + if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + buft = ggml_backend_dev_buffer_type(cpu_dev); + } + + if (buft != buft_list->front().second) { + n_moved_tensors++; + if (!first_moved_tensor) { + first_moved_tensor = t_meta; + first_moved_from_buft = buft_list->front().second; + first_moved_to_buft = buft; + } + } + + ggml_context * ctx = ctx_for_buft(buft); + + // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one + if (flags & TENSOR_DUPLICATED) { + ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); + if (t) { + return t; + } + } + return ml.create_tensor(ctx, tn, ne, flags); + }; + + layers.resize(n_layer); + + // TODO: move to a separate function + const auto tn = LLM_TN(arch); + switch (arch) { + case LLM_ARCH_LLAMA: + case LLM_ARCH_REFACT: + case LLM_ARCH_MINICPM: + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } + } break; + case LLM_ARCH_DECI: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + } + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_MINICPM3: + { + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_GROK: + { + if (n_expert == 0) { + throw std::runtime_error("Grok model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_DBRX: + { + if (n_expert == 0) { + throw std::runtime_error("DBRX model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } break; + case LLM_ARCH_BAICHUAN: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_FALCON: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_STARCODER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + // needs to be on GPU + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_BERT: + case LLM_ARCH_NOMIC_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + if (arch == LLM_ARCH_BERT) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } else { + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + + if (arch == LLM_ARCH_BERT) { + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_JINA_BERT_V2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; // JinaBertLayer + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_BLOOM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_MPT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + // AWQ ScaleActivation layer + layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_STABLELM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors, present in Stable LM 2 1.6B + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + + // optional q and k layernorms, present in StableLM 2 12B + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + // optional FFN norm, not present in StableLM 2 12B which uses parallel residual + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); + } + } break; + case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2VL: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN2MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } + } break; + case LLM_ARCH_PHI2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_PHI3: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_PHIMOE: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_PLAMO: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GPT2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_CODESHELL: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_ORION: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_INTERNLM2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GEMMA: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_GEMMA2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_STARCODER2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional bias tensors + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); + } + } break; + case LLM_ARCH_MAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + if (2 * n_embd != d_inner) { + throw std::runtime_error("only an expansion factor of 2 is supported for now"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } + } break; + case LLM_ARCH_XVERSE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_COMMAND_R: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_layer >= 64){ + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_COHERE2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } + } + break; + case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_OLMO2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_OLMOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } break; + case LLM_ARCH_OPENELM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head; + const int64_t n_ff = hparams.n_ff(i); + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GPTNEOX: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_ARCTIC: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } break; + case LLM_ARCH_DEEPSEEK: + { + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_DEEPSEEK2: + { + const bool is_lite = (hparams.n_layer == 27); + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + if (!is_lite) { + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + } + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (!is_lite) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + } + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_BITNET: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_scale = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_T5: + { + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); + // this tensor seems to be unused in HF transformers implementation + layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_T5ENCODER: + { + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_JAIS: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_CHATGLM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_NEMOTRON: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_EXAONE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_RWKV6: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED); + GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0); + } + + } break; + case LLM_ARCH_RWKV6QWEN2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + const int n_head_kv = hparams.n_head_kv(); + int attn_key_value_size; + if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) { + attn_key_value_size = attn_hidden_size; + } else { + attn_key_value_size = n_head_kv * head_size; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + // optional bias tensors + layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_CHAMELEON: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); + + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0); + conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); + + // posnet + { + const int64_t n_embd = hparams.posnet.n_embd; + + for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) { + auto & layer = layers[i].posnet; + + // posnet: + // + // - resnet + // - resnet + // - attn + // - resnet + // - resnet + // - norm + // + switch (i) { + case 0: + case 1: + case 3: + case 4: + { + layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0); + layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0); + + layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0); + + layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0); + layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0); + + layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0); + } break; + case 2: + { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + + layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0); + + layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0); + + layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0); + + layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0); + } break; + case 5: + { + layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + } break; + default: GGML_ABORT("unknown posnet layer"); + }; + } + } + + GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0); + + // convnext + { + const int64_t n_embd = hparams.convnext.n_embd; + + for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) { + auto & layer = layers[i].convnext; + + layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0); + layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0); + + layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0); + + layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0); + layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0); + + layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0); + layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0); + + layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); + } + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + } break; + default: + throw std::runtime_error("unknown architecture"); + } + + if (n_moved_tensors > 0) { + LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n", + __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1, + ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); + } + } + + ml.done_getting_tensors(); + + ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); + pimpl->mappings.reserve(ml.mappings.size()); + + // create the backend buffers + std::vector> ctx_bufs; + ctx_bufs.reserve(ctx_map.size()); + + // Ensure we have enough capacity for the maximum backend buffer we will potentially create + const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size(); + pimpl->bufs.reserve(n_max_backend_buffer); + + for (auto & it : ctx_map) { + ggml_backend_buffer_type_t buft = it.first; + ggml_context * ctx = it.second; + + // skip contexts without tensors + if (ggml_get_first_tensor(ctx) == nullptr) { + continue; + } + + llama_buf_map buf_map; + buf_map.reserve(n_max_backend_buffer); + + // check if it is possible to use buffer_from_host_ptr with this buffer type + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + if (!dev) { + // FIXME: workaround for CPU backend buft having a NULL device + dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + } + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; + bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); + + if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + // only the mmap region containing the tensors in the model is mapped to the backend buffer + // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers + // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size + void * addr = nullptr; + size_t first, last; // NOLINT + ml.get_mapping_range(&first, &last, &addr, idx, ctx); + if (first >= last) { + continue; + } + const size_t max_size = ggml_get_max_tensor_size(ctx); + ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); + } + pimpl->bufs.emplace_back(buf); + buf_map.emplace(idx, buf); + } + } + else { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); + } + pimpl->bufs.emplace_back(buf); + if (use_mlock && ggml_backend_buffer_is_host(buf)) { + pimpl->mlock_bufs.emplace_back(new llama_mlock); + auto & mlock_buf = pimpl->mlock_bufs.back(); + mlock_buf->init (ggml_backend_buffer_get_base(buf)); + mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); + } + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + buf_map.emplace(idx, buf); + } + } + + if (pimpl->bufs.empty()) { + throw std::runtime_error("failed to allocate buffer"); + } + + for (auto & buf : buf_map) { + // indicate that this buffer contains weights + // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight + ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + + ctx_bufs.emplace_back(ctx, buf_map); + } + + if (llama_supports_gpu_offload()) { + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); + if (n_gpu_layers > (int) hparams.n_layer) { + LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); + } + + const int max_backend_supported_layers = hparams.n_layer + 1; + const int max_offloadable_layers = hparams.n_layer + 1; + + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + } + + // print memory requirements per buffer type + for (auto & buf : pimpl->bufs) { + LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + } + + // populate tensors_by_name + for (auto & ctx : pimpl->ctxs) { + for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { + tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + } + + // load tensor data + for (auto & it : ctx_bufs) { + ggml_context * ctx = it.first; + auto & bufs = it.second; + if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { + return false; + } + } + + if (use_mmap_buffer) { + for (auto & mapping : ml.mappings) { + pimpl->mappings.emplace_back(std::move(mapping)); + } + } + + return true; +} + +std::string llama_model::arch_name() const { + return llm_arch_name(arch); +} + +std::string llama_model::type_name() const { + return llm_type_name(type); +} + +std::string llama_model::desc() const { + return pimpl->desc_str; +} + +size_t llama_model::size() const { + return pimpl->n_bytes; +} + +size_t llama_model::max_nodes() const { + return std::max(8192, tensors_by_name.size()*5); +} + +size_t llama_model::n_devices() const { + return devices.size(); +} + +uint64_t llama_model::n_elements() const { + return pimpl->n_elements; +} + +void llama_model::print_info() const { + const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train); + + auto print_f = [](const std::function & f, uint32_t n) { + bool is_var = false; + + std::vector v; + for (uint32_t i = 0; i < n; ++i) { + v.push_back(f(i)); + if (v[i] != v[0]) { + is_var = true; + } + } + + std::stringstream ss; + + if (is_var) { + ss << "["; + for (uint32_t i = 0; i < n; ++i) { + ss << v[i]; + if (i < n - 1) { + ss << ", "; + } + } + ss << "]"; + } else { + ss << v[0]; + } + + return ss.str(); + }; + + // hparams + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); + LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + + if (!hparams.vocab_only) { + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); + LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); + LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } + + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } + + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } + + if (arch == LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((enum llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + } + + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } + + if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + } + + vocab.print_info(); +} + +ggml_backend_dev_t llama_model::dev_layer(int il) const { + return pimpl->dev_layer.at(il).dev; +} + +ggml_backend_dev_t llama_model::dev_output() const { + return pimpl->dev_output.dev; +} + +template +static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; + ggml_tensor * op_tensor = fn(ctx.get()); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op_tensor->src[i] != nullptr) { + assert(op_tensor->src[i]->buffer == nullptr); + op_tensor->src[i]->buffer = buf.get(); + } + } + + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + + return op_supported; +} + +template +static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) { + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (buft_supported(cur_buft, cur_dev, fn)) { + return cur_buft; + } + } + + throw std::runtime_error(format("no suitable buffer type found")); +} + +ggml_backend_buffer_type_t llama_model::select_buft(int il) const { + return ::select_buft( + *pimpl->dev_layer.at(il).buft_list, + [&](ggml_context * ctx) { + ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + return ggml_add(ctx, cur, layer_dir); + }); +} + +const struct ggml_tensor * llama_model::get_tensor(const char * name) const { + auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), + [name](const std::pair & it) { + return it.first == name; + }); + if (it == tensors_by_name.end()) { + return nullptr; + } + + return it->second; +} + +// +// interface implementation +// + +struct llama_model_params llama_model_default_params() { + struct llama_model_params result = { + /*.devices =*/ nullptr, + /*.n_gpu_layers =*/ 0, + /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, + /*.main_gpu =*/ 0, + /*.tensor_split =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.kv_overrides =*/ nullptr, + /*.vocab_only =*/ false, + /*.use_mmap =*/ true, + /*.use_mlock =*/ false, + /*.check_tensors =*/ false, + }; + +#ifdef GGML_USE_METAL + // note: we usually have plenty of VRAM, so by default offload all layers to the GPU + result.n_gpu_layers = 999; +#endif + + return result; +} + +const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model) { + return &model->vocab; +} + +void llama_free_model(struct llama_model * model) { + llama_model_free(model); +} + +void llama_model_free(struct llama_model * model) { + delete model; +} + +int32_t llama_model_n_ctx_train(const struct llama_model * model) { + return model->hparams.n_ctx_train; +} + +int32_t llama_model_n_embd(const struct llama_model * model) { + return model->hparams.n_embd; +} + +int32_t llama_model_n_layer(const struct llama_model * model) { + return model->hparams.n_layer; +} + +int32_t llama_model_n_head(const struct llama_model * model) { + return model->hparams.n_head(); +} + +// deprecated +int32_t llama_n_ctx_train(const struct llama_model * model) { + return llama_model_n_ctx_train(model); +} + +// deprecated +int32_t llama_n_embd(const struct llama_model * model) { + return llama_model_n_embd(model); +} + +// deprecated +int32_t llama_n_layer(const struct llama_model * model) { + return llama_model_n_layer(model); +} + +// deprecated +int32_t llama_n_head(const struct llama_model * model) { + return llama_model_n_head(model); +} + +enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { + switch (model->arch) { + // these models do not use RoPE + case LLM_ARCH_GPT2: + case LLM_ARCH_GPTJ: + case LLM_ARCH_MPT: + case LLM_ARCH_REFACT: + case LLM_ARCH_BLOOM: + case LLM_ARCH_MAMBA: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_T5: + case LLM_ARCH_T5ENCODER: + case LLM_ARCH_JAIS: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_WAVTOKENIZER_DEC: + return LLAMA_ROPE_TYPE_NONE; + + // use what we call a normal RoPE, operating on pairs of consecutive head values + case LLM_ARCH_LLAMA: + case LLM_ARCH_DECI: + case LLM_ARCH_BAICHUAN: + case LLM_ARCH_STARCODER: + case LLM_ARCH_PLAMO: + case LLM_ARCH_ORION: + case LLM_ARCH_INTERNLM2: + case LLM_ARCH_MINICPM: + case LLM_ARCH_XVERSE: + case LLM_ARCH_COMMAND_R: + case LLM_ARCH_COHERE2: + case LLM_ARCH_OLMO: + case LLM_ARCH_ARCTIC: + case LLM_ARCH_DEEPSEEK: + case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_CHATGLM: + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_CHAMELEON: + return LLAMA_ROPE_TYPE_NORM; + + // the pairs of head values are offset by n_rot/2 + case LLM_ARCH_FALCON: + case LLM_ARCH_GROK: + case LLM_ARCH_DBRX: + case LLM_ARCH_BERT: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_STABLELM: + case LLM_ARCH_BITNET: + case LLM_ARCH_QWEN: + case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2MOE: + case LLM_ARCH_OLMO2: + case LLM_ARCH_OLMOE: + case LLM_ARCH_PHI2: + case LLM_ARCH_PHI3: + case LLM_ARCH_PHIMOE: + case LLM_ARCH_GEMMA: + case LLM_ARCH_GEMMA2: + case LLM_ARCH_STARCODER2: + case LLM_ARCH_OPENELM: + case LLM_ARCH_GPTNEOX: + case LLM_ARCH_CODESHELL: + case LLM_ARCH_NEMOTRON: + case LLM_ARCH_EXAONE: + case LLM_ARCH_MINICPM3: + return LLAMA_ROPE_TYPE_NEOX; + + case LLM_ARCH_QWEN2VL: + return LLAMA_ROPE_TYPE_MROPE; + + // all model arches should be listed explicitly here + case LLM_ARCH_UNKNOWN: + GGML_ABORT("unknown architecture"); + } + + return LLAMA_ROPE_TYPE_NONE; +} + +float llama_model_rope_freq_scale_train(const struct llama_model * model) { + return model->hparams.rope_freq_scale_train; +} + +int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) { + const auto & it = model->gguf_kv.find(key); + if (it == model->gguf_kv.end()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int32_t llama_model_meta_count(const struct llama_model * model) { + return (int)model->gguf_kv.size(); +} + +int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)model->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = model->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->first.c_str()); +} + +int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) { + if (i < 0 || i >= (int)model->gguf_kv.size()) { + if (buf_size > 0) { + buf[0] = '\0'; + } + return -1; + } + auto it = model->gguf_kv.begin(); + std::advance(it, i); + return snprintf(buf, buf_size, "%s", it->second.c_str()); +} + +int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { + return snprintf(buf, buf_size, "%s", model->desc().c_str()); +} + +uint64_t llama_model_size(const struct llama_model * model) { + return model->size(); +} + +const char * llama_model_chat_template(const struct llama_model * model, const char * name) { + const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); + const auto & it = model->gguf_kv.find(key); + if (it == model->gguf_kv.end()) { + return nullptr; + } + + return it->second.c_str(); +} + +uint64_t llama_model_n_params(const struct llama_model * model) { + return model->n_elements(); +} + +bool llama_model_has_encoder(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5: return true; + case LLM_ARCH_T5ENCODER: return true; + default: return false; + } +} + +bool llama_model_has_decoder(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_T5ENCODER: return false; + default: return true; + } +} + +llama_token llama_model_decoder_start_token(const struct llama_model * model) { + return model->hparams.dec_start_token_id; +} + +bool llama_model_is_recurrent(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_MAMBA: return true; + case LLM_ARCH_RWKV6: return true; + case LLM_ARCH_RWKV6QWEN2: return true; + default: return false; + } +} diff --git a/src/llama-model.h b/src/llama-model.h new file mode 100644 index 000000000..a7c304447 --- /dev/null +++ b/src/llama-model.h @@ -0,0 +1,370 @@ +#pragma once + +#include "llama.h" +#include "llama-arch.h" +#include "llama-hparams.h" +#include "llama-vocab.h" + +#include +#include +#include +#include + +struct llama_model_loader; + +// available models +enum llm_type { + LLM_TYPE_UNKNOWN, + LLM_TYPE_14M, + LLM_TYPE_17M, + LLM_TYPE_22M, + LLM_TYPE_33M, + LLM_TYPE_60M, + LLM_TYPE_70M, + LLM_TYPE_80M, + LLM_TYPE_109M, + LLM_TYPE_137M, + LLM_TYPE_160M, + LLM_TYPE_220M, + LLM_TYPE_250M, + LLM_TYPE_270M, + LLM_TYPE_335M, + LLM_TYPE_410M, + LLM_TYPE_450M, + LLM_TYPE_770M, + LLM_TYPE_780M, + LLM_TYPE_0_5B, + LLM_TYPE_1B, + LLM_TYPE_1_3B, + LLM_TYPE_1_4B, + LLM_TYPE_1_5B, + LLM_TYPE_1_6B, + LLM_TYPE_2B, + LLM_TYPE_2_8B, + LLM_TYPE_3B, + LLM_TYPE_4B, + LLM_TYPE_6B, + LLM_TYPE_6_9B, + LLM_TYPE_7B, + LLM_TYPE_8B, + LLM_TYPE_9B, + LLM_TYPE_11B, + LLM_TYPE_12B, + LLM_TYPE_13B, + LLM_TYPE_14B, + LLM_TYPE_15B, + LLM_TYPE_16B, + LLM_TYPE_20B, + LLM_TYPE_30B, + LLM_TYPE_32B, + LLM_TYPE_34B, + LLM_TYPE_35B, + LLM_TYPE_40B, + LLM_TYPE_65B, + LLM_TYPE_70B, + LLM_TYPE_236B, + LLM_TYPE_314B, + LLM_TYPE_671B, + LLM_TYPE_SMALL, + LLM_TYPE_MEDIUM, + LLM_TYPE_LARGE, + LLM_TYPE_XL, + LLM_TYPE_A1_7B, + LLM_TYPE_A2_7B, + LLM_TYPE_8x7B, + LLM_TYPE_8x22B, + LLM_TYPE_16x12B, + LLM_TYPE_16x3_8B, + LLM_TYPE_10B_128x3_66B, + LLM_TYPE_57B_A14B, + LLM_TYPE_27B, +}; + +struct llama_layer_posnet { + // resnet + struct ggml_tensor * norm1 = nullptr; + struct ggml_tensor * norm1_b = nullptr; + + struct ggml_tensor * conv1 = nullptr; + struct ggml_tensor * conv1_b = nullptr; + + struct ggml_tensor * norm2 = nullptr; + struct ggml_tensor * norm2_b = nullptr; + + struct ggml_tensor * conv2 = nullptr; + struct ggml_tensor * conv2_b = nullptr; + + // attention + struct ggml_tensor * attn_norm = nullptr; + struct ggml_tensor * attn_norm_b = nullptr; + + struct ggml_tensor * attn_q = nullptr; + struct ggml_tensor * attn_q_b = nullptr; + + struct ggml_tensor * attn_k = nullptr; + struct ggml_tensor * attn_k_b = nullptr; + + struct ggml_tensor * attn_v = nullptr; + struct ggml_tensor * attn_v_b = nullptr; + + struct ggml_tensor * attn_o = nullptr; + struct ggml_tensor * attn_o_b = nullptr; + + // normalize + struct ggml_tensor * norm = nullptr; + struct ggml_tensor * norm_b = nullptr; +}; + +struct llama_layer_convnext { + struct ggml_tensor * dw = nullptr; + struct ggml_tensor * dw_b = nullptr; + + struct ggml_tensor * norm = nullptr; + struct ggml_tensor * norm_b = nullptr; + + struct ggml_tensor * pw1 = nullptr; + struct ggml_tensor * pw1_b = nullptr; + + struct ggml_tensor * pw2 = nullptr; + struct ggml_tensor * pw2_b = nullptr; + + struct ggml_tensor * gamma = nullptr; +}; + +struct llama_layer { + // normalization + struct ggml_tensor * attn_norm = nullptr; + struct ggml_tensor * attn_norm_b = nullptr; + struct ggml_tensor * attn_norm_2 = nullptr; + struct ggml_tensor * attn_norm_2_b = nullptr; + struct ggml_tensor * attn_q_norm = nullptr; + struct ggml_tensor * attn_q_norm_b = nullptr; + struct ggml_tensor * attn_k_norm = nullptr; + struct ggml_tensor * attn_k_norm_b = nullptr; + struct ggml_tensor * attn_out_norm = nullptr; + struct ggml_tensor * attn_out_norm_b = nullptr; + struct ggml_tensor * attn_q_a_norm = nullptr; + struct ggml_tensor * attn_kv_a_norm = nullptr; + struct ggml_tensor * attn_sub_norm = nullptr; + struct ggml_tensor * attn_post_norm = nullptr; + struct ggml_tensor * ffn_sub_norm = nullptr; + struct ggml_tensor * attn_norm_cross = nullptr; + struct ggml_tensor * attn_norm_enc = nullptr; + + // attention + struct ggml_tensor * wq = nullptr; + struct ggml_tensor * wk = nullptr; + struct ggml_tensor * wv = nullptr; + struct ggml_tensor * wo = nullptr; + struct ggml_tensor * wqkv = nullptr; + struct ggml_tensor * wq_a = nullptr; + struct ggml_tensor * wq_b = nullptr; + struct ggml_tensor * wkv_a_mqa = nullptr; + struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wq_cross = nullptr; + struct ggml_tensor * wk_cross = nullptr; + struct ggml_tensor * wv_cross = nullptr; + struct ggml_tensor * wo_cross = nullptr; + struct ggml_tensor * wq_enc = nullptr; + struct ggml_tensor * wk_enc = nullptr; + struct ggml_tensor * wv_enc = nullptr; + struct ggml_tensor * wo_enc = nullptr; + + // attention bias + struct ggml_tensor * bq = nullptr; + struct ggml_tensor * bk = nullptr; + struct ggml_tensor * bv = nullptr; + struct ggml_tensor * bo = nullptr; + struct ggml_tensor * bqkv = nullptr; + + // relative position bias + struct ggml_tensor * attn_rel_b = nullptr; + struct ggml_tensor * attn_rel_b_enc = nullptr; + struct ggml_tensor * attn_rel_b_cross = nullptr; + + // normalization + struct ggml_tensor * ffn_norm = nullptr; + struct ggml_tensor * ffn_norm_b = nullptr; + struct ggml_tensor * ffn_post_norm = nullptr; + struct ggml_tensor * layer_out_norm = nullptr; + struct ggml_tensor * layer_out_norm_b = nullptr; + struct ggml_tensor * ffn_norm_exps = nullptr; + struct ggml_tensor * ffn_norm_enc = nullptr; + + // ff + struct ggml_tensor * ffn_gate = nullptr; // w1 + struct ggml_tensor * ffn_down = nullptr; // w2 + struct ggml_tensor * ffn_up = nullptr; // w3 + struct ggml_tensor * ffn_gate_enc = nullptr; + struct ggml_tensor * ffn_down_enc = nullptr; + struct ggml_tensor * ffn_up_enc = nullptr; + + // ff MoE + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + + // ff shared expert (shexp) + struct ggml_tensor * ffn_gate_inp_shexp = nullptr; + struct ggml_tensor * ffn_gate_shexp = nullptr; + struct ggml_tensor * ffn_down_shexp = nullptr; + struct ggml_tensor * ffn_up_shexp = nullptr; + + // ff bias + struct ggml_tensor * ffn_gate_b = nullptr; + struct ggml_tensor * ffn_down_b = nullptr; // b2 + struct ggml_tensor * ffn_up_b = nullptr; // b3 + struct ggml_tensor * ffn_act = nullptr; + struct ggml_tensor * ffn_exp_probs_b = nullptr; + + // mamba proj + struct ggml_tensor * ssm_in = nullptr; + struct ggml_tensor * ssm_x = nullptr; + struct ggml_tensor * ssm_dt = nullptr; + struct ggml_tensor * ssm_out = nullptr; + + // mamba + struct ggml_tensor * ssm_conv1d = nullptr; + struct ggml_tensor * ssm_a = nullptr; + struct ggml_tensor * ssm_d = nullptr; + + // mamba bias + struct ggml_tensor * ssm_conv1d_b = nullptr; + struct ggml_tensor * ssm_dt_b = nullptr; + + // rwkv + struct ggml_tensor * time_mix_w1 = nullptr; + struct ggml_tensor * time_mix_w2 = nullptr; + struct ggml_tensor * time_mix_lerp_x = nullptr; + struct ggml_tensor * time_mix_lerp_w = nullptr; + struct ggml_tensor * time_mix_lerp_k = nullptr; + struct ggml_tensor * time_mix_lerp_v = nullptr; + struct ggml_tensor * time_mix_lerp_r = nullptr; + struct ggml_tensor * time_mix_lerp_g = nullptr; + struct ggml_tensor * time_mix_lerp_fused = nullptr; + + struct ggml_tensor * time_mix_first = nullptr; + struct ggml_tensor * time_mix_decay = nullptr; + struct ggml_tensor * time_mix_decay_w1 = nullptr; + struct ggml_tensor * time_mix_decay_w2 = nullptr; + struct ggml_tensor * time_mix_key = nullptr; + struct ggml_tensor * time_mix_key_b = nullptr; + struct ggml_tensor * time_mix_value = nullptr; + struct ggml_tensor * time_mix_value_b = nullptr; + struct ggml_tensor * time_mix_receptance = nullptr; + struct ggml_tensor * time_mix_receptance_b = nullptr; + struct ggml_tensor * time_mix_gate = nullptr; + + struct ggml_tensor * time_mix_ln = nullptr; + struct ggml_tensor * time_mix_ln_b = nullptr; + struct ggml_tensor * time_mix_output = nullptr; + + struct ggml_tensor * channel_mix_lerp_k = nullptr; + struct ggml_tensor * channel_mix_lerp_r = nullptr; + + struct ggml_tensor * channel_mix_key = nullptr; + struct ggml_tensor * channel_mix_receptance = nullptr; + struct ggml_tensor * channel_mix_value = nullptr; + + // long rope factors + struct ggml_tensor * rope_long = nullptr; + struct ggml_tensor * rope_short = nullptr; + struct ggml_tensor * rope_freqs = nullptr; + + // bitnet scale + struct ggml_tensor * wq_scale = nullptr; + struct ggml_tensor * wk_scale = nullptr; + struct ggml_tensor * wv_scale = nullptr; + struct ggml_tensor * wo_scale = nullptr; + struct ggml_tensor * ffn_gate_scale = nullptr; + struct ggml_tensor * ffn_up_scale = nullptr; + struct ggml_tensor * ffn_down_scale = nullptr; + + struct llama_layer_posnet posnet; + + struct llama_layer_convnext convnext; +}; + +struct llama_model { + llm_type type = LLM_TYPE_UNKNOWN; + llm_arch arch = LLM_ARCH_UNKNOWN; + + std::string name = "n/a"; + + llama_hparams hparams = {}; + llama_vocab vocab; + + struct ggml_tensor * tok_embd = nullptr; + struct ggml_tensor * type_embd = nullptr; + struct ggml_tensor * pos_embd = nullptr; + struct ggml_tensor * tok_norm = nullptr; + struct ggml_tensor * tok_norm_b = nullptr; + + struct ggml_tensor * output_norm = nullptr; + struct ggml_tensor * output_norm_b = nullptr; + struct ggml_tensor * output = nullptr; + struct ggml_tensor * output_b = nullptr; + struct ggml_tensor * output_norm_enc = nullptr; + + // classifier + struct ggml_tensor * cls = nullptr; + struct ggml_tensor * cls_b = nullptr; + struct ggml_tensor * cls_out = nullptr; + struct ggml_tensor * cls_out_b = nullptr; + + struct ggml_tensor * conv1d = nullptr; + struct ggml_tensor * conv1d_b = nullptr; + + std::vector layers; + + llama_model_params params; + + // gguf metadata + std::unordered_map gguf_kv; + + // list of devices used in this model + std::vector devices; + + // for quantize-stats only + std::vector> tensors_by_name; + + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + explicit llama_model(const struct llama_model_params & params); + ~llama_model(); + + void load_stats (llama_model_loader & ml); + void load_arch (llama_model_loader & ml); + void load_hparams(llama_model_loader & ml); + void load_vocab (llama_model_loader & ml); + bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback + + std::string arch_name() const; + std::string type_name() const; + + std::string desc() const; + + size_t size() const; + size_t max_nodes() const; + size_t n_devices() const; + + // total number of parameters in the model + uint64_t n_elements() const; + + void print_info() const; + + ggml_backend_dev_t dev_layer(int il) const; + ggml_backend_dev_t dev_output() const; + + ggml_backend_buffer_type_t select_buft(int il) const; + + const struct ggml_tensor * get_tensor(const char * name) const; + +private: + struct impl; + std::unique_ptr pimpl; +}; + +const char * llm_type_name(llm_type type); diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp new file mode 100644 index 000000000..fb7982655 --- /dev/null +++ b/src/llama-quant.cpp @@ -0,0 +1,934 @@ +#include "llama-quant.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-model-loader.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +static void zeros(std::ofstream & file, size_t n) { + char zero = 0; + for (size_t i = 0; i < n; ++i) { + file.write(&zero, 1); + } +} + +struct quantize_state_impl { + const llama_model & model; + const llama_model_quantize_params * params; + + int n_attention_wv = 0; + int n_ffn_down = 0; + int n_ffn_gate = 0; + int n_ffn_up = 0; + int i_attention_wv = 0; + int i_ffn_down = 0; + int i_ffn_gate = 0; + int i_ffn_up = 0; + + int n_k_quantized = 0; + int n_fallback = 0; + + bool has_imatrix = false; + + // used to figure out if a model shares tok_embd with the output weight + bool has_output = false; + + quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) + : model(model) + , params(params) + {} +}; + +static void llama_tensor_dequantize_impl( + struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, + const size_t nelements, const int nthread +) { + if (output.size() < nelements) { + output.resize(nelements); + } + float * f32_output = (float *) output.data(); + + const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type); + if (ggml_is_quantized(tensor->type)) { + if (qtype->to_float == NULL) { + throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type))); + } + } else if (tensor->type != GGML_TYPE_F16 && + tensor->type != GGML_TYPE_BF16) { + throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); + } + + if (nthread < 2) { + if (tensor->type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); + } else if (tensor->type == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements); + } else if (ggml_is_quantized(tensor->type)) { + qtype->to_float(tensor->data, f32_output, nelements); + } else { + GGML_ABORT("fatal error"); // unreachable + } + return; + } + + size_t block_size; + if (tensor->type == GGML_TYPE_F16 || + tensor->type == GGML_TYPE_BF16) { + block_size = 1; + } else { + block_size = (size_t)ggml_blck_size(tensor->type); + } + + size_t block_size_bytes = ggml_type_size(tensor->type); + + GGML_ASSERT(nelements % block_size == 0); + size_t nblocks = nelements / block_size; + size_t blocks_per_thread = nblocks / nthread; + size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count + + size_t in_buff_offs = 0; + size_t out_buff_offs = 0; + + for (int tnum = 0; tnum < nthread; tnum++) { + size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread + size_t thr_elems = thr_blocks * block_size; // number of elements for this thread + size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread + + auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) { + if (typ == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels); + } else if (typ == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels); + } else { + qtype->to_float(inbuf, outbuf, nels); + } + }; + workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems); + in_buff_offs += thr_block_bytes; + out_buff_offs += thr_elems; + } + for (auto & w : workers) { w.join(); } + workers.clear(); +} + +static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { + const std::string name = ggml_get_name(tensor); + + // TODO: avoid hardcoded tensor names - use the TN_* constants + const llm_arch arch = qs.model.arch; + const auto tn = LLM_TN(arch); + + auto use_more_bits = [](int i_layer, int n_layers) -> bool { + return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2; + }; + const int n_expert = std::max(1, (int)qs.model.hparams.n_expert); + auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) { + if (n_expert > 1) { + // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly + // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work + // for getting the current layer as I initially thought, and we need to resort to parsing the + // tensor name. + if (sscanf(name, "blk.%d.", &i_layer) != 1) { + throw std::runtime_error(format("Failed to determine layer for tensor %s", name)); + } + if (i_layer < 0 || i_layer >= n_layer) { + throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer)); + } + } + return std::make_pair(i_layer, n_layer); + }; + + // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings + // with the quantization of the output tensor + if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { + if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { + new_type = qs.params->output_tensor_type; + } else { + const int64_t nx = tensor->ne[0]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) { + new_type = GGML_TYPE_Q8_0; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { + new_type = GGML_TYPE_Q5_K; + } + else if (new_type != GGML_TYPE_Q8_0) { + new_type = GGML_TYPE_Q6_K; + } + } + } else if (name == "token_embd.weight") { + if (qs.params->token_embedding_type < GGML_TYPE_COUNT) { + new_type = qs.params->token_embedding_type; + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || + ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { + new_type = GGML_TYPE_Q2_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) { + new_type = GGML_TYPE_IQ3_S; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = GGML_TYPE_IQ3_S; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) { + new_type = GGML_TYPE_Q4_K; + } + } + } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || + ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { + if (name.find("attn_v.weight") != std::string::npos) { + if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K; + else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; + ++qs.i_attention_wv; + } + else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) { + new_type = GGML_TYPE_Q4_K; + } + else if (name.find("ffn_down") != std::string::npos) { + if (qs.i_ffn_down < qs.n_ffn_down/8) { + new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; + } + ++qs.i_ffn_down; + } + else if (name.find("attn_output.weight") != std::string::npos) { + if (qs.model.hparams.n_expert == 8) { + new_type = GGML_TYPE_Q5_K; + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S; + } + } + } else if (name.find("attn_v.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { + new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS; + } + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; + else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) { + new_type = GGML_TYPE_Q5_K; + } + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && + use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; + if (qs.model.type == LLM_TYPE_70B) { + // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is + // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with + // nearly negligible increase in model size by quantizing this tensor with more bits: + if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K; + } + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } + ++qs.i_attention_wv; + } else if (name.find("attn_k.weight") != std::string::npos) { + if (qs.model.hparams.n_expert == 8) { + // for the 8-expert model, bumping this to Q8_0 trades just ~128MB + // TODO: explore better strategies + new_type = GGML_TYPE_Q8_0; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { + new_type = GGML_TYPE_IQ3_XXS; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = GGML_TYPE_IQ2_S; + } + } else if (name.find("attn_q.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { + new_type = GGML_TYPE_IQ3_XXS; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { + new_type = GGML_TYPE_IQ2_S; + } + } else if (name.find("ffn_down") != std::string::npos) { + auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); + int i_layer = info.first, n_layer = info.second; + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) { + if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) { + new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { + new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K + : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K + : GGML_TYPE_Q3_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 || + (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { + new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { + if (arch == LLM_ARCH_FALCON) { + new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K : + use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; + } else { + if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; + } + } + else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) { + new_type = GGML_TYPE_Q5_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) { + new_type = GGML_TYPE_Q5_K; + } + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0) + && qs.has_imatrix && i_layer < n_layer/8) { + // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0. + // We only do it when an imatrix is provided because a) we want to make sure that one can always get the + // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix. + new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; + } + ++qs.i_ffn_down; + } else if (name.find("attn_output.weight") != std::string::npos) { + if (arch != LLM_ARCH_FALCON) { + if (qs.model.hparams.n_expert == 8) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || + ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || + ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S || + ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) { + new_type = GGML_TYPE_Q5_K; + } + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K ) new_type = GGML_TYPE_Q3_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ) new_type = GGML_TYPE_Q4_K; + } + } else { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; + } + } + else if (name.find("attn_qkv.weight") != std::string::npos) { + if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { + new_type = GGML_TYPE_Q4_K; + } + else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; + else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; + } + else if (name.find("ffn_gate") != std::string::npos) { + auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str()); + int i_layer = info.first, n_layer = info.second; + if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { + new_type = GGML_TYPE_IQ3_XXS; + } + ++qs.i_ffn_gate; + } + else if (name.find("ffn_up") != std::string::npos) { + auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str()); + int i_layer = info.first, n_layer = info.second; + if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { + new_type = GGML_TYPE_IQ3_XXS; + } + ++qs.i_ffn_up; + } + + // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + //} + // IK: let's remove this, else Q2_K is almost the same as Q3_K_S + //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { + // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; + //} + // This can be used to reduce the size of the Q5_K_S model. + // The associated PPL increase is fully in line with the size reduction + //else { + // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; + //} + bool convert_incompatible_tensor = false; + { + const int64_t nx = tensor->ne[0]; + const int64_t ny = tensor->ne[1]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (nx % qk_k != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } else { + ++qs.n_k_quantized; + } + } + + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { + new_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + ++qs.n_fallback; + } + + return new_type; +} + +static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { + if (nthread < 2) { + // single-thread + size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix); + if (!ggml_validate_row_data(new_type, new_data, new_size)) { + throw std::runtime_error("quantized data validation failed"); + } + return new_size; + } + + std::mutex mutex; + int64_t counter = 0; + size_t new_size = 0; + bool valid = true; + auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size, + nrows, n_per_row, imatrix]() { + const int64_t nrows_per_chunk = chunk_size / n_per_row; + size_t local_size = 0; + while (true) { + std::unique_lock lock(mutex); + int64_t first_row = counter; counter += nrows_per_chunk; + if (first_row >= nrows) { + if (local_size > 0) { + new_size += local_size; + } + break; + } + lock.unlock(); + const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk); + size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix); + local_size += this_size; + + // validate the quantized data + const size_t row_size = ggml_row_size(new_type, n_per_row); + void * this_data = (char *) new_data + first_row * row_size; + if (!ggml_validate_row_data(new_type, this_data, this_size)) { + std::unique_lock lock(mutex); + valid = false; + break; + } + } + }; + for (int it = 0; it < nthread - 1; ++it) { + workers.emplace_back(compute); + } + compute(); + for (auto & w : workers) { w.join(); } + workers.clear(); + if (!valid) { + throw std::runtime_error("quantized data validation failed"); + } + return new_size; +} + +static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { + ggml_type default_type; + llama_ftype ftype = params->ftype; + + switch (params->ftype) { + case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; + case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; + case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; + case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; + case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; + + // K-quants + case LLAMA_FTYPE_MOSTLY_Q2_K_S: + case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: + case LLAMA_FTYPE_MOSTLY_Q3_K_M: + case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: + case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: + case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break; + case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; + case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break; + case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break; + case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; + case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; + case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; + + default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); + } + + int nthread = params->nthread; + + if (nthread <= 0) { + nthread = std::thread::hardware_concurrency(); + } + + // mmap consistently increases speed Linux, and also increases speed on Windows with + // hot cache. It may cause a slowdown on macOS, possibly related to free memory. +#if defined(__linux__) || defined(_WIN32) + constexpr bool use_mmap = true; +#else + constexpr bool use_mmap = false; +#endif + + llama_model_kv_override * kv_overrides = nullptr; + if (params->kv_overrides) { + auto v = (std::vector*)params->kv_overrides; + kv_overrides = v->data(); + } + + std::vector splits = {}; + llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides); + ml.init_mappings(false); // no prefetching + + llama_model model(llama_model_default_params()); + + model.load_arch (ml); + model.load_hparams(ml); + model.load_stats (ml); + + struct quantize_state_impl qs(model, params); + + if (params->only_copy) { + ftype = ml.ftype; + } + const std::unordered_map> * imatrix_data = nullptr; + if (params->imatrix) { + imatrix_data = static_cast>*>(params->imatrix); + if (imatrix_data) { + LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); + qs.has_imatrix = true; + // check imatrix for nans or infs + for (const auto & kv : *imatrix_data) { + for (float f : kv.second) { + if (!std::isfinite(f)) { + throw std::runtime_error(format("imatrix contains non-finite value %f\n", f)); + } + } + } + } + } + + const size_t align = GGUF_DEFAULT_ALIGNMENT; + gguf_context_ptr ctx_out { gguf_init_empty() }; + + // copy the KV pairs from the input file + gguf_set_kv (ctx_out.get(), ml.meta.get()); + gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV + gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV + + // Remove split metadata + gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str()); + gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str()); + gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str()); + + if (params->kv_overrides) { + const std::vector & overrides = *(const std::vector *)params->kv_overrides; + for (const auto & o : overrides) { + if (o.key[0] == 0) break; + if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { + gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { + gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { + gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool); + } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + gguf_set_val_str(ctx_out.get(), o.key, o.val_str); + } else { + LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); + } + } + } + + // make a list of weights + std::vector tensors; + tensors.reserve(ml.weights_map.size()); + for (const auto & it : ml.weights_map) { + tensors.push_back(&it.second); + } + + // keep_split requires that the weights are sorted by split index + if (params->keep_split) { + std::sort(tensors.begin(), tensors.end(), [](const llama_model_loader::llama_tensor_weight * a, const llama_model_loader::llama_tensor_weight * b) { + if (a->idx == b->idx) { + return a->offs < b->offs; + } + return a->idx < b->idx; + }); + } + + for (const auto * it : tensors) { + const struct ggml_tensor * tensor = it->tensor; + + const std::string name = ggml_get_name(tensor); + + // TODO: avoid hardcoded tensor names - use the TN_* constants + if (name.find("attn_v.weight") != std::string::npos || + name.find("attn_qkv.weight") != std::string::npos || + name.find("attn_kv_b.weight")!= std::string::npos) { + ++qs.n_attention_wv; + } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { + qs.has_output = true; + } + } + + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; + + // sanity checks for models that have attention layers + if (qs.n_attention_wv != 0) + { + const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); + // attention layers have a non-zero number of kv heads + int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); + if (llama_model_has_encoder(&model)) { + n_attn_layer *= 3; + } + GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected"); + } + + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector workers; + workers.reserve(nthread); + + int idx = 0; + + std::vector> read_data; + std::vector> work; + std::vector> f32_conv_buf; + + uint16_t n_split = 1; + + // Assume split index is continuous + if (params->keep_split) { + for (const auto * it : tensors) { + n_split = std::max(uint16_t(it->idx + 1), n_split); + } + } + std::vector ctx_outs(n_split); + ctx_outs[0] = std::move(ctx_out); + + // populate the original tensors so we get an initial meta data + for (const auto * it : tensors) { + uint16_t i_split = params->keep_split ? it->idx : 0; + struct ggml_tensor * tensor = it->tensor; + if (!ctx_outs[i_split]) { + ctx_outs[i_split].reset(gguf_init_empty()); + } + gguf_add_tensor(ctx_outs[i_split].get(), tensor); + } + + // Set split info if needed + if (n_split > 1) { + for (size_t i = 0; i < ctx_outs.size(); ++i) { + gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i); + gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split); + gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors); + } + } + + int cur_split = -1; + std::ofstream fout; + auto close_ofstream = [&]() { + // Write metadata and close file handler + if (fout.is_open()) { + fout.seekp(0); + std::vector data(gguf_get_meta_size(ctx_outs[cur_split].get())); + gguf_get_meta_data(ctx_outs[cur_split].get(), data.data()); + fout.write((const char *) data.data(), data.size()); + fout.close(); + } + }; + auto new_ofstream = [&](int index) { + cur_split = index; + GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context"); + std::string fname = fname_out; + if (params->keep_split) { + std::vector split_path(llama_path_max(), 0); + llama_split_path(split_path.data(), split_path.size(), fname_out.c_str(), cur_split, n_split); + fname = std::string(split_path.data()); + } + + fout = std::ofstream(fname, std::ios::binary); + fout.exceptions(std::ofstream::failbit); // fail fast on write errors + const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split].get()); + // placeholder for the meta data + ::zeros(fout, meta_size); + }; + + const auto tn = LLM_TN(model.arch); + new_ofstream(0); + for (const auto * it : tensors) { + const auto & weight = *it; + struct ggml_tensor * tensor = weight.tensor; + if (weight.idx != cur_split && params->keep_split) { + close_ofstream(); + new_ofstream(weight.idx); + } + + const std::string name = ggml_get_name(tensor); + + if (!ml.use_mmap) { + if (read_data.size() < ggml_nbytes(tensor)) { + read_data.resize(ggml_nbytes(tensor)); + } + tensor->data = read_data.data(); + } + ml.load_data_for(tensor); + + LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", + ++idx, ml.n_tensors, + ggml_get_name(tensor), + llama_format_tensor_shape(tensor).c_str(), + ggml_type_name(tensor->type)); + + // This used to be a regex, but has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + + // quantize only 2D and 3D tensors (experts) + quantize &= (ggml_n_dims(tensor) >= 2); + + // do not quantize norm tensors + quantize &= name.find("_norm.weight") == std::string::npos; + + quantize &= params->quantize_output_tensor || name != "output.weight"; + quantize &= !params->only_copy; + + // do not quantize expert gating tensors + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; + + // do not quantize positional embeddings and token types (BERT) + quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); + quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); + + // do not quantize Mamba's small yet 2D weights + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + + // do not quantize RWKV's time_mix_first tensors + quantize &= name.find("time_mix_first.weight") == std::string::npos; + quantize &= name.find("time_mix_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; + + // do not quantize relative position bias (T5) + quantize &= name.find("attn_rel_b.weight") == std::string::npos; + + enum ggml_type new_type; + void * new_data; + size_t new_size; + + if (quantize) { + new_type = default_type; + + // get more optimal quantization type based on the tensor shape, layer, etc. + if (!params->pure && ggml_is_quantized(default_type)) { + new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + } + if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { + new_type = params->token_embedding_type; + } + if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { + new_type = params->output_tensor_type; + } + + // If we've decided to quantize to the same type the tensor is already + // in then there's nothing to do. + quantize = tensor->type != new_type; + } + + if (!quantize) { + new_type = tensor->type; + new_data = tensor->data; + new_size = ggml_nbytes(tensor); + LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); + } else { + const int64_t nelements = ggml_nelements(tensor); + + const float * imatrix = nullptr; + if (imatrix_data) { + auto it = imatrix_data->find(tensor->name); + if (it == imatrix_data->end()) { + LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); + } else { + if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { + imatrix = it->second.data(); + } else { + LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); + + // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix + // this is a significant error and it may be good idea to abort the process if this happens, + // since many people will miss the error and not realize that most of the model is being quantized without an imatrix + // tok_embd should be ignored in this case, since it always causes this warning + if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { + throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); + } + } + } + } + if ((new_type == GGML_TYPE_IQ2_XXS || + new_type == GGML_TYPE_IQ2_XS || + new_type == GGML_TYPE_IQ2_S || + new_type == GGML_TYPE_IQ1_S || + (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || + (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { + LLAMA_LOG_ERROR("\n\n============================================================\n"); + LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); + LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); + LLAMA_LOG_ERROR("============================================================\n\n"); + throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); + } + + float * f32_data; + + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); + } else { + llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); + f32_data = (float *) f32_conv_buf.data(); + } + + LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); + fflush(stdout); + + if (work.size() < (size_t)nelements * 4) { + work.resize(nelements * 4); // upper bound on size + } + new_data = work.data(); + + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows = tensor->ne[1]; + + static const int64_t min_chunk_size = 32 * 512; + const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); + + const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; + const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; + const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; + + // quantize each expert separately since they have different importance matrices + new_size = 0; + for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { + const float * f32_data_03 = f32_data + i03 * nelements_matrix; + void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; + const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; + + new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + } + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); + } + total_size_org += ggml_nbytes(tensor); + total_size_new += new_size; + + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } + close_ofstream(); + + LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); + LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + + if (qs.n_fallback > 0) { + LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", + __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback); + } +} + +// +// interface implementation +// + +struct llama_model_quantize_params llama_model_quantize_default_params() { + struct llama_model_quantize_params result = { + /*.nthread =*/ 0, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.output_tensor_type =*/ GGML_TYPE_COUNT, + /*.token_embedding_type =*/ GGML_TYPE_COUNT, + /*.allow_requantize =*/ false, + /*.quantize_output_tensor =*/ true, + /*.only_copy =*/ false, + /*.pure =*/ false, + /*.keep_split =*/ false, + /*.imatrix =*/ nullptr, + /*.kv_overrides =*/ nullptr, + }; + + return result; +} + +uint32_t llama_model_quantize( + const char * fname_inp, + const char * fname_out, + const llama_model_quantize_params * params) { + try { + llama_model_quantize_impl(fname_inp, fname_out, params); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what()); + return 1; + } + + return 0; +} diff --git a/src/llama-quant.h b/src/llama-quant.h new file mode 100644 index 000000000..6f70f09be --- /dev/null +++ b/src/llama-quant.h @@ -0,0 +1 @@ +#pragma once diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 41f48ec28..26974f539 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1,56 +1,166 @@ #include "llama-sampling.h" +#include "llama-impl.h" #include "llama-vocab.h" #include "llama-grammar.h" -#include #include +#include +#include +#include +#include +#include #include #include -#include #include #include #include +#include -static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector & probs) { -#if 1 - probs.resize(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - probs[i] = cur_p->data[i].p; +// the ring buffer works similarly to std::deque, but with a fixed capacity +template +struct ring_buffer { + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; } - std::discrete_distribution dist(probs.begin(), probs.end()); -#else - // avoid the copy with a custom iterator + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (capacity == 0) { + throw std::runtime_error("ring buffer: capacity is zero"); + } + + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + //T & operator[](size_t i) { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + //const T & at(size_t i) const { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + const T & rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + + std::vector data; +}; + +static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) { + // iterator for the probabilities +#ifdef __GNUC__ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif struct probs_iterator { typedef std::input_iterator_tag iterator_category; typedef float value_type; typedef float * pointer; typedef float & reference; - typedef size_t difference_type; + typedef ptrdiff_t difference_type; - const llama_token_data_array * data; - size_t i; + const llama_token_data * data; - bool operator==(const probs_iterator & other) const { return data + i == other.data + other.i; } - bool operator!=(const probs_iterator & other) const { return data + i != other.data + other.i; } - float operator*() const { return data->data[i].p; } - probs_iterator & operator++() { ++i; return *this; } - probs_iterator operator++(int) { probs_iterator tmp = *this; ++i; return tmp; } + bool operator==(const probs_iterator & other) const { return data == other.data; } + bool operator!=(const probs_iterator & other) const { return data != other.data; } + const float & operator*() const { return data->p; } + probs_iterator & operator++() { ++data; return *this; } + probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; } }; + +#ifdef __GNUC__ #pragma GCC diagnostic pop - - std::discrete_distribution dist(probs_iterator{cur_p, 0}, probs_iterator{cur_p, cur_p->size}); - - GGML_UNUSED(probs); #endif + std::discrete_distribution dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size}); + return dist(rng); } +/* static void llama_log_softmax(float * array, size_t size) { float max_l = *std::max_element(array, array + size); float sum = 0.f; @@ -64,6 +174,31 @@ static void llama_log_softmax(float * array, size_t size) { array[i] = logf(array[i] / sum); } } +*/ + +static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { + if (temp <= 0.0f) { + // find the token with the highest logit and set the rest to -inf + size_t max_i = 0; + float max_l = cur_p->data[0].logit; + + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i ].logit > max_l) { + cur_p->data[max_i].logit = -INFINITY; + max_i = i; + max_l = cur_p->data[i].logit; + } else { + cur_p->data[i].logit = -INFINITY; + } + } + + return; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= temp; + } +} static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { GGML_ASSERT(cur_p->size > 0); @@ -91,7 +226,7 @@ static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { } static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { - // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast + // TODO: move bucket sort to separate function so that top_p/typical/softmax first is equally fast // if (k >= (int32_t)cur_p->size) { // return; // } @@ -122,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) for (int i = 0; i < (int)cur_p->size; ++i) { const float val = cur_p->data[i].logit; int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); - ib = std::max(0, std::min(nbuckets-1, ib)); + ib = std::max(0, std::min(nbuckets - 1, ib)); bucket_idx[i] = ib; ++histo[ib]; } @@ -145,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) for (int i = 0; i < (int)cur_p->size; ++i) { int j = bucket_idx[i]; if (j >= ib) { - *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i]; + *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i]; } } ptr = tmp_tokens.data(); int ndone = 0; - for (int j = nbuckets-1; j > ib; --j) { + for (int j = nbuckets - 1; j > ib; --j) { std::sort(ptr, ptr + histo[j], comp); ptr += histo[j]; ndone += histo[j]; @@ -166,6 +301,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) cur_p->size = k; } +static uint32_t get_rng_seed(uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + // use system clock if std::random_device is not a true RNG + static bool is_rd_prng = std::random_device().entropy() == 0; + if (is_rd_prng) { + return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count(); + } + std::random_device rd; + return rd(); + } + return seed; +} + // llama_sampler API const char * llama_sampler_name(const struct llama_sampler * smpl) { @@ -223,75 +371,104 @@ void llama_sampler_free(struct llama_sampler * smpl) { llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { const auto * logits = llama_get_logits_ith(ctx, idx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); // TODO: do not allocate each time - std::vector cur(n_vocab); + std::vector cur; + cur.reserve(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; llama_sampler_apply(smpl, &cur_p); - return cur_p.data[cur_p.selected].id; + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + + auto token = cur_p.data[cur_p.selected].id; + + llama_sampler_accept(smpl, token); + + return token; } // sampler chain +static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { + return "chain"; +} + +static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_accept(smpl, token); + } + + chain->n_sample++; +} + +static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto * smpl : chain->samplers) { + llama_sampler_apply(smpl, cur_p); + } +} + +static void llama_sampler_chain_reset(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_reset(smpl); + } + + chain->t_sample_us = 0; + chain->n_sample = 0; +} + +static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; + + auto * result = llama_sampler_chain_init(chain_src->params); + + for (auto * smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + } + + return result; +} + +static void llama_sampler_chain_free(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_free(smpl); + } + + delete chain; +} + static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_perf); - - for (auto * smpl : chain->samplers) { - llama_sampler_accept(smpl, token); - } - - chain->n_sample++; - }, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_perf); - - for (auto * smpl : chain->samplers) { - llama_sampler_apply(smpl, cur_p); - } - }, - /* .reset = */ [](struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_reset(smpl); - } - - chain->t_sample_us = 0; - chain->n_sample = 0; - }, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; - - auto * result = llama_sampler_chain_init(chain_src->params); - - for (auto * smpl : chain_src->samplers) { - llama_sampler_chain_add(result, llama_sampler_clone(smpl)); - } - - return result; - }, - /* .free = */ [](struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_free(smpl); - } - - delete chain; - }, + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -314,13 +491,26 @@ void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { const auto * p = (const llama_sampler_chain *) chain->ctx; - if (i < 0 || i >= (int32_t) p->samplers.size()) { + if (i < 0 || (size_t) i >= p->samplers.size()) { return nullptr; } return p->samplers[i]; } +struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { + auto * p = (llama_sampler_chain *) chain->ctx; + + if (i < 0 || (size_t) i >= p->samplers.size()) { + return nullptr; + } + + auto * result = p->samplers[i]; + p->samplers.erase(p->samplers.begin() + i); + + return result; +} + int llama_sampler_chain_n(const struct llama_sampler * chain) { const auto * p = (const llama_sampler_chain *) chain->ctx; @@ -366,10 +556,9 @@ struct llama_sampler * llama_sampler_init_greedy() { struct llama_sampler_dist { const uint32_t seed; + uint32_t seed_cur; std::mt19937 rng; - - std::vector probs; // work array }; static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { @@ -378,7 +567,10 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl* static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_dist *) smpl->ctx; - cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + + llama_sampler_softmax_impl(cur_p); + + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); } static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { @@ -397,7 +589,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample static void llama_sampler_dist_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_dist *) smpl->ctx; - ctx->rng = std::mt19937(ctx->seed); + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); } static void llama_sampler_dist_free(struct llama_sampler * smpl) { @@ -414,12 +607,13 @@ static struct llama_sampler_i llama_sampler_dist_i = { }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { + auto seed_cur = get_rng_seed(seed); return new llama_sampler { /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { - /* .seed = */ seed, - /* .rng = */ std::mt19937(seed), - /* .probs = */ {}, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), }, }; } @@ -655,101 +849,6 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { }; } -// tail-free - -struct llama_sampler_tail_free { - const float z; - const size_t min_keep; -}; - -static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) { - return "tail-free"; -} - -static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_tail_free *) smpl->ctx; - - if (ctx->z >= 1.0f || cur_p->size <= 2) { - return; - } - - llama_sampler_softmax_impl(cur_p); - - // Compute the first and second derivatives - std::vector first_derivatives(cur_p->size - 1); - std::vector second_derivatives(cur_p->size - 2); - - for (size_t i = 0; i < first_derivatives.size(); ++i) { - first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p; - } - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; - } - - // Calculate absolute value of second derivatives - for (size_t i = 0; i < second_derivatives.size(); ++i) { - second_derivatives[i] = std::abs(second_derivatives[i]); - } - - // Normalize the second derivatives - { - const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); - - if (second_derivatives_sum > 1e-6f) { - for (float & value : second_derivatives) { - value /= second_derivatives_sum; - } - } else { - for (float & value : second_derivatives) { - value = 1.0f / second_derivatives.size(); - } - } - } - - float cum_sum = 0.0f; - size_t last_idx = cur_p->size; - for (size_t i = 0; i < second_derivatives.size(); ++i) { - cum_sum += second_derivatives[i]; - - // Check if the running sum is greater than z or if we have kept at least min_keep tokens - if (cum_sum > ctx->z && i >= ctx->min_keep) { - last_idx = i; - break; - } - } - - // Resize the output vector to keep only the tokens above the tail location - cur_p->size = last_idx; -} - -static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx; - return llama_sampler_init_tail_free(ctx->z, ctx->min_keep); -} - -static void llama_sampler_tail_free_free(struct llama_sampler * smpl) { - delete (llama_sampler_tail_free *) smpl->ctx; -} - -static struct llama_sampler_i llama_sampler_tail_free_i = { - /* .name = */ llama_sampler_tail_free_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_tail_free_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_tail_free_clone, - /* .free = */ llama_sampler_tail_free_free, -}; - -struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) { - return new llama_sampler { - /* .iface = */ &llama_sampler_tail_free_i, - /* .ctx = */ new llama_sampler_tail_free { - /* .z = */ z, - /*. min_keep = */ min_keep, - }, - }; -} - // typical struct llama_sampler_typical { @@ -861,9 +960,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl* static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_temp *) smpl->ctx; - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= ctx->temp; - } + + llama_sampler_temp_impl(cur_p, ctx->temp); } static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { @@ -910,6 +1008,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke if (ctx->delta > 0) { const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); const float max_temp = ctx->temp + ctx->delta; + float exponent_val = ctx->exponent; // no need to do anything if there is only one (or zero) candidates @@ -947,9 +1046,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke #endif // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= dyn_temp; - } + llama_sampler_temp_impl(cur_p, dyn_temp); // Re-compute softmax probabilities after scaling logits with dynamic temperature const double max_l_double = cur_p->data[0].logit; @@ -973,9 +1070,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke } #endif } else { - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= ctx->temp; - } + llama_sampler_temp_impl(cur_p, ctx->temp); } } @@ -1008,12 +1103,108 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa }; } +// xtc + +struct llama_sampler_xtc { + const float probability; + const float threshold; + const size_t min_keep; + + const uint32_t seed; + uint32_t seed_cur; + + std::mt19937 rng; +}; + +static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { + return "xtc"; +} + +static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + + if (ctx->probability <= 0.0f + || ctx->threshold > 0.5f + || cur_p->size < 2) { + return; + } + + std::uniform_real_distribution distribution(0.0f, 1.0f); + float chance = distribution(ctx->rng); + if (chance > ctx->probability) return; + + // in case it's not sorted/recalculated yet + llama_sampler_softmax_impl(cur_p); + + int pos_last = 0; + + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].p >= ctx->threshold) { + pos_last = i; + } else break; + } + + if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) { + cur_p->data += pos_last; + cur_p->size -= pos_last; + } +} + +static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_xtc *) smpl->ctx; + auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_xtc *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_xtc_free(struct llama_sampler * smpl) { + delete (llama_sampler_xtc *) smpl->ctx; +} + +static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler_i llama_sampler_xtc_i = { + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, +}; + +struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + return new llama_sampler { + /* .iface = */ &llama_sampler_xtc_i, + /* .ctx = */ new llama_sampler_xtc { + /* .probability = */ p, + /* .threshold = */ t, + /* .min_keep = */ min_keep, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + }, + }; +} + // mirostat struct llama_sampler_mirostat { const int32_t n_vocab; const uint32_t seed; + uint32_t seed_cur; const float tau; const float eta; @@ -1023,8 +1214,6 @@ struct llama_sampler_mirostat { float mu; std::mt19937 rng; - - std::vector probs; }; static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) { @@ -1055,7 +1244,7 @@ static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_toke llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); llama_sampler_softmax_impl(cur_p); - const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; @@ -1084,7 +1273,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_mirostat *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; - ctx->rng = std::mt19937(ctx->seed); + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); } static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { @@ -1101,17 +1291,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { + auto seed_cur = get_rng_seed(seed); return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_mirostat { - /* .n_vocab = */ n_vocab, - /* .seed = */ seed, - /* .tau = */ tau, - /* .eta = */ eta, - /* .m = */ m, - /* .mu = */ 2.0f*tau, - /* .rng = */ std::mt19937(seed), - /* .probs = */ {}, + /* .n_vocab = */ n_vocab, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .tau = */ tau, + /* .eta = */ eta, + /* .m = */ m, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed_cur), }, }; } @@ -1120,6 +1311,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see struct llama_sampler_mirostat_v2 { const uint32_t seed; + uint32_t seed_cur; const float tau; const float eta; @@ -1127,8 +1319,6 @@ struct llama_sampler_mirostat_v2 { float mu; std::mt19937 rng; - - std::vector probs; }; static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) { @@ -1152,7 +1342,7 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t // Normalize the probabilities of the remaining words llama_sampler_softmax_impl(cur_p); - const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + const int idx = llama_sample_dist(cur_p, ctx->rng); cur_p->selected = idx; @@ -1166,7 +1356,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; - ctx->rng = std::mt19937(ctx->seed); + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); } static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) { @@ -1199,15 +1390,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { + auto seed_cur = get_rng_seed(seed); return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_v2_i, /* .ctx = */ new llama_sampler_mirostat_v2 { - /* .seed = */ seed, - /* .tau = */ tau, - /* .eta = */ eta, - /* .mu = */ 2.0f*tau, - /* .rng = */ std::mt19937(seed), - /* .probs = */ {}, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .tau = */ tau, + /* .eta = */ eta, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed_cur), }, }; } @@ -1241,13 +1433,30 @@ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token } } +// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle. +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (!ctx->grammar) { return; } - auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); + std::vector trigger_words; + for (auto & word : ctx->grammar->trigger_words) { + trigger_words.push_back(word.c_str()); + } + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), + ctx->grammar->lazy, trigger_words.data(), trigger_words.size(), + ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); ctx->grammar = grammar_new; @@ -1256,7 +1465,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; - auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr); + auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); // copy the state { @@ -1292,19 +1501,27 @@ static struct llama_sampler_i llama_sampler_grammar_i = { /* .free = */ llama_sampler_grammar_free, }; -struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { +static struct llama_sampler * llama_sampler_init_grammar_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { auto * ctx = new llama_sampler_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { *ctx = { - /* .vocab = */ &vocab, + /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), }; } else { *ctx = { - /* .vocab = */ &vocab, + /* .vocab = */ vocab, /* .grammar_str = */ {}, /* .grammar_root = */ {}, /* .grammar = */ nullptr, @@ -1317,22 +1534,36 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab }; } +struct llama_sampler * llama_sampler_init_grammar( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens); +} + // penalties struct llama_sampler_penalties { - const int32_t n_vocab; - const llama_token special_eos_id; - const llama_token linefeed_id; - const int32_t penalty_last_n; const float penalty_repeat; const float penalty_freq; const float penalty_present; - const bool penalize_nl; - const bool ignore_eos; - ring_buffer prev; + + // a frequency map to count token occurrences + std::unordered_map token_count; }; static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { @@ -1345,76 +1576,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to return; } + ctx->token_count[token]++; + + // if the ring buffer is full, remove the oldest token + if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) { + const auto old = ctx->prev.front(); + + ctx->token_count[old]--; + if (ctx->token_count[old] == 0) { + ctx->token_count.erase(old); + } + } + ctx->prev.push_back(token); + +#if 0 + // sanity check + std::unordered_map tmp; + for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { + tmp[ctx->prev.rat(i)]++; + } + + assert(ctx->token_count == tmp); +#endif } static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_penalties *) smpl->ctx; - if (ctx->ignore_eos) { - assert(ctx->special_eos_id >= 0); - - // optimistically check if the candidates are not yet sorted/shuffled/truncated - if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) { - cur_p->data[ctx->special_eos_id].logit = -INFINITY; - } else { - // else, search for the special EOS token - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id == ctx->special_eos_id) { - cur_p->data[i].logit = -INFINITY; - break; - } - } - } - } - if ((ctx->penalty_last_n == 0) || (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { return; } - bool nl_found = false; - size_t nl_idx = 0; - float nl_logit = -INFINITY; - if (!ctx->penalize_nl) { - assert(ctx->linefeed_id >= 0); - - // optimistically check if the candidates are not yet sorted/shuffled/truncated - if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) { - nl_found = true; - nl_idx = ctx->linefeed_id; - nl_logit = cur_p->data[ctx->linefeed_id].logit; - } else { - // else, search for the linefeed token - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id == ctx->linefeed_id) { - nl_found = true; - nl_idx = i; - nl_logit = cur_p->data[i].logit; - break; - } - } - } - } - - // Create a frequency map to count occurrences of each token in last_tokens - // TODO: optimize this by maintaining the token count in the sampler context - using llama_token_cnt = std::unordered_map; - llama_token_cnt token_count; - - for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { - token_count[ctx->prev.rat(i)]++; - } - // Apply frequency and presence penalties to the cur_p for (size_t i = 0; i < cur_p->size; ++i) { - const auto token_iter = token_count.find(cur_p->data[i].id); - if (token_iter == token_count.end()) { + const auto token_iter = ctx->token_count.find(cur_p->data[i].id); + if (token_iter == ctx->token_count.end()) { continue; } const int count = token_iter->second; + assert(count > 0 && count <= ctx->penalty_last_n); + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. if (cur_p->data[i].logit <= 0) { @@ -1427,30 +1632,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok } cur_p->sorted = false; - - if (!ctx->penalize_nl && nl_found) { - // restore the logit of the newline token if it was penalized - cur_p->data[nl_idx].logit = nl_logit; - } } static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_penalties *) smpl->ctx; ctx->prev.clear(); + ctx->token_count.clear(); } static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; auto * result = llama_sampler_init_penalties( - ctx->n_vocab, - ctx->special_eos_id, - ctx->linefeed_id, ctx->penalty_last_n, ctx->penalty_repeat, ctx->penalty_freq, - ctx->penalty_present, - ctx->penalize_nl, - ctx->ignore_eos); + ctx->penalty_present); // copy the state { @@ -1476,40 +1672,420 @@ static struct llama_sampler_i llama_sampler_penalties_i = { }; struct llama_sampler * llama_sampler_init_penalties( - int32_t n_vocab, - llama_token special_eos_id, - llama_token linefeed_id, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos) { - if (linefeed_id == LLAMA_TOKEN_NULL) { - penalize_nl = true; - } - - if (special_eos_id == LLAMA_TOKEN_NULL) { - ignore_eos = false; - } + float penalty_present) { + penalty_last_n = std::max(penalty_last_n, 0); return new llama_sampler { /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { - /* .n_vocab = */ n_vocab, - /* .special_eos_id = */ special_eos_id, - /* .linefeed_id = */ linefeed_id, /* .penalty_last_n = */ penalty_last_n, /* .penalty_repeat = */ penalty_repeat, /* .penalty_freq = */ penalty_freq, /* .penalty_present = */ penalty_present, - /* .penalize_nl = */ penalize_nl, - /* .ignore_eos = */ ignore_eos, /* .prev = */ ring_buffer(penalty_last_n), + /* .token_count = */ {}, }, }; } +// DRY + +struct llama_sampler_dry { + int32_t total_context_size; + + const float dry_multiplier; + const float dry_base; + const int32_t dry_allowed_length; + const int32_t dry_penalty_last_n; + + std::unordered_multimap> dry_processed_breakers; + std::vector dry_repeat_count; + std::unordered_map dry_max_token_repeat; + ring_buffer last_tokens; +}; + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { + for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) { + std::string word = vocab.detokenize({token_id}, true); + if (word.find(str) != std::string::npos) { + token_sequences.emplace(token_id, std::vector()); + } else { + size_t word_len = word.size(); + size_t str_len = str.size(); + size_t pos = -1; + while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { + bool match = true; + size_t i; + for (i = 1; i < str_len && i + pos < word_len; ++i) { + if (word[pos + i] != str[i]) { + match = false; + break; + } + } + if (match) { + std::vector tokenization = vocab.tokenize(str.substr(i), false, false); + if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { + tokenization.resize(max_tail_len); + } + + // Ensure we don't already have a duplicate matching tokenization + auto its = token_sequences.equal_range(token_id); + bool found = false; + for (auto it = its.first; it != its.second; ++it) { + if (tokenization == it->second) { + found = true; + break; + } + } + if (!found) { + token_sequences.emplace(token_id, tokenization); + } + } + } + } + } +} + +static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) { + return "dry"; +} + +static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + ctx->last_tokens.push_back(token); +} + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0); + int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size); + + if (last_n_repeat <= ctx->dry_allowed_length) { + return; + } + + ctx->dry_repeat_count.assign(last_n_repeat, 0); + ctx->dry_max_token_repeat.clear(); + + // Step 1: Look for restart sequences to limit the maximum repetition length. + // Work backwards through the context looking for any token that begins a restart sequence. + // + // The collection `restart_sequences` is a mapping from a "head" token to all "tail" + // sequences that together comprise a restart sequence. This allows us to quickly check + // whether each token is the head of a complete sequence. Most restart sequences are actually + // a single token, and for these the "tail" is an empty vector. + // + // If the token is a "head", test all restart sequences that begin with this token + // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and + // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The + // longest matching sequence (if any) is used to limit the maximum repetition length. + // + // Note that in the case case of a short sequence contained in a longer one, this might fail to + // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as + // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress + // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. + // + // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we + // have already clamped the maximum tail sequence length when generating `restart_sequences`. + // With clamping, this scan is O(N) in the context length. + + int rep_limit = last_n_repeat; + for (int i = 0; i < last_n_repeat; ++i) { + llama_token token = ctx->last_tokens.rat(i); + auto its = ctx->dry_processed_breakers.equal_range(token); + if (its.first == ctx->dry_processed_breakers.end()) { + continue; + } + int longest_match = -1; + for (auto it = its.first; it != its.second; ++it) { + // Note that (*it) does not contain the head character, so seq_len will be + // the restart sequence length minus 1. + // In the common case of a single-token restart sequence, (*it) will be empty + // and we will trivially match. + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= (int)i) { + bool match = true; + for (int offset = 0; offset < seq_len; ++offset) { + // The -1 when indexing `last_tokens` is because we already matched the head. + if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + if (longest_match >= 0) { + // We found a restart sequence starting `i` tokens from the end and continuing for + // `longest_match` tokens. + rep_limit = i - longest_match; + break; + } + } + if (rep_limit < ctx->dry_allowed_length) { + return; + } + + // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in + // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing + // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. + // + // This algorithm is not currently documented on Wikipedia, but there is a clear description here: + // https://ivanyu.me/blog/2014/10/15/z-algorithm/ + // + // The code below is adapted from the public domain implementation by the same author here: + // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py + // + // Example: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // ^ + // This `3` means that the last three tokens of the context (a b c) also appear here. + // + // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested + // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each + // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables + // ensure that the inner while loops only examine each token in the context once as the outer + // for loop iterates over the context. + + { + const int last = last_n_repeat - 1; + int rt = 0, lt = 0; + + for (int k = 1; k < last_n_repeat; ++k) { + if (k > rt) { + // If k is outside the current Z-box, do naive computation. + int n = 0; + while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) { + ++n; + } + ctx->dry_repeat_count[last - k] = std::min(n, rep_limit); + if (n > 0) { + lt = k; + rt = k + n - 1; + } + } else { + // If k is inside the current Z-box, consider two cases. + + int p = k - lt; // Pair index. + int right_part_len = rt - k + 1; + + if (ctx->dry_repeat_count[last - p] < right_part_len) { + int n = std::min(ctx->dry_repeat_count[last - p], rep_limit); + ctx->dry_repeat_count[last - k] = n; + } else { + int i = rt + 1; + while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) { + i += 1; + } + + int n = std::min(i - k, rep_limit); + ctx->dry_repeat_count[last - k] = n; + lt = k; + rt = i - 1; + } + } + } + } + + // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length + // that would be generated by emitting each new token that would extend a sequence. + // + // Following the same example as above: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // + // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. + // c: 3 -> 4 (from `a b c` to `a b c c`) + // b: 1 -> 2 (from `c` to `c b`) + // y: 2 -> 3 (from `b c` to `b c y`) + + for (int i = 0; i < last_n_repeat - 1; ++i) { + int repeat_len = ctx->dry_repeat_count[i]; + if (repeat_len >= ctx->dry_allowed_length) { + // This token ends a repeat, so the next token would continue one. + // By convention, the value of `repeat_len` only includes the tokens currently + // in the context, not the new token that would be added. + llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i); + // Track the maximum sequence ending in this token. + const auto& it = ctx->dry_max_token_repeat.find(token); + if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) { + ctx->dry_max_token_repeat[token] = repeat_len; + } + } + } + + // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. + + // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. + // Compute it from `penalty_base` and the approximate log of `std::numeric_limits::max()` + const float FLOAT_MAX_LOG = 88.7228391f; + int max_exponent = 0; + if (ctx->dry_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base); + } + + for (size_t i = 0; i < cur_p->size; ++i) { + const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id); + if (af_kvp != ctx->dry_max_token_repeat.end()) { + // Check all sequence breakers starting with this token + auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id); + bool is_single_token_breaker = false; + + for (auto it = range.first; it != range.second; ++it) { + if (it->second.empty()) { + is_single_token_breaker = true; + break; + } + } + + // Apply penalty only if it's not a single-token sequence breaker + if (!is_single_token_breaker) { + int repeat_exp = af_kvp->second - ctx->dry_allowed_length; + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp); + cur_p->data[i].logit -= penalty; + } + } + } + + cur_p->sorted = false; +} + +static void llama_sampler_dry_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + ctx->last_tokens.clear(); + ctx->dry_repeat_count.clear(); + ctx->dry_max_token_repeat.clear(); +} + +static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_dry *) smpl->ctx; + + llama_vocab dummy_vocab; + + // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying + auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); + + // Copy the state, including the processed breakers + { + auto * result_ctx = (llama_sampler_dry *) result->ctx; + result_ctx->dry_processed_breakers = ctx->dry_processed_breakers; + result_ctx->dry_repeat_count = ctx->dry_repeat_count; + result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat; + result_ctx->last_tokens = ctx->last_tokens; + } + + return result; +} + +static void llama_sampler_dry_free(struct llama_sampler * smpl) { + delete (llama_sampler_dry *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dry_i = { + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, +}; + +struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); + std::unordered_multimap> processed_breakers; + const int MAX_CHAR_LEN = 40; + const int MAX_SEQ_LEN = 20; + + const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { + // Process sequence breakers + for (size_t i = 0; i < num_breakers; ++i) { + if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) { + LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i); + continue; + } + + std::string sequence_break(seq_breakers[i]); + if (sequence_break.empty()) { + LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n"); + continue; + } + + if (sequence_break.size() > MAX_CHAR_LEN) { + LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN); + sequence_break.resize(MAX_CHAR_LEN); + } + + get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); + } + } + + return new llama_sampler { + /* .iface = */ &llama_sampler_dry_i, + /* .ctx = */ new llama_sampler_dry { + /* .total_context_size = */ context_size, + /* .dry_multiplier = */ dry_multiplier, + /* .dry_base = */ dry_base, + /* .dry_allowed_length = */ dry_allowed_length, + /* .dry_penalty_last_n = */ dry_penalty_last_n, + /* .dry_processed_breakers = */ std::move(processed_breakers), + /* .dry_repeat_count = */ dry_enabled ? std::vector(effective_dry_penalty_last_n, 0) : std::vector{}, + /* .dry_max_token_repeat = */ {}, + /* .last_tokens = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0), + }, + }; +} + +// wrapper for test-sampling.cpp +struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers) { + llama_vocab dummy_vocab; + auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); + auto * ctx = (llama_sampler_dry *) result->ctx; + + // Process the token-based sequence breakers + ctx->dry_processed_breakers.clear(); + if (seq_breakers.empty()) { + LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n"); + } else { + for (const auto& breaker : seq_breakers) { + if (breaker.empty()) { + LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n"); + continue; + } + llama_token head_token = breaker[0]; + std::vector tail_tokens(breaker.begin() + 1, breaker.end()); + ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens)); + } + + if (ctx->dry_processed_breakers.empty()) { + LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n"); + } + } + + return result; +} + // logit-bias struct llama_sampler_logit_bias { @@ -1527,6 +2103,10 @@ static const char * llama_sampler_logit_bias_name(const struct llama_sampler * / static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + if (ctx->logit_bias.empty()) { + return; + } + ctx->to_search.clear(); // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) @@ -1538,6 +2118,10 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to } } + if (ctx->to_search.empty()) { + return; + } + // search for the remaining candidates that were not found in the previous step for (size_t i = 0; i < cur_p->size; ++i) { for (const auto & lb : ctx->to_search) { @@ -1548,6 +2132,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to } } } + static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); @@ -1579,3 +2164,287 @@ struct llama_sampler * llama_sampler_init_logit_bias( }, }; } + +// infill + +//#define GGML_DEBUG_SAMPLER_INFILL + +struct llama_sampler_infill { + const struct llama_vocab * vocab; + + std::vector buf0; + std::vector buf1; +}; + +static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { + return "infill"; +} + +static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_infill *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + +#if defined(GGML_DEBUG_SAMPLER_INFILL) +#define LOG_DBG_CUR LLAMA_LOG_DEBUG +#else +#define LOG_DBG_CUR(...) +#endif + + for (size_t i = 0; i < cur_p->size; ++i) { + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + float p_txt_sum = 0.0f; + float p_eog_sum = 0.0f; + + for (size_t i = 0; i < cur_p->size; ++i) { + if (ctx->vocab->is_eog(cur_p->data[i].id)) { + p_eog_sum += cur_p->data[i].p; + } else { + p_txt_sum += cur_p->data[i].p; + } + } + + const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat); + + LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size); + + if (3*p_eog_sum*cur_p->size > p_txt_sum) { + LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum); + + // keep just the EOG tokens + const auto size_org = cur_p->size; + + cur_p->size = 0; + + float p_sum = 0.0f; + + for (size_t i = 0; i < size_org; ++i) { + if (ctx->vocab->is_eog(cur_p->data[i].id)) { + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + } + + return; + } + + size_t n_combined = 0; GGML_UNUSED(n_combined); + + // combine tokens with common prefix + for (size_t i0 = 0; i0 < cur_p->size; ++i0) { + for (size_t i1 = 0; i1 < cur_p->size; ++i1) { + if (cur_p->data[i0].logit == -INFINITY) { + break; + } + + if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) { + continue; + } + + int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + if (len0 < 0) { + ctx->buf0.resize(len0); + len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + assert(len0 > 0); + } + + int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + if (len1 < 0) { + ctx->buf1.resize(len1); + len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + assert(len1 > 0); + } + + // token i0 is a prefix of token i1 + if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) { + int dst = i0; + int src = i1; + + // merge into the token with higher probability + if (cur_p->data[i1].p > cur_p->data[i0].p) { + std::swap(dst, src); + } + + cur_p->data[dst].p += cur_p->data[src].p; + cur_p->data[src].logit = -INFINITY; + cur_p->data[src].p = 0.0f; + + n_combined++; + } + } + } + + size_t n_non_eog = 0; + + size_t size_org = cur_p->size; + + float p_sum = 0.0f; + float thold = 0.2f; + + cur_p->size = 0; + + LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + if (!is_eog) { + ++n_non_eog; + } + + p_sum += cur_p->data[i].p; + + // keep this token + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog); + + // if no non-EOG tokens are left -> reduce cur_p to single EOT token + if (n_non_eog == 0) { + cur_p->size = 1; + cur_p->data[0].id = ctx->vocab->token_eot(); + cur_p->data[0].logit = 1.0f; + + return; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + size_org = cur_p->size; + p_sum = 0.0f; + thold = 1.0/(n_non_eog + 1); + + cur_p->size = 0; + + LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + +#undef LOG_DBG_CUR +} + +static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_infill *) smpl->ctx; + return llama_sampler_init_infill(ctx->vocab); +} + +static void llama_sampler_infill_free(struct llama_sampler * smpl) { + delete (llama_sampler_infill *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_infill_i = { + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, +}; + +struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { + return new llama_sampler { + /* .iface = */ &llama_sampler_infill_i, + /* .ctx = */ new llama_sampler_infill { + /* .vocab = */ vocab, + /* .buf0 = */ std::vector(512), + /* .buf1 = */ std::vector(512), + }, + }; +} + +// utils + +uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { + if (smpl->iface == &llama_sampler_dist_i) { + return ((const llama_sampler_dist *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_mirostat_i) { + return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_mirostat_v2_i) { + return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur; + } + + if (smpl->iface == &llama_sampler_chain_i) { + const auto * ctx = (const llama_sampler_chain *) smpl->ctx; + for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { + const uint32_t seed = llama_sampler_get_seed(*it); + if (seed != LLAMA_DEFAULT_SEED) { + return seed; + } + } + } + + return LLAMA_DEFAULT_SEED; +} + +// perf + +struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) { + struct llama_perf_sampler_data data = {}; + + if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { + GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); + } + + const auto * ctx = (const struct llama_sampler_chain *) chain->ctx; + + data.t_sample_ms = 1e-3 * ctx->t_sample_us; + data.n_sample = std::max(0, ctx->n_sample); + + return data; +} + +void llama_perf_sampler_print(const struct llama_sampler * chain) { + const auto data = llama_perf_sampler(chain); + + LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample); +} + +void llama_perf_sampler_reset(struct llama_sampler * chain) { + if (chain == nullptr || chain->iface != &llama_sampler_chain_i) { + GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__); + } + + auto * ctx = (struct llama_sampler_chain *) chain->ctx; + + ctx->t_sample_us = ctx->n_sample = 0; +} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index d90b14713..759dd7dcb 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -2,9 +2,9 @@ // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? -#include "llama-grammar.h" +#include "llama.h" -#include +#include struct llama_vocab; struct llama_grammar; @@ -23,7 +23,10 @@ struct llama_sampler_chain { mutable int32_t n_sample; }; -struct llama_sampler * llama_sampler_init_grammar_impl( - const struct llama_vocab & vocab, - const char * grammar_str, - const char * grammar_root); +struct llama_sampler * llama_sampler_init_dry_testing( + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector>& seq_breakers); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 2c007477e..ad9ffe66a 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1,5 +1,8 @@ #include "llama-vocab.h" +#include "llama-impl.h" +#include "llama-model-loader.h" + #include "unicode.h" #include @@ -9,29 +12,15 @@ #include #include #include +#include #include -#include +#include +#include // // helpers // -LLAMA_ATTRIBUTE_FORMAT(1, 2) -static std::string format(const char * fmt, ...) { - va_list ap; - va_list ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - GGML_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), size); -} - struct naive_trie { naive_trie() : has_value(false), value(0) { } @@ -50,7 +39,7 @@ struct naive_trie { res.first->second.insert(key + 1, len - 1, value); } } - std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) { + std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) const { if (len == 0 || offset == len) { return std::make_pair(key, offset); } @@ -76,86 +65,13 @@ struct naive_trie { }; // -// impl +// tokenizers // -int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { - GGML_ASSERT(token_left.find(' ') == std::string::npos); - GGML_ASSERT(token_left.find('\n') == std::string::npos); - GGML_ASSERT(token_right.find(' ') == std::string::npos); - GGML_ASSERT(token_right.find('\n') == std::string::npos); - - auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); - if (it == bpe_ranks.end()) { - return -1; - } - - return it->second; -} - -static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { - return vocab.type; -} - -static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL; -} - -static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN; -} - -static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL; -} - -static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE; -} - -static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED; -} - -static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED; -} - -static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); - GGML_ASSERT(llama_is_byte_token(vocab, id)); - const auto & token_data = vocab.id_to_token.at(id); - switch (llama_vocab_get_type(vocab)) { - case LLAMA_VOCAB_TYPE_SPM: - case LLAMA_VOCAB_TYPE_UGM: { - auto buf = token_data.text.substr(3, 2); - return strtol(buf.c_str(), NULL, 16); - } - case LLAMA_VOCAB_TYPE_BPE: { - GGML_ABORT("fatal error"); - //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT? - } - case LLAMA_VOCAB_TYPE_WPM: { - GGML_ABORT("fatal error"); - } - default: - GGML_ABORT("fatal error"); - } -} - -static void llama_escape_whitespace(std::string & text) { - replace_all(text, " ", "\xe2\x96\x81"); -} - -static void llama_unescape_whitespace(std::string & word) { - replace_all(word, "\xe2\x96\x81", " "); -} +struct llm_tokenizer { + llm_tokenizer() {} + virtual ~llm_tokenizer() = default; +}; struct llm_symbol { using index = int; @@ -187,10 +103,14 @@ struct llm_bigram_spm { size_t size; }; -struct llm_tokenizer_spm { - llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {} +struct llm_tokenizer_spm : llm_tokenizer { + llm_tokenizer_spm(const llama_vocab & /*vocab*/) {} +}; - void tokenize(const std::string & text, std::vector & output) { +struct llm_tokenizer_spm_session { + llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { // split string into utf8 chars int index = 0; size_t offs = 0; @@ -207,7 +127,7 @@ struct llm_tokenizer_spm { } // seed the work queue with all possible 2-character tokens. - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { try_add_bigram(i - 1, i); } @@ -249,13 +169,13 @@ struct llm_tokenizer_spm { } private: - void resegment(llm_symbol & symbol, std::vector & output) { + void resegment(llm_symbol & symbol, std::vector & output) { auto text = std::string(symbol.text, symbol.n); - auto token = vocab.token_to_id.find(text); + auto token = vocab.text_to_token(text); // Do we need to support is_unused? - if (token != vocab.token_to_id.end()) { - output.push_back((*token).second); + if (token != LLAMA_TOKEN_NULL) { + output.push_back(token); return; } @@ -265,13 +185,13 @@ private: // output any symbols that did not form tokens as bytes. output.reserve(output.size() + symbol.n); for (int j = 0; j < (int)symbol.n; ++j) { - llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]); - output.push_back(token_id); + llama_token id = vocab.byte_to_token(symbol.text[j]); + output.push_back(id); } return; } - resegment(symbols[p->second.first], output); + resegment(symbols[p->second.first], output); resegment(symbols[p->second.second], output); } @@ -279,19 +199,18 @@ private: if (left == -1 || right == -1) { return; } - const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); - auto token = vocab.token_to_id.find(text); + auto token = vocab.text_to_token(text); - if (token == vocab.token_to_id.end()) { + if (token == LLAMA_TOKEN_NULL) { return; } - if (static_cast((*token).second) >= vocab.id_to_token.size()) { + if (static_cast(token) >= vocab.n_tokens()) { return; } - const auto & tok_data = vocab.id_to_token[(*token).second]; + const auto & tok_data = vocab.get_token_data(token); llm_bigram_spm bigram; bigram.left = left; @@ -306,10 +225,11 @@ private: } const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_spm * spm_tokenizer; std::vector symbols; llm_bigram_spm::queue work_queue; - std::map> rev_merge; }; @@ -352,10 +272,10 @@ struct llm_bigram_bpe { size_t size; }; -struct llm_tokenizer_bpe { - llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) { - GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE); - switch (vocab.type_pre) { +struct llm_tokenizer_bpe : llm_tokenizer { + llm_tokenizer_bpe(const llama_vocab & vocab) { + GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_BPE); + switch (vocab.get_pre_type()) { case LLAMA_VOCAB_PRE_TYPE_LLAMA3: regex_exprs = { // original regex from tokenizer.json @@ -382,6 +302,13 @@ struct llm_tokenizer_bpe { "\\p{N}+", }; break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: + regex_exprs = { + "\\p{N}{1,3}", + "[一-龥぀-ゟ゠-ヿ]+", + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", @@ -404,6 +331,7 @@ struct llm_tokenizer_bpe { case LLAMA_VOCAB_PRE_TYPE_SMOLLM: case LLAMA_VOCAB_PRE_TYPE_CODESHELL: case LLAMA_VOCAB_PRE_TYPE_EXAONE: + case LLAMA_VOCAB_PRE_TYPE_MINERVA: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -450,6 +378,20 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_CHAMELEON: + // Note: in theory, the special token (sentinel and image token) regex_exprs below + // are unnecessary, as they are split in `tokenizer_st_partition` anyway. + // However, since the upstream pre-tokenizer uses them, they are also + // included here (see https://huggingface.co/facebook/chameleon-7b). + regex_exprs = { + "", // Sentinel tokens + "(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens + "([\\t\\n]| | )", // directly from tokenizer.json + "\\p{N}", // Individual digits + "[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -462,36 +404,42 @@ struct llm_tokenizer_bpe { } } - void append(const llama_vocab::id token_id, std::vector & output) const { + std::vector regex_exprs; +}; + +struct llm_tokenizer_bpe_session { + llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + + static void append(const llama_token token_id, std::vector & output) { output.push_back(token_id); } - bool append_bos(std::vector & output) const { - if (vocab.tokenizer_add_bos) { - GGML_ASSERT(vocab.special_bos_id != -1); - output.push_back(vocab.special_bos_id); + bool append_bos(std::vector & output) const { + if (vocab.get_add_bos()) { + GGML_ASSERT(vocab.token_bos() != LLAMA_TOKEN_NULL); + output.push_back(vocab.token_bos()); return true; } return false; } - bool append_eos(std::vector & output) const { - if (vocab.tokenizer_add_eos) { - GGML_ASSERT(vocab.special_eos_id != -1); - output.push_back(vocab.special_eos_id); + bool append_eos(std::vector & output) const { + if (vocab.get_add_eos()) { + GGML_ASSERT(vocab.token_eos() != LLAMA_TOKEN_NULL); + output.push_back(vocab.token_eos()); return true; } return false; } - void check_double_bos_eos(const std::vector & output) const { - if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { + void check_double_bos_eos(const std::vector & output) const { + if (vocab.get_add_bos() && output.size() >= 2 && output[1] == vocab.token_bos()) { LLAMA_LOG_WARN( "%s: Added a BOS token to the prompt as specified by the model but the prompt " "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " "Are you sure this is what you want?\n", __FUNCTION__); } - if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) { + if (vocab.get_add_eos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) { LLAMA_LOG_WARN( "%s: Added a EOS token to the prompt as specified by the model but the prompt " "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " @@ -499,21 +447,21 @@ struct llm_tokenizer_bpe { } } - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - - const auto word_collection = unicode_regex_split(text, regex_exprs); + const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs); symbols_final.clear(); - for (auto & word : word_collection) { + for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); symbols.clear(); int index = 0; size_t offset = 0; - if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + //if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) { symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); offset = word.size(); } @@ -529,7 +477,7 @@ struct llm_tokenizer_bpe { index++; symbols.emplace_back(sym); } - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { add_new_bigram(i - 1, i); } @@ -587,18 +535,18 @@ struct llm_tokenizer_bpe { } const std::string str = std::string(symbol.text, symbol.n); - const auto token = vocab.token_to_id.find(str); + const auto token = vocab.text_to_token(str); - if (token == vocab.token_to_id.end()) { + if (token == LLAMA_TOKEN_NULL) { for (auto j = str.begin(); j != str.end(); ++j) { std::string byte_str(1, *j); - auto token_multibyte = vocab.token_to_id.find(byte_str); - if (token_multibyte != vocab.token_to_id.end()) { - output.push_back(token_multibyte->second); + auto token_multibyte = vocab.text_to_token(byte_str); + if (token_multibyte != LLAMA_TOKEN_NULL) { + output.push_back(token_multibyte); } } } else { - output.push_back((*token).second); + output.push_back(token); } } } @@ -609,7 +557,6 @@ private: if (left == -1 || right == -1) { return; } - std::string left_token = std::string(symbols[left].text, symbols[left].n); std::string right_token = std::string(symbols[right].text, symbols[right].n); @@ -633,12 +580,10 @@ private: } const llama_vocab & vocab; - - std::vector regex_exprs; + const llm_tokenizer_bpe & tokenizer; std::vector symbols; std::vector symbols_final; - llm_bigram_bpe::queue work_queue; }; @@ -646,15 +591,16 @@ private: // WPM tokenizer // -struct llm_tokenizer_wpm { - llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} +struct llm_tokenizer_wpm : llm_tokenizer { + llm_tokenizer_wpm(const llama_vocab & /*vocab*/) {} +}; - void tokenize(const std::string & text, std::vector & output) const { - const auto & token_map = vocab.token_to_id; +struct llm_tokenizer_wpm_session { + llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {} + void tokenize(const std::string & text, std::vector & output) { // normalize and split by whitespace std::vector words = preprocess(text); - // bos token prepended already // find the longest tokens that form the words @@ -675,10 +621,10 @@ struct llm_tokenizer_wpm { for (int i = 0; i < n; ++i) { // loop through possible match length bool match = false; - for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) { - auto it = token_map.find(word1.substr(i, j - i)); - if (it != token_map.end()) { - output.push_back(it->second); + for (int j = std::min(n, i + vocab.max_token_len() + 1); j > i; j--) { + auto id = vocab.text_to_token(word1.substr(i, j - i)); + if (id != LLAMA_TOKEN_NULL) { + output.push_back(id); match = true; i = j - 1; break; @@ -693,18 +639,18 @@ struct llm_tokenizer_wpm { // we didn't find any matches for this word if (current_tokens == output.size()) { - output.push_back(vocab.special_unk_id); + output.push_back(vocab.token_unk()); } } } // TODO: reduce string copies by using cpts_offs array - std::vector preprocess(const std::string & text) const { + static std::vector preprocess(const std::string & text) { const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector words(1, ""); for (const uint32_t cpt : cpts_nfd) { - const auto flags = unicode_cpt_flags(cpt); + const auto flags = unicode_cpt_flags_from_cpt(cpt); if (flags.is_whitespace) { if (words.back().size()) { // finish previous word if any @@ -751,53 +697,56 @@ struct llm_tokenizer_wpm { //(cpt >= 0xFF00 && cpt <= 0xFFEF); } +private: const llama_vocab & vocab; + // currently unused + // const llm_tokenizer_wpm * wpm_tokenizer; }; // // UGM tokenizer // -struct llm_tokenizer_ugm { - llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) { - if (vocab.precompiled_charsmap.size() > 0) { +struct llm_tokenizer_ugm : llm_tokenizer { + llm_tokenizer_ugm(const llama_vocab & vocab, const std::vector & precompiled_charsmap) { + if (precompiled_charsmap.size() > 0) { size_t charsmap_offset = 0; // First four bytes of precompiled_charsmap contains length of binary // blob containing XOR-compressed compact double array (XCDA) entries - uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0]; + uint32_t xcda_blob_size = *(const uint32_t *) &precompiled_charsmap[0]; charsmap_offset += sizeof(xcda_blob_size); - if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) { + if (xcda_blob_size + charsmap_offset >= precompiled_charsmap.size()) { throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); } // Next xcda_blob_size bytes contain entries of XOR-compressed compact // double array (XCDA). Each entry is bit-packed into a 32-bit integer. - xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset]; + xcda_array = (const uint32_t *) &precompiled_charsmap[charsmap_offset]; xcda_array_size = xcda_blob_size / sizeof(uint32_t); charsmap_offset += xcda_blob_size; // Remaining bytes of precompiled charsmap contain null-terminated // replacement strings for prefixes matched by the XCDA. - prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset]; - prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset; + prefix_replacements = &precompiled_charsmap[charsmap_offset]; + prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset; } - for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { - const auto &token_data = vocab.id_to_token[id]; + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & token_data = vocab.get_token_data(id); - if (llama_is_normal_token(vocab, id)) { + if (vocab.is_normal(id)) { min_score = std::min(min_score, token_data.score); max_score = std::max(max_score, token_data.score); } - if (llama_is_normal_token(vocab, id) || - llama_is_user_defined_token(vocab, id) || - llama_is_unused_token(vocab, id)) { + if (vocab.is_normal(id) || + vocab.is_user_defined(id) || + vocab.is_unused(id)) { token_matcher.insert(token_data.text.data(), token_data.text.size(), id); } - if (llama_is_user_defined_token(vocab, id)) { + if (vocab.is_user_defined(id)) { user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size()); } } @@ -805,6 +754,29 @@ struct llm_tokenizer_ugm { unknown_token_score = min_score - unknown_token_score_penalty; } + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + const char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; + + const uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; + + struct naive_trie user_defined_token_matcher; + + float min_score = FLT_MAX; + float max_score = -FLT_MAX; + + float unknown_token_score_penalty = 10.0; + float unknown_token_score; + + struct naive_trie token_matcher; +}; + +struct llm_tokenizer_ugm_session { + llm_tokenizer_ugm_session(const llama_vocab & vocab, const llm_tokenizer_ugm & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + /* This implementation is based on SentencePiece optimized Viterbi algorithm for * unigram language models. The general idea is to: * - move along the input sequence in steps of one UTF code point, @@ -818,7 +790,7 @@ struct llm_tokenizer_ugm { * After processing the whole sequence we backtrack from the end to get * the best tokenization. */ - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const std::string & text, std::vector & output) { // get current size of output (for reversal later) size_t output_size = output.size(); @@ -831,9 +803,9 @@ struct llm_tokenizer_ugm { } // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores - std::vector tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX}); + std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX}); // at the beginning tokenization score is zero - tokenization_results[0] = { vocab.special_unk_id, 0, 0 }; + tokenization_results[0] = { vocab.token_unk(), 0, 0 }; for (size_t input_offset = 0; input_offset < input_len;) { size_t prefix_offset = input_offset; @@ -843,7 +815,7 @@ struct llm_tokenizer_ugm { // traverse the token matcher trie to find a matching token bool single_codepoint_token_found = false; const struct best_tokenization & current_best = tokenization_results[input_offset]; - const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + const struct naive_trie * node = tokenizer.token_matcher.traverse(normalized[prefix_offset++]); while (prefix_offset <= input_len && node != NULL) { // check if we found valid token in prefix @@ -853,13 +825,13 @@ struct llm_tokenizer_ugm { single_codepoint_token_found = true; } llama_token token_id = node->value; - const auto & token_data = vocab.id_to_token[token_id]; + const auto & token_data = vocab.get_token_data(token_id); // we set the user-defined token scores to 0 to make them more likely to be selected // (normal token scores are log probabilities, so they are negative) // score type is double here to make tokenization results exactly // the same as in the HF tokenizer using SentencePiece - const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score; + const double token_score = vocab.is_user_defined(token_id) ? 0.0 : token_data.score; const double challenger_score = current_best.score_sum + token_score; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { @@ -873,11 +845,11 @@ struct llm_tokenizer_ugm { // if we didn't find a valid token corresponding to the whole UTF code point // then use unknown token as the tokenization of this UTF code point if (!single_codepoint_token_found) { - const double challenger_score = current_best.score_sum + unknown_token_score; + const double challenger_score = current_best.score_sum + tokenizer.unknown_token_score; prefix_offset = input_offset + n_utf8_code_units; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score }; + struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score }; current_champ = challenger; } } @@ -890,7 +862,7 @@ struct llm_tokenizer_ugm { // merge sequences of consecutive unknown tokens into single unknown tokens bool is_prev_unknown = false; for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) { - bool is_unknown = tokenization.token_id == vocab.special_unk_id; + bool is_unknown = tokenization.token_id == vocab.token_unk(); if (!(is_prev_unknown && is_unknown)) { output.push_back(tokenization.token_id); } @@ -905,7 +877,6 @@ struct llm_tokenizer_ugm { } private: - const llama_vocab & vocab; // helper structure for returning normalization results struct normalization_result { @@ -918,11 +889,11 @@ private: normalized->clear(); normalized->reserve(input.size() * 3); - const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; + const std::string space = vocab.get_escape_whitespaces() ? tokenizer.escaped_space : " "; - bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; - bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; - bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces; + const bool shall_prepend_space = !vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix(); + const bool shall_append_space = vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix(); + const bool shall_merge_spaces = vocab.get_remove_extra_whitespaces(); bool is_space_prepended = false; bool processing_non_ws = false; @@ -1000,13 +971,21 @@ private: size_t xcda_array_size; }; + // this structure stores the best tokenization so far at input_offset + struct best_tokenization { + llama_token token_id; + size_t input_offset; + float score_sum; + }; + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { if (input_offset == input.size()) { return { &input[input_offset], 0, 0 }; } // if input prefix matches some user-defined token return this token as normalization result - auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + auto user_defined_token_match = + tokenizer.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); if (user_defined_token_match.second > 0) { return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; } @@ -1014,8 +993,8 @@ private: size_t longest_prefix_length = 0; size_t longest_prefix_offset = 0; - if (xcda_array_size > 0) { - struct xcda_array_view xcda_view(xcda_array, xcda_array_size); + if (tokenizer.xcda_array_size > 0) { + struct xcda_array_view xcda_view(tokenizer.xcda_array, tokenizer.xcda_array_size); // Find the longest normalized sequence matching the input prefix by walking // the XOR-compressed compact double array (XCDA) starting from the root node @@ -1051,50 +1030,27 @@ private: if (longest_prefix_length > 0) { // we have a match, so return the replacement sequence - if (longest_prefix_offset >= prefix_replacements_size) { + if (longest_prefix_offset >= tokenizer.prefix_replacements_size) { throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); } - const char * prefix_replacement = &prefix_replacements[longest_prefix_offset]; + const char * prefix_replacement = &(tokenizer.prefix_replacements)[longest_prefix_offset]; return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; - } else { - // check if the input prefix contains a valid sequence of UTF-8 code units - try { - // if yes, return this sequence unmodified - size_t prefix_offset = input_offset; - unicode_cpt_from_utf8(input, prefix_offset); - return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; - } catch (std::invalid_argument & /*ex*/) { - // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER - return { "\xEF\xBF\xBD", 3, 1 }; - } + } + + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch (std::invalid_argument & /*ex*/) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; } } - // escaped space symbol - U+2581 (Lower One Eighth Block) - const std::string escaped_space = "\xE2\x96\x81"; - - const char * prefix_replacements = NULL; - size_t prefix_replacements_size = 0; - - const uint32_t * xcda_array = NULL; - size_t xcda_array_size = 0; - - struct naive_trie user_defined_token_matcher; - - // this structure stores the best tokenization so far at input_offset - struct best_tokenization { - llama_token token_id; - size_t input_offset; - float score_sum; - }; - - float min_score = FLT_MAX; - float max_score = -FLT_MAX; - - float unknown_token_score_penalty = 10.0; - float unknown_token_score; - - struct naive_trie token_matcher; + const llama_vocab & vocab; + const llm_tokenizer_ugm & tokenizer; }; // @@ -1155,27 +1111,32 @@ static std::vector llama_unescape_rwkv_token(const std::string & escape return output; } -struct llm_tokenizer_rwkv { - llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) { +struct llm_tokenizer_rwkv : llm_tokenizer { + llm_tokenizer_rwkv(const llama_vocab & vocab) { // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. // For now, we decode the vocab here into the lookup we'll use for tokenization. // build trie - for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { - const auto & token = vocab.id_to_token[id]; - const auto data = llama_unescape_rwkv_token(token.text); - token_matcher.insert((const char *) data.data(), data.size(), id); + for (uint32_t id = 0; id < vocab.n_tokens(); ++id) { + const auto & data = vocab.get_token_data(id); + const auto text = llama_unescape_rwkv_token(data.text); + token_matcher.insert((const char *) text.data(), text.size(), id); } } - void tokenize(const std::string & text, std::vector & output) { - uint32_t position = 0; + struct naive_trie token_matcher; +}; +struct llm_tokenizer_rwkv_session { + llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + uint32_t position = 0; while (position < text.size()) { - const struct naive_trie * node = token_matcher.traverse(text[position]); + const struct naive_trie * node = tokenizer.token_matcher.traverse(text[position]); if (node == NULL) { // no matching token found, add unknown token - output.push_back(vocab.special_unk_id); + output.push_back(vocab.token_unk()); position += 1; continue; } @@ -1197,13 +1158,13 @@ struct llm_tokenizer_rwkv { } } +private: const llama_vocab & vocab; - - struct naive_trie token_matcher; + const llm_tokenizer_rwkv & tokenizer; }; // -// (de-) tokenize +// impl // typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { @@ -1212,7 +1173,7 @@ typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { } FRAGMENT_BUFFER_VARIANT_TYPE; struct fragment_buffer_variant { - fragment_buffer_variant(llama_vocab::id _token) + fragment_buffer_variant(llama_token _token) : type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), token(_token), @@ -1223,7 +1184,7 @@ struct fragment_buffer_variant { fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) : type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), - token((llama_vocab::id) - 1), + token((llama_token) - 1), raw_text(_raw_text), offset(_offset), length(_length){ @@ -1233,20 +1194,963 @@ struct fragment_buffer_variant { } const FRAGMENT_BUFFER_VARIANT_TYPE type; - const llama_vocab::id token; + const llama_token token; const std::string _dummy; const std::string & raw_text; const uint64_t offset; const uint64_t length; }; +struct llama_vocab::impl { + uint32_t n_token_types = 0; // for BERT-style token types + + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + + int max_token_len = 0; // used for optimizing longest token search + + // default LLaMA special tokens + // TODO: should we set all of these to LLAMA_TOKEN_NULL? + llama_token special_bos_id = 1; + llama_token special_eos_id = 2; + llama_token special_eot_id = LLAMA_TOKEN_NULL; + llama_token special_eom_id = LLAMA_TOKEN_NULL; + llama_token special_unk_id = 0; + llama_token special_sep_id = LLAMA_TOKEN_NULL; + llama_token special_pad_id = LLAMA_TOKEN_NULL; + llama_token special_mask_id = LLAMA_TOKEN_NULL; + + llama_token linefeed_id = 13; + + // fim tokens + llama_token special_fim_pre_id = LLAMA_TOKEN_NULL; + llama_token special_fim_suf_id = LLAMA_TOKEN_NULL; + llama_token special_fim_mid_id = LLAMA_TOKEN_NULL; + llama_token special_fim_pad_id = LLAMA_TOKEN_NULL; + llama_token special_fim_rep_id = LLAMA_TOKEN_NULL; // repo + llama_token special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator + + // tokenizer flags + bool add_space_prefix = false; + bool add_bos = false; + bool add_eos = false; + bool ignore_merges = false; + bool clean_spaces = false; // clean_up_tokenization_spaces + bool remove_extra_whitespaces = false; + bool escape_whitespaces = true; + bool treat_whitespace_as_suffix = false; + + std::unordered_map token_to_id; + std::vector id_to_token; + + std::vector cache_special_tokens; + std::vector cache_token_to_piece; // llama_token_to_piece(special = true); + struct pair_hash { + size_t operator()(const std::pair & p) const { + return std::hash{}(p.first) ^ //create some hash for pair + (std::hash{}(p.second) << 1); + } + }; + std::unordered_map, int, pair_hash> bpe_ranks; + + // set of all tokens that cause "end of generation" + std::set special_eog_ids; + + std::unique_ptr tokenizer; + + std::vector precompiled_charsmap; + + impl(const llama_vocab & vocab) : vocab(vocab) { + } + + ~impl() = default; + + void load(llama_model_loader & ml, const LLM_KV & kv); + + enum llama_vocab_type get_type() const; + + std::string type_name() const; + + bool is_normal (llama_token id) const; + bool is_unknown (llama_token id) const; + bool is_control (llama_token id) const; + bool is_byte (llama_token id) const; + bool is_user_defined(llama_token id) const; + bool is_unused (llama_token id) const; + bool is_eog (llama_token id) const; + + uint8_t token_to_byte(llama_token id) const; + + llama_token_attr token_get_attr(llama_token id) const; + + void init_tokenizer(enum llama_vocab_type type); + + void tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const; + + std::string token_to_piece_for_cache( + llama_token token, + bool special) const; + + + std::vector tokenize( + const std::string & raw_text, + bool add_special, + bool parse_special = false) const; + + int32_t tokenize( + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special) const; + + // does not write null-terminator to buf + int32_t token_to_piece( + llama_token token, + char * buf, + int32_t length, + int32_t lstrip, + bool special) const; + + // use cached data + const std::string & token_to_piece(llama_token token) const; + + int32_t detokenize( + const llama_token * tokens, + int32_t n_tokens, + char * text, + int32_t text_len_max, + bool remove_special, + bool unparse_special) const; + + std::string detokenize( + const std::vector & tokens, + bool special) const; + + void print_info() const; + +private: + const llama_vocab & vocab; +}; + +void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { + struct gguf_context * ctx = ml.meta.get(); + + // determine vocab type + { + std::string tokenizer_model; + std::string tokenizer_pre; + + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + + ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false); + + if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { + type = LLAMA_VOCAB_TYPE_NONE; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + linefeed_id = LLAMA_TOKEN_NULL; + + // read vocab size from metadata + uint32_t n_tokens = 0; + if (ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) { + LLAMA_LOG_WARN("%s: adding %u dummy tokens\n", __func__, n_tokens); + id_to_token.resize(n_tokens); + } + + return; + } + + if (tokenizer_model == "llama") { + type = LLAMA_VOCAB_TYPE_SPM; + + // default special tokens + special_bos_id = 1; + special_eos_id = 2; + special_unk_id = 0; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "bert") { + type = LLAMA_VOCAB_TYPE_WPM; + + // default special tokens + special_bos_id = 101; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = 100; + special_sep_id = 102; + special_pad_id = 0; + special_mask_id = 103; + } else if (tokenizer_model == "gpt2") { + type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } + + // default special tokens + special_bos_id = 11; + special_eos_id = 11; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "t5") { + type = LLAMA_VOCAB_TYPE_UGM; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = 1; + special_unk_id = 2; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 0; + special_mask_id = LLAMA_TOKEN_NULL; + + const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); + if (precompiled_charsmap_keyidx != -1) { + size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); + precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); +#ifdef IS_BIG_ENDIAN + // correct endiannes of data in precompiled_charsmap binary blob + uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; + *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); + assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); + size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); + uint32_t * xcda_array = (uint32_t *) &precompiled_charsmap[sizeof(uint32_t)]; + for (size_t i = 0; i < xcda_array_size; ++i) { + xcda_array[i] = __builtin_bswap32(xcda_array[i]); + } +#endif + } + } else if (tokenizer_model == "rwkv") { + type = LLAMA_VOCAB_TYPE_RWKV; + + // default special tokens + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + } else { + throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); + } + + // for now, only BPE models have pre-tokenizers + if (type == LLAMA_VOCAB_TYPE_BPE) { + add_space_prefix = false; + clean_spaces = true; + if (tokenizer_pre.empty()) { + LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); + LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); + LLAMA_LOG_WARN("%s: ************************************ \n", __func__); + LLAMA_LOG_WARN("%s: \n", __func__); + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if (tokenizer_pre == "default") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if ( + tokenizer_pre == "llama3" || + tokenizer_pre == "llama-v3" || + tokenizer_pre == "llama-bpe"|| + tokenizer_pre == "falcon3") { + pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + ignore_merges = true; + add_bos = true; + } else if ( + tokenizer_pre == "deepseek-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + clean_spaces = false; + } else if ( + tokenizer_pre == "deepseek-v3") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "falcon") { + pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; + } else if ( + tokenizer_pre == "mpt") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MPT; + } else if ( + tokenizer_pre == "starcoder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_STARCODER; + } else if ( + tokenizer_pre == "gpt-2" || + tokenizer_pre == "phi-2" || + tokenizer_pre == "jina-es" || + tokenizer_pre == "jina-de" || + tokenizer_pre == "gigachat" || + tokenizer_pre == "jina-v1-en" || + tokenizer_pre == "jina-v2-es" || + tokenizer_pre == "jina-v2-de" || + tokenizer_pre == "jina-v2-code" || + tokenizer_pre == "roberta-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "refact") { + pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; + } else if ( + tokenizer_pre == "command-r") { + pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + clean_spaces = false; + } else if ( + tokenizer_pre == "qwen2" || + tokenizer_pre == "deepseek-r1-qwen") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + clean_spaces = false; + } else if ( + tokenizer_pre == "stablelm2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; + } else if ( + tokenizer_pre == "olmo") { + pre_type = LLAMA_VOCAB_PRE_TYPE_OLMO; + } else if ( + tokenizer_pre == "dbrx") { + pre_type = LLAMA_VOCAB_PRE_TYPE_DBRX; + } else if ( + tokenizer_pre == "smaug-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SMAUG; + } else if ( + tokenizer_pre == "poro-chat") { + pre_type = LLAMA_VOCAB_PRE_TYPE_PORO; + clean_spaces = false; + } else if ( + tokenizer_pre == "chatglm-bpe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; + special_bos_id = LLAMA_TOKEN_NULL; + } else if ( + tokenizer_pre == "viking") { + pre_type = LLAMA_VOCAB_PRE_TYPE_VIKING; + clean_spaces = false; + } else if ( + tokenizer_pre == "jais") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS; + } else if ( + tokenizer_pre == "tekken") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TEKKEN; + clean_spaces = false; + ignore_merges = true; + add_bos = true; + } else if ( + tokenizer_pre == "smollm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SMOLLM; + clean_spaces = false; + } else if ( + tokenizer_pre == "codeshell") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CODESHELL; + } else if ( + tokenizer_pre == "bloom") { + pre_type = LLAMA_VOCAB_PRE_TYPE_BLOOM; + } else if ( + tokenizer_pre == "gpt3-finnish") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; + } else if ( + tokenizer_pre == "exaone") { + pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "chameleon") { + pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; + add_bos = true; + clean_spaces = false; + } else if ( + tokenizer_pre == "minerva-7b") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINERVA; + } else if ( + tokenizer_pre == "megrez") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else { + throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); + } + } else if (type == LLAMA_VOCAB_TYPE_SPM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = true; + clean_spaces = false; + add_bos = true; + add_eos = false; + } else if (type == LLAMA_VOCAB_TYPE_WPM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = false; + clean_spaces = true; + add_bos = true; + add_eos = false; + } else if (type == LLAMA_VOCAB_TYPE_UGM) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_bos = false; + add_eos = true; + } else if (type == LLAMA_VOCAB_TYPE_RWKV) { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + add_space_prefix = false; + clean_spaces = false; + add_bos = false; + add_eos = false; + } else { + pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } + + ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, add_space_prefix, false); + ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false); + } + + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + + const float * scores = nullptr; + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); + if (score_idx != -1) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + } + + const int * toktypes = nullptr; + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx != -1) { + toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + } + + uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); + id_to_token.resize(n_tokens); + + for (uint32_t i = 0; i < n_tokens; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + if (word.empty()) { + LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); + word = "[EMPTY_" + std::to_string(i) + "]"; + } + + token_to_id[word] = i; + max_token_len = std::max(max_token_len, (int) word.size()); + + auto & token_data = id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores ? scores[i] : 0.0f; + token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; + + if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file + switch(toktypes[i]) { + case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; + case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; + case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; + case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; + case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; + case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; + case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + } + } + } + GGML_ASSERT(id_to_token.size() == token_to_id.size()); + + init_tokenizer(type); + + // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' + if (type == LLAMA_VOCAB_TYPE_SPM) { + try { + linefeed_id = vocab.byte_to_token('\n'); + } catch (const std::exception & e) { + LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what()); + linefeed_id = special_pad_id; + } + } else if (type == LLAMA_VOCAB_TYPE_WPM) { + linefeed_id = special_pad_id; + } else if (type == LLAMA_VOCAB_TYPE_RWKV) { + const std::vector ids = tokenize("\n", false); + GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + linefeed_id = ids[0]; + } else { + const std::vector ids = tokenize("\n", false); + + //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); + if (ids.empty()) { + LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); + linefeed_id = special_pad_id; + } else { + linefeed_id = ids[0]; + } + } + + // special tokens + { + const std::vector> special_token_types = { + { LLM_KV_TOKENIZER_BOS_ID, special_bos_id }, + { LLM_KV_TOKENIZER_EOS_ID, special_eos_id }, + { LLM_KV_TOKENIZER_EOT_ID, special_eot_id }, + { LLM_KV_TOKENIZER_EOM_ID, special_eom_id }, + { LLM_KV_TOKENIZER_UNK_ID, special_unk_id }, + { LLM_KV_TOKENIZER_SEP_ID, special_sep_id }, + { LLM_KV_TOKENIZER_PAD_ID, special_pad_id }, + { LLM_KV_TOKENIZER_MASK_ID, special_mask_id }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, special_fim_pre_id }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, special_fim_suf_id }, + { LLM_KV_TOKENIZER_FIM_MID_ID, special_fim_mid_id }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, special_fim_pad_id }, + { LLM_KV_TOKENIZER_FIM_REP_ID, special_fim_rep_id }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, special_fim_sep_id }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, special_fim_pre_id }, + { LLM_KV_TOKENIZER_SUFFIX_ID, special_fim_suf_id }, + { LLM_KV_TOKENIZER_MIDDLE_ID, special_fim_mid_id }, + }; + + for (const auto & it : special_token_types) { + const std::string & key = kv(std::get<0>(it)); + int32_t & id = std::get<1>(it); + + uint32_t new_id; + if (!ml.get_key(std::get<0>(it), new_id, false)) { + continue; + } + if (new_id >= id_to_token.size()) { + LLAMA_LOG_WARN("%s: bad special token: '%s' = %u, using default id %d\n", + __func__, key.c_str(), new_id, id); + } else { + id = new_id; + } + } + + // Handle add_bos and add_eos + { + bool temp = true; + + if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { + add_bos = temp; + } + if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { + add_eos = temp; + } + } + + // auto-detect special tokens by text + // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... + // for now, we apply this workaround to find the tokens based on their text + + for (const auto & t : token_to_id) { + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. + if (special_eot_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eot_id|>" + || t.first == "<|im_end|>" + || t.first == "<|end|>" + || t.first == "" + || t.first == "<|endoftext|>" + || t.first == "" + || t.first == "<|end▁of▁sentence|>" // DeepSeek + ) { + special_eot_id = t.second; + if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + } + } + } + + // find EOM token: "<|eom_id|>" + if (special_eom_id == LLAMA_TOKEN_NULL) { + if (false + || t.first == "<|eom_id|>" + ) { + special_eom_id = t.second; + if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", + __func__, t.second, t.first.c_str()); + id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + } + } + } + + // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        ) {
+                    special_fim_pre_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        ) {
+                    special_fim_suf_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        ) {
+                    special_fim_mid_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    special_fim_pad_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    special_fim_rep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    special_fim_sep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        special_eog_ids.clear();
+
+        if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
+            special_eog_ids.insert(special_fim_pad_id);
+        }
+
+        if (special_fim_rep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_rep_id) == 0) {
+            special_eog_ids.insert(special_fim_rep_id);
+        }
+
+        if (special_fim_sep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_sep_id) == 0) {
+            special_eog_ids.insert(special_fim_sep_id);
+        }
+
+        for (const auto & t : token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == ""
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == ""
+               ) {
+                special_eog_ids.insert(t.second);
+                if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
+                    id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
+            }
+        }
+
+        // sanity checks
+        if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
+            special_eog_ids.insert(special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eot_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eot_id) == 0) {
+            special_eog_ids.insert(special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eom_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eom_id) == 0) {
+            special_eog_ids.insert(special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+    }
+
+    // build special tokens cache
+    {
+        for (llama_token id = 0; id < (llama_token) n_tokens; ++id) {
+            if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                cache_special_tokens.push_back(id);
+            }
+        }
+
+        std::sort(cache_special_tokens.begin(), cache_special_tokens.end(),
+            [&] (const llama_token a, const llama_token b) {
+                return id_to_token[a].text.size() > id_to_token[b].text.size();
+            }
+        );
+
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t) cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache(n_tokens);
+
+        for (uint32_t id = 0; id < n_tokens; ++id) {
+            cache[id] = token_to_piece_for_cache(id, true);
+
+            size_cache += cache[id].size();
+        }
+
+        std::swap(cache_token_to_piece, cache);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
+    }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string & str, const std::vector & substrs) -> bool {
+            for (const auto & substr : substrs) {
+                if (str.find(substr) < std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_token id, llama_token_attr attr, bool value) {
+            uint32_t current = id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_GENERAL_NAME,  model_name,    false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer name
+        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
+            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {""}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {"", "", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
+        }
+    }
+}
+
+enum llama_vocab_type llama_vocab::impl::get_type() const {
+    return type;
+}
+
+std::string llama_vocab::impl::type_name() const{
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
+        default:                    return "unknown";
+    }
+}
+
+bool llama_vocab::impl::is_normal(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
+}
+
+bool llama_vocab::impl::is_unknown(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
+}
+
+bool llama_vocab::impl::is_control(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
+}
+
+bool llama_vocab::impl::is_byte(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
+}
+
+bool llama_vocab::impl::is_user_defined(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
+}
+
+bool llama_vocab::impl::is_unused(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
+}
+
+bool llama_vocab::impl::is_eog(llama_token id) const {
+    return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
+}
+
+uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    GGML_ASSERT(is_byte(id));
+    const auto & token_data = id_to_token.at(id);
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            auto buf = token_data.text.substr(3, 2);
+            return strtol(buf.c_str(), NULL, 16);
+        }
+        case LLAMA_VOCAB_TYPE_BPE: {
+            GGML_ABORT("fatal error");
+        }
+        case LLAMA_VOCAB_TYPE_WPM: {
+            GGML_ABORT("fatal error");
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token_attr llama_vocab::impl::token_get_attr(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token.at(id).attr;
+}
+
+void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
+    LLAMA_LOG_DEBUG("%s: initializing tokenizer for type %d\n", __func__, type);
+
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            tokenizer = std::make_unique(vocab, precompiled_charsmap);
+            break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            tokenizer = std::make_unique(vocab);
+            break;
+        default:
+            GGML_ABORT("unsupported vocab type");
+    }
+}
+
+//
+// (de-) tokenize
+//
+
 // #define PRETOKENIZERDEBUG
 
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) {
+void llama_vocab::impl::tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const {
     // for each special token
-    for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
-        const auto & data = vocab.id_to_token[special_id];
-        const auto & special_token = data.text;
+    for (const llama_token special_id : cache_special_tokens) {
+        const auto & data = vocab.get_token_data(special_id);
+        const auto & text = data.text;
 
         if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
             // Ignore control and unknown tokens when parse_special == false
@@ -1263,7 +2167,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
 
             // if a fragment is text ( not yet processed )
             if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                auto & raw_text = fragment.raw_text;
+                const auto & raw_text = fragment.raw_text;
 
                 auto raw_text_base_offset = fragment.offset;
                 auto raw_text_base_length = fragment.length;
@@ -1273,13 +2177,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     // find the first occurrence of a given special token in this fragment
                     //  passing offset argument only limit the "search area" but match coordinates
                     //  are still relative to the source full raw_text
-                    auto match = raw_text.find(special_token, raw_text_base_offset);
+                    auto match = raw_text.find(text, raw_text_base_offset);
 
                     // no occurrences found, stop processing this fragment for a given special token
                     if (match == std::string::npos) break;
 
                     // check if match is within bounds of offset <-> length
-                    if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
+                    if (match + text.length() > raw_text_base_offset + raw_text_base_length) break;
 
 #ifdef PRETOKENIZERDEBUG
                     LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
@@ -1314,9 +2218,9 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     it++;
 
                     // right
-                    if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
-                        int64_t right_reminder_offset = match + special_token.length();
-                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
+                    if (match + text.length() < raw_text_base_offset + raw_text_base_length) {
+                        int64_t right_reminder_offset = match + text.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + text.length());
 
                         if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
                             while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
@@ -1337,7 +2241,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
                         }
 
                         // repeat for the right side
@@ -1351,7 +2255,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
                         }
                         break;
                     }
@@ -1362,296 +2266,29 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
     }
 }
 
-std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
-    std::vector output;
-    std::forward_list fragment_buffer;
-
-    if (!raw_text.empty()) {
-        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
-        tokenizer_st_partition(vocab, fragment_buffer, parse_special);
+// NOTE: avoid ever using this except for building the token_to_piece caches
+std::string llama_vocab::impl::token_to_piece_for_cache(llama_token token, bool special) const {
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache
+    const int n_chars = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
     }
 
-    switch (vocab.type) {
-        case LLAMA_VOCAB_TYPE_SPM:
-            {
-                // OG tokenizer behavior:
-                //
-                // tokenizer.encode('', add_special_tokens=True)  returns [1]
-                // tokenizer.encode('', add_special_tokens=False) returns []
-
-                bool is_prev_special = true;  // prefix with space if first token
-
-                if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                    is_prev_special = true;
-                }
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-                        // prefix with space if previous is special
-                        if (vocab.tokenizer_add_space_prefix && is_prev_special) {
-                            raw_text = " " + raw_text;
-                        }
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        llm_tokenizer_spm tokenizer(vocab);
-                        llama_escape_whitespace(raw_text);
-                        tokenizer.tokenize(raw_text, output);
-                        is_prev_special = false;
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                        is_prev_special = true;
-                    }
-                }
-
-                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_BPE:
-            {
-                llm_tokenizer_bpe tokenizer(vocab);
-
-                if (add_special) {
-                    tokenizer.append_bos(output);
-                }
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        tokenizer.append(fragment.token, output);
-                    }
-                }
-
-                if (add_special) {
-                    tokenizer.append_eos(output);
-                    tokenizer.check_double_bos_eos(output);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_WPM:
-            {
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_cls_id != -1);
-                    output.push_back(vocab.special_cls_id);
-                }
-
-                llm_tokenizer_wpm tokenizer(vocab);
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_sep_id != -1);
-                    output.push_back(vocab.special_sep_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_UGM:
-            {
-                llm_tokenizer_ugm tokenizer(vocab);
-
-                if (add_special && vocab.tokenizer_add_bos != 0) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                }
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.tokenizer_add_eos == 1) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_RWKV:
-            {
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-
-                        llm_tokenizer_rwkv tokenizer(vocab);
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_NONE:
-            GGML_ABORT("fatal error");
-    }
-
-    return output;
+    return piece;
 }
 
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    static const char * hex = "0123456789ABCDEF";
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
-            auto token = vocab.token_to_id.find(buf);
-            if (token != vocab.token_to_id.end()) {
-                return (*token).second;
-            }
-            // Try to fall back to just the byte as a string
-            const char buf2[2] = { (char)ch, 0 };
-            return vocab.token_to_id.at(buf2);
-        }
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_BPE: {
-            return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
-        }
-        default:
-            GGML_ABORT("fatal error");
-    }
+static void llama_escape_whitespace(std::string & text) {
+    replace_all(text, " ", "\xe2\x96\x81");
 }
 
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].text.c_str();
-}
-
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].score;
-}
-
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].attr;
-}
-
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
-    return token != -1 && (
-        token == llama_token_eos_impl(vocab) ||
-        token == llama_token_eot_impl(vocab) ||
-        token == llama_token_eom_impl(vocab)
-    );
-}
-
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
-    return llama_is_control_token(vocab, token);
-}
-
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
-    return vocab.special_bos_id;
-}
-
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eos_id;
-}
-
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
-    return vocab.special_cls_id;
-}
-
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_sep_id;
-}
-
-llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
-    return vocab.linefeed_id;
-}
-
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_pad_id;
-}
-
-bool llama_add_bos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_bos;
-}
-
-bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_eos;
-}
-
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_prefix_id;
-}
-
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
-    return vocab.special_middle_id;
-}
-
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_suffix_id;
-}
-
-llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eot_id;
-}
-
-llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eom_id;
-}
-
-int32_t llama_tokenize_impl(
-    const struct llama_vocab & vocab,
-                  const char * text,
-                     int32_t   text_len,
-                 llama_token * tokens,
-                     int32_t   n_tokens_max,
-                        bool   add_special,
-                        bool   parse_special) {
-    auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
-    if (n_tokens_max < (int) res.size()) {
-        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
-        return -((int) res.size());
-    }
-
-    for (size_t i = 0; i < res.size(); i++) {
-        tokens[i] = res[i];
-    }
-
-    return res.size();
+static void llama_unescape_whitespace(std::string & word) {
+    replace_all(word, "\xe2\x96\x81", " ");
 }
 
 static std::string llama_decode_text(const std::string & text) {
@@ -1674,11 +2311,185 @@ static std::string llama_decode_text(const std::string & text) {
     return decoded_text;
 }
 
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
+std::vector llama_vocab::impl::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    std::vector output;
+    std::forward_list fragment_buffer;
+
+    if (!raw_text.empty()) {
+        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+        tokenizer_st_partition(fragment_buffer, parse_special);
+    }
+
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            {
+                // OG tokenizer behavior:
+                //
+                // tokenizer.encode('', add_special_tokens=True)  returns [1]
+                // tokenizer.encode('', add_special_tokens=False) returns []
+
+                bool is_prev_special = true;  // prefix with space if first token
+
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                    is_prev_special = true;
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text;
+
+                        // prefix with space if previous is special
+                        if (add_space_prefix && is_prev_special) {
+                            text = ' ';
+                        }
+
+                        text += fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        llama_escape_whitespace(text);
+                        llm_tokenizer_spm_session session(vocab);
+                        session.tokenize(text, output);
+                        is_prev_special = false;
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                        is_prev_special = true;
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            {
+                llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get()));
+                // it calls some other methods that are not exist in llm_tokenizer,
+                // here just cast it to bpe tokenizer object
+                if (add_special) {
+                    session.append_bos(output);
+                }
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        session.append(fragment.token, output);
+                    }
+                }
+
+                if (add_special) {
+                    session.append_eos(output);
+                    session.check_double_bos_eos(output);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            {
+                if (add_special) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+
+                llm_tokenizer_wpm_session session(vocab);
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special) {
+                    GGML_ASSERT(special_sep_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_sep_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            {
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+                llm_tokenizer_ugm_session session(vocab, *static_cast(tokenizer.get()));
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            {
+                llm_tokenizer_rwkv_session session(vocab, *static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_NONE:
+            GGML_ABORT("fatal error");
+    }
+
+    return output;
+}
+
+int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
     // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
     static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
-    const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
+    const llama_token_attr attr = token_get_attr(token);
     if (!special && (attr & attr_special)) {
         return 0;
     }
@@ -1699,7 +2510,7 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
 
     // if we have a cache - use it
     {
-        const auto & cache = vocab.cache_token_to_piece;
+        const auto & cache = cache_token_to_piece;
 
         if (!cache.empty()) {
             const auto & result = cache.at(token);
@@ -1707,9 +2518,9 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
         }
     }
 
-    if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
-        const std::string & token_text = vocab.id_to_token[token].text;
-        switch (llama_vocab_get_type(vocab)) {
+    if (0 <= token && token < (int32_t) id_to_token.size()) {
+        const std::string & token_text = id_to_token[token].text;
+        switch (get_type()) {
             case LLAMA_VOCAB_TYPE_WPM:
             case LLAMA_VOCAB_TYPE_SPM:
             case LLAMA_VOCAB_TYPE_UGM: {
@@ -1717,12 +2528,14 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
                 // suppressing them like CONTROL tokens.
                 if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
                     return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
                     std::string result = token_text;
                     llama_unescape_whitespace(result);
                     return _try_copy(result.data(), result.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_BYTE) {
-                    char byte = (char) llama_token_to_byte(vocab, token);
+                }
+                if (attr & LLAMA_TOKEN_ATTR_BYTE) {
+                    char byte = (char) token_to_byte(token);
                     return _try_copy((char*) &byte, 1);
                 }
                 break;
@@ -1732,7 +2545,8 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
                 // suppressing them like CONTROL tokens.
                 if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
                     return _try_copy(token_text.data(), token_text.size());
-                } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
+                }
+                if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
                     std::string result = llama_decode_text(token_text);
                     return _try_copy(result.data(), result.size());
                 }
@@ -1757,37 +2571,46 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
     return 0;
 }
 
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
+const std::string & llama_vocab::impl::token_to_piece(llama_token token) const {
+    return cache_token_to_piece.at(token);
+}
+
+int32_t llama_vocab::impl::detokenize(
                const llama_token * tokens,
                          int32_t   n_tokens,
                             char * text,
                          int32_t   text_len_max,
                             bool   remove_special,
-                            bool   unparse_special) {
+                            bool   unparse_special) const {
+    if (type == LLAMA_VOCAB_TYPE_NONE) {
+        return 0;
+    }
+
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
     int32_t avail = text_len_max;
     int32_t total = 0;
 
     // remove the leading space
-    bool remove_space = vocab.tokenizer_add_space_prefix;
+    bool remove_space = add_space_prefix;
 
-    if (remove_special && vocab.tokenizer_add_bos) {
-        if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
+    if (remove_special && add_bos) {
+        if (n_tokens > 0 && tokens[0] == special_bos_id) {
             remove_space = false;
             n_tokens--;
             tokens++;
         }
     }
 
-    if (remove_special && vocab.tokenizer_add_eos) {
-        if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
+    if (remove_special && add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens - 1] == special_eos_id) {
             n_tokens--;
         }
     }
 
     for (int32_t i = 0; i < n_tokens; ++i) {
         GGML_ASSERT(avail >= 0);
-        int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
+        int32_t n_chars = token_to_piece(tokens[i], text, avail, remove_space, unparse_special);
         remove_space = false;
         if (n_chars < 0) {
             avail = 0;
@@ -1803,7 +2626,7 @@ int32_t llama_detokenize_impl(
         return -total;
     }
 
-    if (vocab.tokenizer_clean_spaces) {
+    if (clean_spaces) {
         text -= total;  // restart text
 
         // first pass: characters ?!.,  //TODO: where do these characters come from?
@@ -1863,3 +2686,567 @@ int32_t llama_detokenize_impl(
 
     return total <= text_len_max ? total : -total;
 }
+
+void llama_vocab::impl::print_info() const {
+    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, type_name().c_str());
+    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, vocab.n_tokens());
+    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
+
+    // special tokens
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token[special_bos_id].text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token[special_eos_id].text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token[special_eot_id].text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token[special_eom_id].text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token[special_unk_id].text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token[special_sep_id].text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token[special_pad_id].text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token[special_mask_id].text.c_str() ); }
+
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token[linefeed_id].text.c_str() ); }
+
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token[special_fim_pre_id].text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token[special_fim_suf_id].text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token[special_fim_mid_id].text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token[special_fim_pad_id].text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token[special_fim_rep_id].text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token[special_fim_sep_id].text.c_str() ); }
+
+    for (const auto & id : special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token[id].text.c_str() );
+    }
+
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+}
+
+llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
+}
+
+llama_vocab::~llama_vocab() {
+}
+
+void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
+    pimpl->load(ml, kv);
+}
+
+enum llama_vocab_type llama_vocab::get_type() const {
+    return pimpl->type;
+}
+
+enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
+    return pimpl->pre_type;
+}
+
+uint32_t llama_vocab::n_tokens() const {
+    return (uint32_t) pimpl->id_to_token.size();
+}
+
+uint32_t llama_vocab::n_token_types() const {
+    return (uint32_t) pimpl->n_token_types;
+}
+
+std::string llama_vocab::type_name() const{
+    return pimpl->type_name();
+}
+
+bool llama_vocab::is_normal(llama_token id) const {
+    return pimpl->is_normal(id);
+}
+
+bool llama_vocab::is_unknown(llama_token id) const {
+    return pimpl->is_unknown(id);
+}
+
+bool llama_vocab::is_control(llama_token id) const {
+    return pimpl->is_control(id);
+}
+
+bool llama_vocab::is_byte(llama_token id) const {
+    return pimpl->is_byte(id);
+}
+
+bool llama_vocab::is_user_defined(llama_token id) const {
+    return pimpl->is_user_defined(id);
+}
+
+bool llama_vocab::is_unused(llama_token id) const {
+    return pimpl->is_unused(id);
+}
+
+bool llama_vocab::is_eog(llama_token id) const {
+    return pimpl->is_eog(id);
+}
+
+uint8_t llama_vocab::token_to_byte(llama_token id) const {
+    return pimpl->token_to_byte(id);
+}
+
+llama_token llama_vocab::byte_to_token(uint8_t ch) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    static const char * hex = "0123456789ABCDEF";
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
+            auto token = pimpl->token_to_id.find(buf);
+            if (token != pimpl->token_to_id.end()) {
+                return (*token).second;
+            }
+            // Try to fall back to just the byte as a string
+            const char buf2[2] = { (char)ch, 0 };
+            return pimpl->token_to_id.at(buf2);
+        }
+        case LLAMA_VOCAB_TYPE_WPM:
+        case LLAMA_VOCAB_TYPE_BPE: {
+            return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token llama_vocab::text_to_token(const std::string & text) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    auto it = pimpl->token_to_id.find(text);
+    if (it != pimpl->token_to_id.end()) {
+        return (*it).second;
+    }
+    return LLAMA_TOKEN_NULL;
+}
+
+const llama_vocab::token_data & llama_vocab::get_token_data(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id);
+}
+
+const char * llama_vocab::token_get_text(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).text.c_str();
+}
+
+float llama_vocab::token_get_score(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).score;
+}
+
+llama_token_attr llama_vocab::token_get_attr(llama_token id) const {
+    return pimpl->token_get_attr(id);
+}
+
+llama_token llama_vocab::token_bos() const {
+    return pimpl->special_bos_id;
+}
+
+llama_token llama_vocab::token_eos() const {
+    return pimpl->special_eos_id;
+}
+
+llama_token llama_vocab::token_eot() const {
+    return pimpl->special_eot_id;
+}
+
+llama_token llama_vocab::token_eom() const {
+    return pimpl->special_eom_id;
+}
+
+llama_token llama_vocab::token_unk() const {
+    return pimpl->special_unk_id;
+}
+
+llama_token llama_vocab::token_sep() const {
+    return pimpl->special_sep_id;
+}
+
+llama_token llama_vocab::token_nl() const {
+    return pimpl->linefeed_id;
+}
+
+llama_token llama_vocab::token_pad() const {
+    return pimpl->special_pad_id;
+}
+
+llama_token llama_vocab::token_prefix() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_middle() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_suffix() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_pre() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_fim_suf() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_mid() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_fim_pad() const {
+    return pimpl->special_fim_pad_id;
+}
+
+llama_token llama_vocab::token_fim_rep() const {
+    return pimpl->special_fim_rep_id;
+}
+
+llama_token llama_vocab::token_fim_sep() const {
+    return pimpl->special_fim_sep_id;
+}
+
+bool llama_vocab::get_add_space_prefix() const {
+    return pimpl->add_space_prefix;
+}
+
+bool llama_vocab::get_add_bos() const {
+    return pimpl->add_bos;
+}
+
+bool llama_vocab::get_add_eos() const {
+    return pimpl->add_eos;
+}
+
+bool llama_vocab::get_ignore_merges() const {
+    return pimpl->ignore_merges;
+}
+
+bool llama_vocab::get_clean_spaces() const {
+    return pimpl->clean_spaces;
+}
+
+bool llama_vocab::get_remove_extra_whitespaces() const {
+    return pimpl->remove_extra_whitespaces;
+}
+
+bool llama_vocab::get_escape_whitespaces() const {
+    return pimpl->escape_whitespaces;
+}
+
+bool llama_vocab::get_treat_whitespace_as_suffix() const {
+    return pimpl->treat_whitespace_as_suffix;
+}
+
+int llama_vocab::max_token_len() const {
+    return pimpl->max_token_len;
+}
+
+int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
+    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
+    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
+    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
+    GGML_ASSERT(token_right.find('\n') == std::string::npos);
+
+    auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right));
+    if (it == pimpl->bpe_ranks.end()) {
+        return -1;
+    }
+
+    return it->second;
+}
+
+int32_t llama_vocab::tokenize(
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) const {
+    auto res = tokenize(std::string(text, text_len), add_special, parse_special);
+    if (n_tokens_max < (int) res.size()) {
+        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
+        return -((int) res.size());
+    }
+
+    for (size_t i = 0; i < res.size(); i++) {
+        tokens[i] = res[i];
+    }
+
+    return res.size();
+}
+
+std::vector llama_vocab::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    return pimpl->tokenize(raw_text, add_special, parse_special);
+}
+
+const std::string & llama_vocab::token_to_piece(llama_token token) const {
+    return pimpl->token_to_piece(token);
+}
+
+int32_t llama_vocab::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    return pimpl->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_vocab::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    return pimpl->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
+std::string llama_vocab::detokenize(const std::vector & tokens, bool special) const {
+    std::string text;
+    text.resize(std::max(text.capacity(), tokens.size()));
+    int32_t n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    if (n_chars < 0) {
+        text.resize(-n_chars);
+        n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
+    }
+
+    text.resize(n_chars);
+
+    // NOTE: the original tokenizer decodes bytes after collecting the pieces.
+    return text;
+}
+
+void llama_vocab::print_info() const {
+    pimpl->print_info();
+}
+
+//
+// interface implementation
+//
+
+int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab) {
+    return vocab->n_tokens();
+}
+
+// deprecated
+int32_t llama_n_vocab(const struct llama_vocab * vocab) {
+    return llama_vocab_n_tokens(vocab);
+}
+
+enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) {
+    return vocab->get_type();
+}
+
+const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_text(token);
+}
+
+float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_score(token);
+}
+
+enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_attr(token);
+}
+
+bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_eog(token);
+}
+
+bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_control(token);
+}
+
+llama_token llama_vocab_bos(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_eos(const struct llama_vocab * vocab) {
+    return vocab->token_eos();
+}
+
+llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
+    return vocab->token_eot();
+}
+
+// deprecated
+llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_sep(const struct llama_vocab * vocab) {
+    return vocab->token_sep();
+}
+
+llama_token llama_vocab_nl (const struct llama_vocab * vocab) {
+    return vocab->token_nl();
+}
+
+llama_token llama_vocab_pad(const struct llama_vocab * vocab) {
+    return vocab->token_pad();
+}
+
+bool llama_vocab_get_add_bos(const struct llama_vocab * vocab) {
+    return vocab->get_add_bos();
+}
+
+bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
+    return vocab->get_add_eos();
+}
+
+llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pre();
+}
+
+llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab) {
+    return vocab->token_fim_suf();
+}
+
+llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab) {
+    return vocab->token_fim_mid();
+}
+
+llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pad();
+}
+
+llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_rep();
+}
+
+llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_sep();
+}
+
+// deprecated
+const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_text(vocab, token);
+}
+
+// deprecated
+float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_score(vocab, token);
+}
+
+// deprecated
+enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_attr(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_eog(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_control(vocab, token);
+}
+
+// deprecated
+llama_token llama_token_bos(const struct llama_vocab * vocab) {
+    return llama_vocab_bos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eos(const struct llama_vocab * vocab) {
+    return llama_vocab_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eot(const struct llama_vocab * vocab) {
+    return llama_vocab_eot(vocab);
+}
+
+// deprecated
+llama_token llama_token_cls(const struct llama_vocab * vocab) {
+    //return llama_vocab_cls(vocab);
+    return llama_vocab_bos(vocab); // avoid deprecation warning
+}
+
+// deprecated
+llama_token llama_token_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_sep(vocab);
+}
+
+// deprecated
+llama_token llama_token_nl (const struct llama_vocab * vocab) {
+    return llama_vocab_nl(vocab);
+}
+
+// deprecated
+llama_token llama_token_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_pad(vocab);
+}
+
+// deprecated
+bool llama_add_bos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_bos(vocab);
+}
+
+// deprecated
+bool llama_add_eos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pre(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pre(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_suf(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_suf(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_mid(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_mid(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pad(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_rep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_rep(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_sep(vocab);
+}
+
+//
+// tokenization
+//
+
+int32_t llama_tokenize(
+    const struct llama_vocab * vocab,
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) {
+    return vocab->tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
+}
+
+int32_t llama_token_to_piece(
+    const struct llama_vocab * vocab,
+                 llama_token   token,
+                        char * buf,
+                     int32_t   length,
+                     int32_t   lstrip,
+                        bool   special) {
+    return vocab->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_detokenize(
+    const struct llama_vocab * vocab,
+           const llama_token * tokens,
+                     int32_t   n_tokens,
+                        char * text,
+                     int32_t   text_len_max,
+                        bool   remove_special,
+                        bool   unparse_special) {
+    return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index dc4b5f12f..5ce355214 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -1,133 +1,125 @@
 #pragma once
 
-#include "llama-impl.h"
+#include "llama.h"
 
 #include 
 #include 
-#include 
-#include 
+#include 
+
+struct LLM_KV;
+struct llama_model_loader;
 
 struct llama_vocab {
-    using id    = llama_token;
-    using token = std::string;
-    using tattr = llama_token_attr;
-
     struct token_data {
-        token text;
-        float score;
-        tattr attr;
+        std::string      text;
+        float            score;
+        llama_token_attr attr;
     };
 
-    uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
+    llama_vocab();
+    ~llama_vocab();
 
-    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
-    enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+    void load(llama_model_loader & ml, const LLM_KV & kv);
 
-    int max_token_len = 0; // used for optimizing longest token search
+    enum llama_vocab_type     get_type()     const;
+    enum llama_vocab_pre_type get_pre_type() const;
 
-    std::unordered_map token_to_id;
-    std::vector       id_to_token;
+    uint32_t n_tokens() const;
+    uint32_t n_token_types() const;
 
-    std::vector    cache_special_tokens;
-    std::vector cache_token_to_piece; // llama_token_to_piece(special = true);
+    std::string type_name() const;
 
-    std::map, int> bpe_ranks;
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
 
-    // default LLaMA special tokens
-    id special_bos_id  = 1;
-    id special_eos_id  = 2;
-    id special_unk_id  = 0;
-    id special_sep_id  = -1;
-    id special_pad_id  = -1;
-    id special_cls_id  = -1;
-    id special_mask_id = -1;
+    uint8_t     token_to_byte(llama_token id) const;
+    llama_token byte_to_token(uint8_t ch)     const;
 
-    id linefeed_id       = 13;
-    id special_prefix_id = -1;
-    id special_suffix_id = -1;
-    id special_middle_id = -1;
-    id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
-    id special_eom_id    = -1;
+    llama_token text_to_token(const std::string & text) const;
 
-    // tokenizer flags
-    bool tokenizer_add_space_prefix = false;
-    bool tokenizer_add_bos          = false;
-    bool tokenizer_add_eos          = false;
-    bool tokenizer_ignore_merges    = false;
-    bool tokenizer_clean_spaces     = false;  // clean_up_tokenization_spaces
-    bool tokenizer_remove_extra_whitespaces   = false;
-    bool tokenizer_escape_whitespaces         = true;
-    bool tokenizer_treat_whitespace_as_suffix = false;
+    const token_data & get_token_data(llama_token id) const;
 
-    std::vector precompiled_charsmap;
+    const char *     token_get_text (llama_token id) const;
+    float            token_get_score(llama_token id) const;
+    llama_token_attr token_get_attr (llama_token id) const;
+
+    llama_token token_bos() const;
+    llama_token token_eos() const;
+    llama_token token_eot() const;
+    llama_token token_eom() const;
+    llama_token token_unk() const;
+    llama_token token_sep() const;
+    llama_token token_nl () const;
+    llama_token token_pad() const;
+
+    llama_token token_prefix() const;
+    llama_token token_middle() const;
+    llama_token token_suffix() const;
+
+    llama_token token_fim_pre() const;
+    llama_token token_fim_suf() const;
+    llama_token token_fim_mid() const;
+    llama_token token_fim_pad() const;
+    llama_token token_fim_rep() const;
+    llama_token token_fim_sep() const;
+
+    bool get_add_space_prefix          () const;
+    bool get_add_bos                   () const;
+    bool get_add_eos                   () const;
+    bool get_ignore_merges             () const;
+    bool get_clean_spaces              () const;
+    bool get_remove_extra_whitespaces  () const;
+    bool get_escape_whitespaces        () const;
+    bool get_treat_whitespace_as_suffix() const;
+
+    int max_token_len() const;
 
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
+
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
+
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
+
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
+
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
+
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
+
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
+
+    void print_info() const;
+
+private:
+    struct impl;
+    std::unique_ptr pimpl;
 };
-
-//
-// internal API
-//
-
-// TODO: rename to llama_tokenize_impl
-// TODO: This should probably be in llama.h
-std::vector llama_tokenize_internal(
-        const llama_vocab & vocab,
-        std::string raw_text,
-        bool add_special,
-        bool parse_special = false);
-
-// TODO: move the API below as member functions of llama_vocab
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
-
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
-
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
-
-bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
-bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eot_impl   (const struct llama_vocab & vocab);
-llama_token llama_token_eom_impl   (const struct llama_vocab & vocab);
-
-int32_t llama_tokenize_impl(
-        const struct llama_vocab & vocab,
-                      const char * text,
-                         int32_t   text_len,
-                     llama_token * tokens,
-                         int32_t   n_tokens_max,
-                            bool   add_special,
-                            bool   parse_special);
-
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token,
-                            char * buf,
-                         int32_t   length,
-                         int32_t   lstrip,
-                            bool   special);
-
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
-               const llama_token * tokens,
-                         int32_t   n_tokens,
-                            char * text,
-                         int32_t   text_len_max,
-                            bool   remove_special,
-                            bool   unparse_special);
diff --git a/src/llama.cpp b/src/llama.cpp
index 39e20440e..192b20a27 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -1,8656 +1,76 @@
 #include "llama-impl.h"
+
+#include "llama-chat.h"
+#include "llama-mmap.h"
+#include "llama-context.h"
 #include "llama-vocab.h"
 #include "llama-sampling.h"
-
-#include "unicode.h"
+#include "llama-kv-cache.h"
+#include "llama-model-loader.h"
+#include "llama-model.h"
 
 #include "ggml.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
-
-#ifdef GGML_USE_RPC
-#  include "ggml-rpc.h"
-#endif
-
-#ifdef GGML_USE_CUDA
-#  include "ggml-cuda.h"
-#elif defined(GGML_USE_VULKAN)
-#  include "ggml-vulkan.h"
-#elif defined(GGML_USE_SYCL)
-#  include "ggml-sycl.h"
-#elif defined(GGML_USE_KOMPUTE)
-#   include "ggml-kompute.h"
-#elif defined(GGML_USE_CANN)
-#   include "ggml-cann.h"
-#endif
-
-#ifdef GGML_USE_BLAS
-#  include "ggml-blas.h"
-#endif
-
-#ifdef GGML_USE_METAL
-#  include "ggml-metal.h"
-#endif
-
-// TODO: replace with ggml API call
-#define QK_K 256
-
-#ifdef __has_include
-    #if __has_include()
-        #include 
-        #if defined(_POSIX_MAPPED_FILES)
-            #include 
-            #include 
-        #endif
-        #if defined(_POSIX_MEMLOCK_RANGE)
-            #include 
-        #endif
-    #endif
-#endif
-
-#if defined(_WIN32)
-    #define WIN32_LEAN_AND_MEAN
-    #ifndef NOMINMAX
-        #define NOMINMAX
-    #endif
-    #include 
-    #ifndef PATH_MAX
-        #define PATH_MAX MAX_PATH
-    #endif
-    #include 
-#endif
-
-#if __cplusplus >= 202000L
-    #define LU8(x) (const char*)(u8##x)
-#else
-    #define LU8(x) u8##x
-#endif
+#include "ggml-cpp.h"
 
 #include 
 #include 
 #include 
-#include 
 #include 
-#include 
-#include 
 #include 
-#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
 #include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-// bump if necessary
-#define LLAMA_MAX_LAYERS  512
-#define LLAMA_MAX_EXPERTS 160  // DeepSeekV2
-
-//
-// helpers
-//
-
-// trim whitespace from the beginning and end of a string
-static std::string trim(const std::string & str) {
-    size_t start = 0;
-    size_t end = str.size();
-    while (start < end && isspace(str[start])) {
-        start += 1;
-    }
-    while (end > start && isspace(str[end - 1])) {
-        end -= 1;
-    }
-    return str.substr(start, end - start);
-}
-
-static bool is_float_close(float a, float b, float abs_tol) {
-    // Check for non-negative tolerance
-    if (abs_tol < 0.0) {
-        throw std::invalid_argument("Tolerance must be non-negative");
-    }
-
-    // Exact equality check
-    if (a == b) {
-        return true;
-    }
-
-    // Check for infinities
-    if (std::isinf(a) || std::isinf(b)) {
-        return false;
-    }
-
-    // Regular comparison using the provided absolute tolerance
-    return std::fabs(b - a) <= abs_tol;
-}
-
-static void zeros(std::ofstream & file, size_t n) {
-    char zero = 0;
-    for (size_t i = 0; i < n; ++i) {
-        file.write(&zero, 1);
-    }
-}
-
-LLAMA_ATTRIBUTE_FORMAT(1, 2)
-static std::string format(const char * fmt, ...) {
-    va_list ap;
-    va_list ap2;
-    va_start(ap, fmt);
-    va_copy(ap2, ap);
-    int size = vsnprintf(NULL, 0, fmt, ap);
-    GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
-    std::vector buf(size + 1);
-    int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
-    GGML_ASSERT(size2 == size);
-    va_end(ap2);
-    va_end(ap);
-    return std::string(buf.data(), size);
-}
-
-//
-// gguf constants (sync with gguf.py)
-//
-
-enum llm_arch {
-    LLM_ARCH_LLAMA,
-    LLM_ARCH_FALCON,
-    LLM_ARCH_BAICHUAN,
-    LLM_ARCH_GROK,
-    LLM_ARCH_GPT2,
-    LLM_ARCH_GPTJ,
-    LLM_ARCH_GPTNEOX,
-    LLM_ARCH_MPT,
-    LLM_ARCH_STARCODER,
-    LLM_ARCH_REFACT,
-    LLM_ARCH_BERT,
-    LLM_ARCH_NOMIC_BERT,
-    LLM_ARCH_JINA_BERT_V2,
-    LLM_ARCH_BLOOM,
-    LLM_ARCH_STABLELM,
-    LLM_ARCH_QWEN,
-    LLM_ARCH_QWEN2,
-    LLM_ARCH_QWEN2MOE,
-    LLM_ARCH_PHI2,
-    LLM_ARCH_PHI3,
-    LLM_ARCH_PLAMO,
-    LLM_ARCH_CODESHELL,
-    LLM_ARCH_ORION,
-    LLM_ARCH_INTERNLM2,
-    LLM_ARCH_MINICPM,
-    LLM_ARCH_GEMMA,
-    LLM_ARCH_GEMMA2,
-    LLM_ARCH_STARCODER2,
-    LLM_ARCH_MAMBA,
-    LLM_ARCH_XVERSE,
-    LLM_ARCH_COMMAND_R,
-    LLM_ARCH_DBRX,
-    LLM_ARCH_OLMO,
-    LLM_ARCH_OPENELM,
-    LLM_ARCH_ARCTIC,
-    LLM_ARCH_DEEPSEEK2,
-    LLM_ARCH_CHATGLM,
-    LLM_ARCH_BITNET,
-    LLM_ARCH_T5,
-    LLM_ARCH_T5ENCODER,
-    LLM_ARCH_JAIS,
-    LLM_ARCH_NEMOTRON,
-    LLM_ARCH_EXAONE,
-    LLM_ARCH_RWKV6,
-    LLM_ARCH_UNKNOWN,
-};
-
-static const std::map LLM_ARCH_NAMES = {
-    { LLM_ARCH_LLAMA,           "llama"        },
-    { LLM_ARCH_FALCON,          "falcon"       },
-    { LLM_ARCH_GROK,            "grok"         },
-    { LLM_ARCH_GPT2,            "gpt2"         },
-    { LLM_ARCH_GPTJ,            "gptj"         },
-    { LLM_ARCH_GPTNEOX,         "gptneox"      },
-    { LLM_ARCH_MPT,             "mpt"          },
-    { LLM_ARCH_BAICHUAN,        "baichuan"     },
-    { LLM_ARCH_STARCODER,       "starcoder"    },
-    { LLM_ARCH_REFACT,          "refact"       },
-    { LLM_ARCH_BERT,            "bert"         },
-    { LLM_ARCH_NOMIC_BERT,      "nomic-bert"   },
-    { LLM_ARCH_JINA_BERT_V2,    "jina-bert-v2" },
-    { LLM_ARCH_BLOOM,           "bloom"        },
-    { LLM_ARCH_STABLELM,        "stablelm"     },
-    { LLM_ARCH_QWEN,            "qwen"         },
-    { LLM_ARCH_QWEN2,           "qwen2"        },
-    { LLM_ARCH_QWEN2MOE,        "qwen2moe"     },
-    { LLM_ARCH_PHI2,            "phi2"         },
-    { LLM_ARCH_PHI3,            "phi3"         },
-    { LLM_ARCH_PLAMO,           "plamo"        },
-    { LLM_ARCH_CODESHELL,       "codeshell"    },
-    { LLM_ARCH_ORION,           "orion"        },
-    { LLM_ARCH_INTERNLM2,       "internlm2"    },
-    { LLM_ARCH_MINICPM,         "minicpm"      },
-    { LLM_ARCH_GEMMA,           "gemma"        },
-    { LLM_ARCH_GEMMA2,          "gemma2"       },
-    { LLM_ARCH_STARCODER2,      "starcoder2"   },
-    { LLM_ARCH_MAMBA,           "mamba"        },
-    { LLM_ARCH_XVERSE,          "xverse"       },
-    { LLM_ARCH_COMMAND_R,       "command-r"    },
-    { LLM_ARCH_DBRX,            "dbrx"         },
-    { LLM_ARCH_OLMO,            "olmo"         },
-    { LLM_ARCH_OPENELM,         "openelm"      },
-    { LLM_ARCH_ARCTIC,          "arctic"       },
-    { LLM_ARCH_DEEPSEEK2,       "deepseek2"    },
-    { LLM_ARCH_CHATGLM,         "chatglm"      },
-    { LLM_ARCH_BITNET,          "bitnet"       },
-    { LLM_ARCH_T5,              "t5"           },
-    { LLM_ARCH_T5ENCODER,       "t5encoder"    },
-    { LLM_ARCH_JAIS,            "jais"         },
-    { LLM_ARCH_NEMOTRON,        "nemotron"     },
-    { LLM_ARCH_EXAONE,          "exaone"       },
-    { LLM_ARCH_RWKV6,           "rwkv6"        },
-    { LLM_ARCH_UNKNOWN,         "(unknown)"    },
-};
-
-enum llm_kv {
-    LLM_KV_GENERAL_TYPE,
-    LLM_KV_GENERAL_ARCHITECTURE,
-    LLM_KV_GENERAL_QUANTIZATION_VERSION,
-    LLM_KV_GENERAL_ALIGNMENT,
-    LLM_KV_GENERAL_NAME,
-    LLM_KV_GENERAL_AUTHOR,
-    LLM_KV_GENERAL_VERSION,
-    LLM_KV_GENERAL_URL,
-    LLM_KV_GENERAL_DESCRIPTION,
-    LLM_KV_GENERAL_LICENSE,
-    LLM_KV_GENERAL_SOURCE_URL,
-    LLM_KV_GENERAL_SOURCE_HF_REPO,
-
-    LLM_KV_VOCAB_SIZE,
-    LLM_KV_CONTEXT_LENGTH,
-    LLM_KV_EMBEDDING_LENGTH,
-    LLM_KV_BLOCK_COUNT,
-    LLM_KV_LEADING_DENSE_BLOCK_COUNT,
-    LLM_KV_FEED_FORWARD_LENGTH,
-    LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
-    LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
-    LLM_KV_USE_PARALLEL_RESIDUAL,
-    LLM_KV_TENSOR_DATA_LAYOUT,
-    LLM_KV_EXPERT_COUNT,
-    LLM_KV_EXPERT_USED_COUNT,
-    LLM_KV_EXPERT_SHARED_COUNT,
-    LLM_KV_EXPERT_WEIGHTS_SCALE,
-    LLM_KV_POOLING_TYPE,
-    LLM_KV_LOGIT_SCALE,
-    LLM_KV_DECODER_START_TOKEN_ID,
-    LLM_KV_ATTN_LOGIT_SOFTCAPPING,
-    LLM_KV_FINAL_LOGIT_SOFTCAPPING,
-    LLM_KV_RESCALE_EVERY_N_LAYERS,
-    LLM_KV_TIME_MIX_EXTRA_DIM,
-    LLM_KV_TIME_DECAY_EXTRA_DIM,
-
-    LLM_KV_ATTENTION_HEAD_COUNT,
-    LLM_KV_ATTENTION_HEAD_COUNT_KV,
-    LLM_KV_ATTENTION_MAX_ALIBI_BIAS,
-    LLM_KV_ATTENTION_CLAMP_KQV,
-    LLM_KV_ATTENTION_KEY_LENGTH,
-    LLM_KV_ATTENTION_VALUE_LENGTH,
-    LLM_KV_ATTENTION_LAYERNORM_EPS,
-    LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
-    LLM_KV_ATTENTION_CAUSAL,
-    LLM_KV_ATTENTION_Q_LORA_RANK,
-    LLM_KV_ATTENTION_KV_LORA_RANK,
-    LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
-    LLM_KV_ATTENTION_SLIDING_WINDOW,
-
-    LLM_KV_ROPE_DIMENSION_COUNT,
-    LLM_KV_ROPE_FREQ_BASE,
-    LLM_KV_ROPE_SCALE_LINEAR,
-    LLM_KV_ROPE_SCALING_TYPE,
-    LLM_KV_ROPE_SCALING_FACTOR,
-    LLM_KV_ROPE_SCALING_ATTN_FACTOR,
-    LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
-    LLM_KV_ROPE_SCALING_FINETUNED,
-    LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
-
-    LLM_KV_SPLIT_NO,
-    LLM_KV_SPLIT_COUNT,
-    LLM_KV_SPLIT_TENSORS_COUNT,
-
-    LLM_KV_SSM_INNER_SIZE,
-    LLM_KV_SSM_CONV_KERNEL,
-    LLM_KV_SSM_STATE_SIZE,
-    LLM_KV_SSM_TIME_STEP_RANK,
-    LLM_KV_SSM_DT_B_C_RMS,
-
-    LLM_KV_WKV_HEAD_SIZE,
-
-    LLM_KV_TOKENIZER_MODEL,
-    LLM_KV_TOKENIZER_PRE,
-    LLM_KV_TOKENIZER_LIST,
-    LLM_KV_TOKENIZER_TOKEN_TYPE,
-    LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,
-    LLM_KV_TOKENIZER_SCORES,
-    LLM_KV_TOKENIZER_MERGES,
-    LLM_KV_TOKENIZER_BOS_ID,
-    LLM_KV_TOKENIZER_EOS_ID,
-    LLM_KV_TOKENIZER_UNK_ID,
-    LLM_KV_TOKENIZER_SEP_ID,
-    LLM_KV_TOKENIZER_PAD_ID,
-    LLM_KV_TOKENIZER_CLS_ID,
-    LLM_KV_TOKENIZER_MASK_ID,
-    LLM_KV_TOKENIZER_ADD_BOS,
-    LLM_KV_TOKENIZER_ADD_EOS,
-    LLM_KV_TOKENIZER_ADD_PREFIX,
-    LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
-    LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
-    LLM_KV_TOKENIZER_HF_JSON,
-    LLM_KV_TOKENIZER_RWKV,
-    LLM_KV_TOKENIZER_PREFIX_ID,
-    LLM_KV_TOKENIZER_SUFFIX_ID,
-    LLM_KV_TOKENIZER_MIDDLE_ID,
-    LLM_KV_TOKENIZER_EOT_ID,
-    LLM_KV_TOKENIZER_EOM_ID,
-
-    LLM_KV_ADAPTER_TYPE,
-    LLM_KV_ADAPTER_LORA_ALPHA,
-};
-
-static const std::map LLM_KV_NAMES = {
-    { LLM_KV_GENERAL_TYPE,                  "general.type"                          },
-    { LLM_KV_GENERAL_ARCHITECTURE,          "general.architecture"                  },
-    { LLM_KV_GENERAL_QUANTIZATION_VERSION,  "general.quantization_version"          },
-    { LLM_KV_GENERAL_ALIGNMENT,             "general.alignment"                     },
-    { LLM_KV_GENERAL_NAME,                  "general.name"                          },
-    { LLM_KV_GENERAL_AUTHOR,                "general.author"                        },
-    { LLM_KV_GENERAL_VERSION,               "general.version"                       },
-    { LLM_KV_GENERAL_URL,                   "general.url"                           },
-    { LLM_KV_GENERAL_DESCRIPTION,           "general.description"                   },
-    { LLM_KV_GENERAL_LICENSE,               "general.license"                       },
-    { LLM_KV_GENERAL_SOURCE_URL,            "general.source.url"                    },
-    { LLM_KV_GENERAL_SOURCE_HF_REPO,        "general.source.huggingface.repository" },
-
-    { LLM_KV_VOCAB_SIZE,                        "%s.vocab_size"                        },
-    { LLM_KV_CONTEXT_LENGTH,                    "%s.context_length"                    },
-    { LLM_KV_EMBEDDING_LENGTH,                  "%s.embedding_length"                  },
-    { LLM_KV_BLOCK_COUNT,                       "%s.block_count"                       },
-    { LLM_KV_LEADING_DENSE_BLOCK_COUNT,         "%s.leading_dense_block_count"         },
-    { LLM_KV_FEED_FORWARD_LENGTH,               "%s.feed_forward_length"               },
-    { LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        "%s.expert_feed_forward_length"        },
-    { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
-    { LLM_KV_USE_PARALLEL_RESIDUAL,             "%s.use_parallel_residual"             },
-    { LLM_KV_TENSOR_DATA_LAYOUT,                "%s.tensor_data_layout"                },
-    { LLM_KV_EXPERT_COUNT,                      "%s.expert_count"                      },
-    { LLM_KV_EXPERT_USED_COUNT,                 "%s.expert_used_count"                 },
-    { LLM_KV_EXPERT_SHARED_COUNT,               "%s.expert_shared_count"               },
-    { LLM_KV_EXPERT_WEIGHTS_SCALE,              "%s.expert_weights_scale"              },
-    { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
-    { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
-    { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
-    { LLM_KV_ATTN_LOGIT_SOFTCAPPING,            "%s.attn_logit_softcapping"            },
-    { LLM_KV_FINAL_LOGIT_SOFTCAPPING,           "%s.final_logit_softcapping"           },
-    { LLM_KV_RESCALE_EVERY_N_LAYERS,            "%s.rescale_every_n_layers"            },
-    { LLM_KV_TIME_MIX_EXTRA_DIM,                "%s.time_mix_extra_dim"                },
-    { LLM_KV_TIME_DECAY_EXTRA_DIM,              "%s.time_decay_extra_dim"              },
-
-    { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             },
-    { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          },
-    { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,         "%s.attention.max_alibi_bias"         },
-    { LLM_KV_ATTENTION_CLAMP_KQV,              "%s.attention.clamp_kqv"              },
-    { LLM_KV_ATTENTION_KEY_LENGTH,             "%s.attention.key_length"             },
-    { LLM_KV_ATTENTION_VALUE_LENGTH,           "%s.attention.value_length"           },
-    { LLM_KV_ATTENTION_LAYERNORM_EPS,          "%s.attention.layer_norm_epsilon"     },
-    { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,      "%s.attention.layer_norm_rms_epsilon" },
-    { LLM_KV_ATTENTION_CAUSAL,                 "%s.attention.causal"                 },
-    { LLM_KV_ATTENTION_Q_LORA_RANK,            "%s.attention.q_lora_rank"            },
-    { LLM_KV_ATTENTION_KV_LORA_RANK,           "%s.attention.kv_lora_rank"           },
-    { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
-    { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
-
-    { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
-    { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
-    { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"                    },
-    { LLM_KV_ROPE_SCALING_TYPE,             "%s.rope.scaling.type"                    },
-    { LLM_KV_ROPE_SCALING_FACTOR,           "%s.rope.scaling.factor"                  },
-    { LLM_KV_ROPE_SCALING_ATTN_FACTOR,      "%s.rope.scaling.attn_factor"             },
-    { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,     "%s.rope.scaling.original_context_length" },
-    { LLM_KV_ROPE_SCALING_FINETUNED,        "%s.rope.scaling.finetuned"               },
-    { LLM_KV_ROPE_SCALING_YARN_LOG_MUL,     "%s.rope.scaling.yarn_log_multiplier"     },
-
-    { LLM_KV_SPLIT_NO,                      "split.no"            },
-    { LLM_KV_SPLIT_COUNT,                   "split.count"         },
-    { LLM_KV_SPLIT_TENSORS_COUNT,           "split.tensors.count" },
-
-    { LLM_KV_SSM_CONV_KERNEL,               "%s.ssm.conv_kernel"    },
-    { LLM_KV_SSM_INNER_SIZE,                "%s.ssm.inner_size"     },
-    { LLM_KV_SSM_STATE_SIZE,                "%s.ssm.state_size"     },
-    { LLM_KV_SSM_TIME_STEP_RANK,            "%s.ssm.time_step_rank" },
-    { LLM_KV_SSM_DT_B_C_RMS,                "%s.ssm.dt_b_c_rms" },
-
-    { LLM_KV_WKV_HEAD_SIZE,                 "%s.wkv.head_size" },
-
-    { LLM_KV_TOKENIZER_MODEL,                "tokenizer.ggml.model"                    },
-    { LLM_KV_TOKENIZER_PRE,                  "tokenizer.ggml.pre"                      },
-    { LLM_KV_TOKENIZER_LIST,                 "tokenizer.ggml.tokens"                   },
-    { LLM_KV_TOKENIZER_TOKEN_TYPE,           "tokenizer.ggml.token_type"               },
-    { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT,     "tokenizer.ggml.token_type_count"         },
-    { LLM_KV_TOKENIZER_SCORES,               "tokenizer.ggml.scores"                   },
-    { LLM_KV_TOKENIZER_MERGES,               "tokenizer.ggml.merges"                   },
-    { LLM_KV_TOKENIZER_BOS_ID,               "tokenizer.ggml.bos_token_id"             },
-    { LLM_KV_TOKENIZER_EOS_ID,               "tokenizer.ggml.eos_token_id"             },
-    { LLM_KV_TOKENIZER_UNK_ID,               "tokenizer.ggml.unknown_token_id"         },
-    { LLM_KV_TOKENIZER_SEP_ID,               "tokenizer.ggml.seperator_token_id"       },
-    { LLM_KV_TOKENIZER_PAD_ID,               "tokenizer.ggml.padding_token_id"         },
-    { LLM_KV_TOKENIZER_CLS_ID,               "tokenizer.ggml.cls_token_id"             },
-    { LLM_KV_TOKENIZER_MASK_ID,              "tokenizer.ggml.mask_token_id"            },
-    { LLM_KV_TOKENIZER_ADD_BOS,              "tokenizer.ggml.add_bos_token"            },
-    { LLM_KV_TOKENIZER_ADD_EOS,              "tokenizer.ggml.add_eos_token"            },
-    { LLM_KV_TOKENIZER_ADD_PREFIX,           "tokenizer.ggml.add_space_prefix"         },
-    { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,      "tokenizer.ggml.remove_extra_whitespaces" },
-    { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap"     },
-    { LLM_KV_TOKENIZER_HF_JSON,              "tokenizer.huggingface.json"              },
-    { LLM_KV_TOKENIZER_RWKV,                 "tokenizer.rwkv.world"                    },
-    { LLM_KV_TOKENIZER_PREFIX_ID,            "tokenizer.ggml.prefix_token_id"          },
-    { LLM_KV_TOKENIZER_SUFFIX_ID,            "tokenizer.ggml.suffix_token_id"          },
-    { LLM_KV_TOKENIZER_MIDDLE_ID,            "tokenizer.ggml.middle_token_id"          },
-    { LLM_KV_TOKENIZER_EOT_ID,               "tokenizer.ggml.eot_token_id"             },
-    { LLM_KV_TOKENIZER_EOM_ID,               "tokenizer.ggml.eom_token_id"             },
-
-    { LLM_KV_ADAPTER_TYPE,                  "adapter.type"       },
-    { LLM_KV_ADAPTER_LORA_ALPHA,            "adapter.lora.alpha" },
-};
-
-struct LLM_KV {
-    LLM_KV(llm_arch arch) : arch(arch) {}
-
-    llm_arch arch;
-
-    std::string operator()(llm_kv kv) const {
-        return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch));
-    }
-};
-
-enum llm_tensor {
-    LLM_TENSOR_TOKEN_EMBD,
-    LLM_TENSOR_TOKEN_EMBD_NORM,
-    LLM_TENSOR_TOKEN_TYPES,
-    LLM_TENSOR_POS_EMBD,
-    LLM_TENSOR_OUTPUT,
-    LLM_TENSOR_OUTPUT_NORM,
-    LLM_TENSOR_ROPE_FREQS,
-    LLM_TENSOR_ROPE_FACTORS_LONG,
-    LLM_TENSOR_ROPE_FACTORS_SHORT,
-    LLM_TENSOR_ATTN_Q,
-    LLM_TENSOR_ATTN_K,
-    LLM_TENSOR_ATTN_V,
-    LLM_TENSOR_ATTN_QKV,
-    LLM_TENSOR_ATTN_OUT,
-    LLM_TENSOR_ATTN_NORM,
-    LLM_TENSOR_ATTN_NORM_2,
-    LLM_TENSOR_ATTN_OUT_NORM,
-    LLM_TENSOR_ATTN_POST_NORM,
-    LLM_TENSOR_ATTN_ROT_EMBD,
-    LLM_TENSOR_FFN_GATE_INP,
-    LLM_TENSOR_FFN_GATE_INP_SHEXP,
-    LLM_TENSOR_FFN_NORM,
-    LLM_TENSOR_FFN_POST_NORM,
-    LLM_TENSOR_FFN_GATE,
-    LLM_TENSOR_FFN_DOWN,
-    LLM_TENSOR_FFN_UP,
-    LLM_TENSOR_FFN_ACT,
-    LLM_TENSOR_FFN_DOWN_EXP,  // split experts for backward compatibility
-    LLM_TENSOR_FFN_GATE_EXP,
-    LLM_TENSOR_FFN_UP_EXP,
-    LLM_TENSOR_FFN_NORM_EXPS,
-    LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
-    LLM_TENSOR_FFN_GATE_EXPS,
-    LLM_TENSOR_FFN_UP_EXPS,
-    LLM_TENSOR_FFN_DOWN_SHEXP,
-    LLM_TENSOR_FFN_GATE_SHEXP,
-    LLM_TENSOR_FFN_UP_SHEXP,
-    LLM_TENSOR_ATTN_Q_NORM,
-    LLM_TENSOR_ATTN_K_NORM,
-    LLM_TENSOR_LAYER_OUT_NORM,
-    LLM_TENSOR_SSM_IN,
-    LLM_TENSOR_SSM_CONV1D,
-    LLM_TENSOR_SSM_X,
-    LLM_TENSOR_SSM_DT,
-    LLM_TENSOR_SSM_A,
-    LLM_TENSOR_SSM_D,
-    LLM_TENSOR_SSM_OUT,
-    LLM_TENSOR_TIME_MIX_W1,
-    LLM_TENSOR_TIME_MIX_W2,
-    LLM_TENSOR_TIME_MIX_LERP_X,
-    LLM_TENSOR_TIME_MIX_LERP_W,
-    LLM_TENSOR_TIME_MIX_LERP_K,
-    LLM_TENSOR_TIME_MIX_LERP_V,
-    LLM_TENSOR_TIME_MIX_LERP_R,
-    LLM_TENSOR_TIME_MIX_LERP_G,
-    LLM_TENSOR_TIME_MIX_FIRST,
-    LLM_TENSOR_TIME_MIX_DECAY,
-    LLM_TENSOR_TIME_MIX_DECAY_W1,
-    LLM_TENSOR_TIME_MIX_DECAY_W2,
-    LLM_TENSOR_TIME_MIX_KEY,
-    LLM_TENSOR_TIME_MIX_VALUE,
-    LLM_TENSOR_TIME_MIX_RECEPTANCE,
-    LLM_TENSOR_TIME_MIX_GATE,
-    LLM_TENSOR_TIME_MIX_LN,
-    LLM_TENSOR_TIME_MIX_OUTPUT,
-    LLM_TENSOR_CHANNEL_MIX_LERP_K,
-    LLM_TENSOR_CHANNEL_MIX_LERP_R,
-    LLM_TENSOR_CHANNEL_MIX_KEY,
-    LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,
-    LLM_TENSOR_CHANNEL_MIX_VALUE,
-    LLM_TENSOR_ATTN_Q_A,
-    LLM_TENSOR_ATTN_Q_B,
-    LLM_TENSOR_ATTN_KV_A_MQA,
-    LLM_TENSOR_ATTN_KV_B,
-    LLM_TENSOR_ATTN_Q_A_NORM,
-    LLM_TENSOR_ATTN_KV_A_NORM,
-    LLM_TENSOR_ATTN_SUB_NORM,
-    LLM_TENSOR_FFN_SUB_NORM,
-    LLM_TENSOR_DEC_ATTN_NORM,
-    LLM_TENSOR_DEC_ATTN_Q,
-    LLM_TENSOR_DEC_ATTN_K,
-    LLM_TENSOR_DEC_ATTN_V,
-    LLM_TENSOR_DEC_ATTN_OUT,
-    LLM_TENSOR_DEC_ATTN_REL_B,
-    LLM_TENSOR_DEC_CROSS_ATTN_NORM,
-    LLM_TENSOR_DEC_CROSS_ATTN_Q,
-    LLM_TENSOR_DEC_CROSS_ATTN_K,
-    LLM_TENSOR_DEC_CROSS_ATTN_V,
-    LLM_TENSOR_DEC_CROSS_ATTN_OUT,
-    LLM_TENSOR_DEC_CROSS_ATTN_REL_B,
-    LLM_TENSOR_DEC_FFN_NORM,
-    LLM_TENSOR_DEC_FFN_GATE,
-    LLM_TENSOR_DEC_FFN_DOWN,
-    LLM_TENSOR_DEC_FFN_UP,
-    LLM_TENSOR_DEC_OUTPUT_NORM,
-    LLM_TENSOR_ENC_ATTN_NORM,
-    LLM_TENSOR_ENC_ATTN_Q,
-    LLM_TENSOR_ENC_ATTN_K,
-    LLM_TENSOR_ENC_ATTN_V,
-    LLM_TENSOR_ENC_ATTN_OUT,
-    LLM_TENSOR_ENC_ATTN_REL_B,
-    LLM_TENSOR_ENC_FFN_NORM,
-    LLM_TENSOR_ENC_FFN_GATE,
-    LLM_TENSOR_ENC_FFN_DOWN,
-    LLM_TENSOR_ENC_FFN_UP,
-    LLM_TENSOR_ENC_OUTPUT_NORM,
-};
-
-static const std::map> LLM_TENSOR_NAMES = {
-    {
-        LLM_ARCH_LLAMA,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
-            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
-            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-        },
-    },
-    {
-        LLM_ARCH_BAICHUAN,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_FALCON,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_NORM_2,     "blk.%d.attn_norm_2" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_GROK,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
-            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
-            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-        },
-    },
-    {
-        LLM_ARCH_GPT2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_POS_EMBD,        "position_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_GPTJ,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-        },
-    },
-    {
-        LLM_ARCH_GPTNEOX,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_MPT,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output"},
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_ACT,         "blk.%d.ffn.act" },
-            { LLM_TENSOR_POS_EMBD,        "position_embd" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm"},
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm"},
-        },
-    },
-    {
-        LLM_ARCH_STARCODER,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_POS_EMBD,        "position_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_REFACT,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_BERT,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
-            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
-            { LLM_TENSOR_POS_EMBD,        "position_embd" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_NOMIC_BERT,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
-            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_JINA_BERT_V2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
-            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
-            { LLM_TENSOR_ATTN_NORM_2,     "blk.%d.attn_norm_2" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_BLOOM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_STABLELM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-        },
-    },
-    {
-        LLM_ARCH_QWEN,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_QWEN2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_QWEN2MOE,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_OUTPUT,             "output" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
-            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
-            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
-            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
-            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
-        },
-    },
-    {
-        LLM_ARCH_PHI2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_PHI3,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_OUTPUT,             "output" },
-            { LLM_TENSOR_ROPE_FACTORS_LONG,  "rope_factors_long" },
-            { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,           "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_PLAMO,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_CODESHELL,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_ORION,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_INTERNLM2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_MINICPM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_GATE_EXP,    "blk.%d.ffn_gate.%d" },
-            { LLM_TENSOR_FFN_DOWN_EXP,    "blk.%d.ffn_down.%d" },
-            { LLM_TENSOR_FFN_UP_EXP,      "blk.%d.ffn_up.%d" },
-        },
-    },
-    {
-        LLM_ARCH_GEMMA,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_GEMMA2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_POST_NORM,  "blk.%d.post_attention_norm" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_POST_NORM,   "blk.%d.post_ffw_norm" },
-        },
-    },
-    {
-        LLM_ARCH_STARCODER2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_MAMBA,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_SSM_IN,          "blk.%d.ssm_in" },
-            { LLM_TENSOR_SSM_CONV1D,      "blk.%d.ssm_conv1d" },
-            { LLM_TENSOR_SSM_X,           "blk.%d.ssm_x" },
-            { LLM_TENSOR_SSM_DT,          "blk.%d.ssm_dt" },
-            { LLM_TENSOR_SSM_A,           "blk.%d.ssm_a" },
-            { LLM_TENSOR_SSM_D,           "blk.%d.ssm_d" },
-            { LLM_TENSOR_SSM_OUT,         "blk.%d.ssm_out" },
-        },
-    },
-    {
-        LLM_ARCH_XVERSE,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_COMMAND_R,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-        },
-    },
-    {
-        LLM_ARCH_DBRX,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-        },
-    },
-    {
-        LLM_ARCH_OLMO,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_OPENELM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
-            { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_ARCTIC,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_NORM_EXPS,   "blk.%d.ffn_norm_exps" },
-            { LLM_TENSOR_FFN_GATE_EXPS,   "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
-        },
-    },
-    {
-        LLM_ARCH_DEEPSEEK2,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_OUTPUT,             "output" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q_A_NORM,      "blk.%d.attn_q_a_norm" },
-            { LLM_TENSOR_ATTN_KV_A_NORM,     "blk.%d.attn_kv_a_norm" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_Q_A,           "blk.%d.attn_q_a" },
-            { LLM_TENSOR_ATTN_Q_B,           "blk.%d.attn_q_b" },
-            { LLM_TENSOR_ATTN_KV_A_MQA,      "blk.%d.attn_kv_a_mqa" },
-            { LLM_TENSOR_ATTN_KV_B,          "blk.%d.attn_kv_b" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
-            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
-            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
-            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
-            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
-            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
-            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
-            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
-        },
-    },
-    {
-        LLM_ARCH_CHATGLM,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_BITNET,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
-            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_SUB_NORM,      "blk.%d.attn_sub_norm" },
-            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_SUB_NORM,       "blk.%d.ffn_sub_norm" },
-        },
-    },
-    {
-        LLM_ARCH_T5,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,           "token_embd" },
-            { LLM_TENSOR_OUTPUT,               "output" },
-            { LLM_TENSOR_DEC_OUTPUT_NORM,      "dec.output_norm" },
-            { LLM_TENSOR_DEC_ATTN_NORM,        "dec.blk.%d.attn_norm" },
-            { LLM_TENSOR_DEC_ATTN_Q,           "dec.blk.%d.attn_q" },
-            { LLM_TENSOR_DEC_ATTN_K,           "dec.blk.%d.attn_k" },
-            { LLM_TENSOR_DEC_ATTN_V,           "dec.blk.%d.attn_v" },
-            { LLM_TENSOR_DEC_ATTN_OUT,         "dec.blk.%d.attn_o" },
-            { LLM_TENSOR_DEC_ATTN_REL_B,       "dec.blk.%d.attn_rel_b" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "dec.blk.%d.cross_attn_norm" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_Q,     "dec.blk.%d.cross_attn_q" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_K,     "dec.blk.%d.cross_attn_k" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_V,     "dec.blk.%d.cross_attn_v" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_OUT,   "dec.blk.%d.cross_attn_o" },
-            { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" },
-            { LLM_TENSOR_DEC_FFN_NORM,         "dec.blk.%d.ffn_norm" },
-            { LLM_TENSOR_DEC_FFN_GATE,         "dec.blk.%d.ffn_gate" },
-            { LLM_TENSOR_DEC_FFN_DOWN,         "dec.blk.%d.ffn_down" },
-            { LLM_TENSOR_DEC_FFN_UP,           "dec.blk.%d.ffn_up" },
-            { LLM_TENSOR_ENC_OUTPUT_NORM,      "enc.output_norm" },
-            { LLM_TENSOR_ENC_ATTN_NORM,        "enc.blk.%d.attn_norm" },
-            { LLM_TENSOR_ENC_ATTN_Q,           "enc.blk.%d.attn_q" },
-            { LLM_TENSOR_ENC_ATTN_K,           "enc.blk.%d.attn_k" },
-            { LLM_TENSOR_ENC_ATTN_V,           "enc.blk.%d.attn_v" },
-            { LLM_TENSOR_ENC_ATTN_OUT,         "enc.blk.%d.attn_o" },
-            { LLM_TENSOR_ENC_ATTN_REL_B,       "enc.blk.%d.attn_rel_b" },
-            { LLM_TENSOR_ENC_FFN_NORM,         "enc.blk.%d.ffn_norm" },
-            { LLM_TENSOR_ENC_FFN_GATE,         "enc.blk.%d.ffn_gate" },
-            { LLM_TENSOR_ENC_FFN_DOWN,         "enc.blk.%d.ffn_down" },
-            { LLM_TENSOR_ENC_FFN_UP,           "enc.blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_T5ENCODER,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,           "token_embd" },
-            { LLM_TENSOR_OUTPUT,               "output" },
-            { LLM_TENSOR_ENC_OUTPUT_NORM,      "enc.output_norm" },
-            { LLM_TENSOR_ENC_ATTN_NORM,        "enc.blk.%d.attn_norm" },
-            { LLM_TENSOR_ENC_ATTN_Q,           "enc.blk.%d.attn_q" },
-            { LLM_TENSOR_ENC_ATTN_K,           "enc.blk.%d.attn_k" },
-            { LLM_TENSOR_ENC_ATTN_V,           "enc.blk.%d.attn_v" },
-            { LLM_TENSOR_ENC_ATTN_OUT,         "enc.blk.%d.attn_o" },
-            { LLM_TENSOR_ENC_ATTN_REL_B,       "enc.blk.%d.attn_rel_b" },
-            { LLM_TENSOR_ENC_FFN_NORM,         "enc.blk.%d.ffn_norm" },
-            { LLM_TENSOR_ENC_FFN_GATE,         "enc.blk.%d.ffn_gate" },
-            { LLM_TENSOR_ENC_FFN_DOWN,         "enc.blk.%d.ffn_down" },
-            { LLM_TENSOR_ENC_FFN_UP,           "enc.blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_JAIS,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-        },
-    },
-    {
-        LLM_ARCH_NEMOTRON,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_EXAONE,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
-            { LLM_TENSOR_OUTPUT,          "output" },
-            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
-            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
-            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
-            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
-            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
-            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
-            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
-            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
-            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
-            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
-        },
-    },
-    {
-        LLM_ARCH_RWKV6,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,                "token_embd" },
-            { LLM_TENSOR_TOKEN_EMBD_NORM,           "token_embd_norm" },
-            { LLM_TENSOR_OUTPUT_NORM,               "output_norm" },
-            { LLM_TENSOR_OUTPUT,                    "output" },
-            { LLM_TENSOR_ATTN_NORM,                 "blk.%d.attn_norm" },
-            { LLM_TENSOR_ATTN_NORM_2,               "blk.%d.attn_norm_2" },
-            { LLM_TENSOR_TIME_MIX_W1,               "blk.%d.time_mix_w1" },
-            { LLM_TENSOR_TIME_MIX_W2,               "blk.%d.time_mix_w2" },
-            { LLM_TENSOR_TIME_MIX_LERP_X,           "blk.%d.time_mix_lerp_x" },
-            { LLM_TENSOR_TIME_MIX_LERP_W,           "blk.%d.time_mix_lerp_w" },
-            { LLM_TENSOR_TIME_MIX_LERP_K,           "blk.%d.time_mix_lerp_k" },
-            { LLM_TENSOR_TIME_MIX_LERP_V,           "blk.%d.time_mix_lerp_v" },
-            { LLM_TENSOR_TIME_MIX_LERP_R,           "blk.%d.time_mix_lerp_r" },
-            { LLM_TENSOR_TIME_MIX_LERP_G,           "blk.%d.time_mix_lerp_g" },
-            { LLM_TENSOR_TIME_MIX_FIRST,            "blk.%d.time_mix_first" },
-            { LLM_TENSOR_TIME_MIX_DECAY,            "blk.%d.time_mix_decay" },
-            { LLM_TENSOR_TIME_MIX_DECAY_W1,         "blk.%d.time_mix_decay_w1" },
-            { LLM_TENSOR_TIME_MIX_DECAY_W2,         "blk.%d.time_mix_decay_w2" },
-            { LLM_TENSOR_TIME_MIX_KEY,              "blk.%d.time_mix_key" },
-            { LLM_TENSOR_TIME_MIX_VALUE,            "blk.%d.time_mix_value" },
-            { LLM_TENSOR_TIME_MIX_RECEPTANCE,       "blk.%d.time_mix_receptance" },
-            { LLM_TENSOR_TIME_MIX_GATE,             "blk.%d.time_mix_gate" },
-            { LLM_TENSOR_TIME_MIX_LN,               "blk.%d.time_mix_ln" },
-            { LLM_TENSOR_TIME_MIX_OUTPUT,           "blk.%d.time_mix_output" },
-            { LLM_TENSOR_CHANNEL_MIX_LERP_K,        "blk.%d.channel_mix_lerp_k" },
-            { LLM_TENSOR_CHANNEL_MIX_LERP_R,        "blk.%d.channel_mix_lerp_r" },
-            { LLM_TENSOR_CHANNEL_MIX_KEY,           "blk.%d.channel_mix_key" },
-            { LLM_TENSOR_CHANNEL_MIX_VALUE,         "blk.%d.channel_mix_value" },
-            { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE,    "blk.%d.channel_mix_receptance" },
-        },
-    },
-    {
-        LLM_ARCH_UNKNOWN,
-        {
-            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
-        },
-    },
-};
-
-static llm_arch llm_arch_from_string(const std::string & name) {
-    for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT
-        if (kv.second == name) {
-            return kv.first;
-        }
-    }
-
-    return LLM_ARCH_UNKNOWN;
-}
-
-// helper to handle gguf constants
-// usage:
-//
-//   const auto tn = LLM_TN(LLM_ARCH_LLAMA);
-//
-//   std::string name = tn(LLM_TENSOR_OUTPUT);                     -> "output"
-//   std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias");         -> "token_embd.bias"
-//   std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3);     -> "blk.3.attn_norm.weight"
-//
-struct LLM_TN {
-    LLM_TN(llm_arch arch) : arch(arch) {}
-
-    llm_arch arch;
-
-    std::string operator()(llm_tensor tensor) const {
-        if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
-            return "__missing__";
-        }
-        return LLM_TENSOR_NAMES.at(arch).at(tensor);
-    }
-
-    std::string operator()(llm_tensor tensor, const std::string & suffix) const {
-        if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
-            return "__missing__";
-        }
-        return LLM_TENSOR_NAMES.at(arch).at(tensor) + "." + suffix;
-    }
-
-    std::string operator()(llm_tensor tensor, int bid) const {
-        if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
-            return "__missing__";
-        }
-        return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid);
-    }
-
-    std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const {
-        if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
-            return "__missing__";
-        }
-        return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid) + "." + suffix;
-    }
-
-    std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const {
-        if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
-            return "__missing__";
-        }
-        return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix;
-    }
-};
-
-//
-// gguf helpers
-//
-
-static const std::map LLAMA_ROPE_SCALING_TYPES = {
-    { LLAMA_ROPE_SCALING_TYPE_NONE,   "none"   },
-    { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
-    { LLAMA_ROPE_SCALING_TYPE_YARN,   "yarn"   },
-};
-
-static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
-    for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
-        if (kv.second == name) {
-            return (llama_rope_scaling_type) kv.first;
-        }
-    }
-
-    return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
-}
-
-static std::string gguf_data_to_str(enum gguf_type type, const void * data, int i) {
-    switch (type) {
-        case GGUF_TYPE_UINT8:   return std::to_string(((const uint8_t  *)data)[i]);
-        case GGUF_TYPE_INT8:    return std::to_string(((const int8_t   *)data)[i]);
-        case GGUF_TYPE_UINT16:  return std::to_string(((const uint16_t *)data)[i]);
-        case GGUF_TYPE_INT16:   return std::to_string(((const int16_t  *)data)[i]);
-        case GGUF_TYPE_UINT32:  return std::to_string(((const uint32_t *)data)[i]);
-        case GGUF_TYPE_INT32:   return std::to_string(((const int32_t  *)data)[i]);
-        case GGUF_TYPE_UINT64:  return std::to_string(((const uint64_t *)data)[i]);
-        case GGUF_TYPE_INT64:   return std::to_string(((const int64_t  *)data)[i]);
-        case GGUF_TYPE_FLOAT32: return std::to_string(((const float    *)data)[i]);
-        case GGUF_TYPE_FLOAT64: return std::to_string(((const double   *)data)[i]);
-        case GGUF_TYPE_BOOL:    return ((const bool *)data)[i] ? "true" : "false";
-        default:                return format("unknown type %d", type);
-    }
-}
-
-static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
-    const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i);
-
-    switch (type) {
-        case GGUF_TYPE_STRING:
-            return gguf_get_val_str(ctx_gguf, i);
-        case GGUF_TYPE_ARRAY:
-            {
-                const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i);
-                int arr_n = gguf_get_arr_n(ctx_gguf, i);
-                const void * data = gguf_get_arr_data(ctx_gguf, i);
-                std::stringstream ss;
-                ss << "[";
-                for (int j = 0; j < arr_n; j++) {
-                    if (arr_type == GGUF_TYPE_STRING) {
-                        std::string val = gguf_get_arr_str(ctx_gguf, i, j);
-                        // escape quotes
-                        replace_all(val, "\\", "\\\\");
-                        replace_all(val, "\"", "\\\"");
-                        ss << '"' << val << '"';
-                    } else if (arr_type == GGUF_TYPE_ARRAY) {
-                        ss << "???";
-                    } else {
-                        ss << gguf_data_to_str(arr_type, data, j);
-                    }
-                    if (j < arr_n - 1) {
-                        ss << ", ";
-                    }
-                }
-                ss << "]";
-                return ss.str();
-            }
-        default:
-            return gguf_data_to_str(type, gguf_get_val_data(ctx_gguf, i), 0);
-    }
-}
-
-//
-// llama helpers
-//
-
-#if defined(_WIN32)
-static std::string llama_format_win_err(DWORD err) {
-    LPSTR buf;
-    size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-                                 NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);
-    if (!size) {
-        return "FormatMessageA failed";
-    }
-    std::string ret(buf, size);
-    LocalFree(buf);
-    return ret;
-}
-#endif
-
-template 
-struct no_init {
-    T value;
-    no_init() { /* do nothing */ }
-};
-
-struct llama_file {
-
-#if defined(_WIN32)
-    // use FILE * so we don't have to re-open the file to mmap
-    FILE * fp;
-    HANDLE fp_win32;
-    size_t size;
-
-private:
-    std::string GetErrorMessageWin32(DWORD error_code) const {
-        std::string ret;
-        LPSTR lpMsgBuf = NULL;
-        DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
-                                    NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
-        if (!bufLen) {
-            ret = format("Win32 error code: %s", error_code);
-        } else {
-            ret = lpMsgBuf;
-            LocalFree(lpMsgBuf);
-        }
-
-        return ret;
-    }
-
-public:
-
-    llama_file(const char * fname, const char * mode) {
-        fp = ggml_fopen(fname, mode);
-        if (fp == NULL) {
-            throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
-        }
-        fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp));
-        seek(0, SEEK_END);
-        size = tell();
-        seek(0, SEEK_SET);
-    }
-
-    size_t tell() const {
-        // SetFilePointerEx returns the current position when seeking relative 0 bytes
-        LARGE_INTEGER li;
-        li.QuadPart = 0;
-        BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT);
-        if (!ret) {
-            throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-        }
-
-        return li.QuadPart;
-    }
-
-    void seek(size_t offset, int whence) const {
-        // no need to convert SEEK_* to FILE_*. The enums are the same.
-        // Still, keep static asserts to avoid failures in the future.
-        static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN");
-        static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT");
-        static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END");
-
-        LARGE_INTEGER li;
-        li.QuadPart = offset;
-        BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence);
-        if (!ret) {
-            throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-        }
-    }
-
-    void read_raw(void * ptr, size_t len) const {
-        // On Win32 ReadFile is significant faster than fread which is again significant faster than std::fstream. Thus
-        // use the Win32 API to do file io instead of the C/C++ library functions.
-
-        // There are conditions under which ReadFile cannot read chunks >64MB.
-        // Thus split the operation into smaller chunks if len exceeds this limit.
-        size_t bytes_read = 0;
-        while (bytes_read < len) {
-            size_t chunk_size = std::min(len - bytes_read, 64*1024*1024);
-            DWORD chunk_read = 0;
-            BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL);
-            if (!result) {
-                throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-            }
-            if (chunk_read < chunk_size || chunk_read == 0) {
-                throw std::runtime_error("unexpectedly reached end of file");
-            }
-
-            bytes_read += chunk_read;
-        } ;
-    }
-
-    uint32_t read_u32() const {
-        uint32_t val;
-        read_raw(&val, sizeof(val));
-        return val;
-    }
-
-    void write_raw(const void * ptr, size_t len) const {
-        // There are conditions under which WriteFile cannot write chunks >64MB.
-        // Thus split the operation into smaller chunks if len exceeds this limit.
-        size_t bytes_written = 0;
-        while (bytes_written < len) {
-            size_t chunk_size = std::min(len - bytes_written, 64*1024*1024);
-            DWORD chunk_written = 0;
-            BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL);
-            if (!result) {
-                throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str()));
-            }
-            if (chunk_written < chunk_size || chunk_written == 0) {
-                throw std::runtime_error("unexpectedly failed to write bytes");
-            }
-
-            bytes_written += chunk_written;
-        }
-    }
-
-    void write_u32(std::uint32_t val) const {
-        write_raw(&val, sizeof(val));
-    }
-
-    ~llama_file() {
-        if (fp) {
-            std::fclose(fp);
-        }
-    }
-#else
-    // use FILE * so we don't have to re-open the file to mmap
-    FILE * fp;
-    size_t size;
-
-    llama_file(const char * fname, const char * mode) {
-        fp = ggml_fopen(fname, mode);
-        if (fp == NULL) {
-            throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno)));
-        }
-        seek(0, SEEK_END);
-        size = tell();
-        seek(0, SEEK_SET);
-    }
-
-    size_t tell() const {
-#ifdef _WIN32
-        __int64 ret = _ftelli64(fp);
-#else
-        long ret = std::ftell(fp);
-#endif
-        if (ret == -1) {
-            throw std::runtime_error(format("ftell error: %s", strerror(errno)));
-        }
-
-        return (size_t) ret;
-    }
-
-    void seek(size_t offset, int whence) const {
-#ifdef _WIN32
-        int ret = _fseeki64(fp, (__int64) offset, whence);
-#else
-        int ret = std::fseek(fp, (long) offset, whence);
-#endif
-        if (ret != 0) {
-            throw std::runtime_error(format("seek error: %s", strerror(errno)));
-        }
-    }
-
-    void read_raw(void * ptr, size_t len) const {
-        if (len == 0) {
-            return;
-        }
-        errno = 0;
-        std::size_t ret = std::fread(ptr, len, 1, fp);
-        if (ferror(fp)) {
-            throw std::runtime_error(format("read error: %s", strerror(errno)));
-        }
-        if (ret != 1) {
-            throw std::runtime_error("unexpectedly reached end of file");
-        }
-    }
-
-    uint32_t read_u32() const {
-        uint32_t ret;
-        read_raw(&ret, sizeof(ret));
-        return ret;
-    }
-
-    void write_raw(const void * ptr, size_t len) const {
-        if (len == 0) {
-            return;
-        }
-        errno = 0;
-        size_t ret = std::fwrite(ptr, len, 1, fp);
-        if (ret != 1) {
-            throw std::runtime_error(format("write error: %s", strerror(errno)));
-        }
-    }
-
-    void write_u32(std::uint32_t val) const {
-        write_raw(&val, sizeof(val));
-    }
-
-    ~llama_file() {
-        if (fp) {
-            std::fclose(fp);
-        }
-    }
-#endif
-};
-using llama_files = std::vector>;
-
-struct llama_mmap {
-    void * addr;
-    size_t size;
-
-    llama_mmap(const llama_mmap &) = delete;
-
-#ifdef _POSIX_MAPPED_FILES
-    static constexpr bool SUPPORTED = true;
-
-    // list of mapped fragments (first_offset, last_offset)
-    std::vector> mapped_fragments;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) {
-        size = file->size;
-        int fd = fileno(file->fp);
-        int flags = MAP_SHARED;
-        // prefetch/readahead impairs performance on NUMA systems
-        if (numa)  { prefetch = 0; }
-#ifdef __linux__
-        // advise the kernel to read the file sequentially (increases readahead)
-        if (posix_fadvise(fd, 0, 0, POSIX_FADV_SEQUENTIAL)) {
-            LLAMA_LOG_WARN("warning: posix_fadvise(.., POSIX_FADV_SEQUENTIAL) failed: %s\n",
-                    strerror(errno));
-        }
-        if (prefetch) { flags |= MAP_POPULATE; }
-#endif
-        addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
-        if (addr == MAP_FAILED) { // NOLINT
-            throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
-        }
-
-        if (prefetch > 0) {
-            // advise the kernel to preload the mapped memory
-            if (posix_madvise(addr, std::min(file->size, prefetch), POSIX_MADV_WILLNEED)) {
-                LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
-                        strerror(errno));
-            }
-        }
-        if (numa) {
-            // advise the kernel not to use readahead
-            // (because the next page might not belong on the same node)
-            if (posix_madvise(addr, file->size, POSIX_MADV_RANDOM)) {
-                LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
-                        strerror(errno));
-            }
-        }
-
-        // initialize list of mapped_fragments
-        mapped_fragments.emplace_back(0, file->size);
-    }
-
-    static void align_range(size_t * first, size_t * last, size_t page_size) {
-        // align first to the next page
-        size_t offset_in_page = *first & (page_size - 1);
-        size_t offset_to_page = offset_in_page == 0 ? 0 : page_size - offset_in_page;
-        *first += offset_to_page;
-
-        // align last to the previous page
-        *last = *last & ~(page_size - 1);
-
-        if (*last <= *first) {
-            *last = *first;
-        }
-    }
-
-    // partially unmap the file in the range [first, last)
-    void unmap_fragment(size_t first, size_t last) {
-        // note: this function must not be called multiple times with overlapping ranges
-        // otherwise, there is a risk of invalidating addresses that have been repurposed for other mappings
-        int page_size = sysconf(_SC_PAGESIZE);
-        align_range(&first, &last, page_size);
-        size_t len = last - first;
-
-        if (len == 0) {
-            return;
-        }
-
-        GGML_ASSERT(first % page_size == 0);
-        GGML_ASSERT(last % page_size == 0);
-        GGML_ASSERT(last > first);
-
-        void * next_page_start = (uint8_t *) addr + first;
-
-        // unmap the range
-        if (munmap(next_page_start, len)) {
-            LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
-        }
-
-        // update the list of mapped fragments to avoid unmapping the same range again in the destructor
-        std::vector> new_mapped_fragments;
-        for (const auto & frag : mapped_fragments) {
-            if (frag.first < first && frag.second > last) {
-                // the range is in the middle of the fragment, split it
-                new_mapped_fragments.emplace_back(frag.first, first);
-                new_mapped_fragments.emplace_back(last, frag.second);
-            } else if (frag.first < first && frag.second > first) {
-                // the range starts in the middle of the fragment
-                new_mapped_fragments.emplace_back(frag.first, first);
-            } else if (frag.first < last && frag.second > last) {
-                // the range ends in the middle of the fragment
-                new_mapped_fragments.emplace_back(last, frag.second);
-            } else if (frag.first >= first && frag.second <= last) {
-                // the range covers the entire fragment
-            } else {
-                // the range is outside the fragment
-                new_mapped_fragments.push_back(frag);
-            }
-        }
-        mapped_fragments = std::move(new_mapped_fragments);
-    }
-
-    ~llama_mmap() {
-        for (const auto & frag : mapped_fragments) {
-            if (munmap((char *) addr + frag.first, frag.second - frag.first)) {
-                LLAMA_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
-            }
-        }
-    }
-#elif defined(_WIN32)
-    static constexpr bool SUPPORTED = true;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1, bool numa = false) {
-        GGML_UNUSED(numa);
-
-        size = file->size;
-
-        HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
-
-        HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
-
-        if (hMapping == NULL) {
-            DWORD error = GetLastError();
-            throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
-        }
-
-        addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
-        DWORD error = GetLastError();
-        CloseHandle(hMapping);
-
-        if (addr == NULL) {
-            throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str()));
-        }
-
-        if (prefetch > 0) {
-#if _WIN32_WINNT >= 0x602
-            // PrefetchVirtualMemory is only present on Windows 8 and above, so we dynamically load it
-            BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG);
-            HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll");
-
-            // may fail on pre-Windows 8 systems
-            pPrefetchVirtualMemory = reinterpret_cast (GetProcAddress(hKernel32, "PrefetchVirtualMemory"));
-
-            if (pPrefetchVirtualMemory) {
-                // advise the kernel to preload the mapped memory
-                WIN32_MEMORY_RANGE_ENTRY range;
-                range.VirtualAddress = addr;
-                range.NumberOfBytes = (SIZE_T) std::min(size, prefetch);
-                if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) {
-                    LLAMA_LOG_WARN("warning: PrefetchVirtualMemory failed: %s\n",
-                            llama_format_win_err(GetLastError()).c_str());
-                }
-            }
-#else
-            throw std::runtime_error("PrefetchVirtualMemory unavailable");
-#endif
-        }
-    }
-
-    void unmap_fragment(size_t first, size_t last) {
-        // not supported
-        GGML_UNUSED(first);
-        GGML_UNUSED(last);
-    }
-
-    ~llama_mmap() {
-        if (!UnmapViewOfFile(addr)) {
-            LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n",
-                    llama_format_win_err(GetLastError()).c_str());
-        }
-    }
-#else
-    static constexpr bool SUPPORTED = false;
-
-    llama_mmap(struct llama_file * file, size_t prefetch = -1, bool numa = false) {
-        GGML_UNUSED(file);
-        GGML_UNUSED(prefetch);
-        GGML_UNUSED(numa);
-
-        throw std::runtime_error("mmap not supported");
-    }
-
-    void unmap_fragment(size_t first, size_t last) {
-        GGML_UNUSED(first);
-        GGML_UNUSED(last);
-
-        throw std::runtime_error("mmap not supported");
-    }
-#endif
-};
-using llama_mmaps = std::vector>;
-
-// Represents some region of memory being locked using mlock or VirtualLock;
-// will automatically unlock on destruction.
-struct llama_mlock {
-    void * addr = NULL;
-    size_t size = 0;
-
-    bool failed_already = false;
-
-    llama_mlock() {}
-    llama_mlock(const llama_mlock &) = delete;
-
-    ~llama_mlock() {
-        if (size) {
-            raw_unlock(addr, size);
-        }
-    }
-
-    void init(void * ptr) {
-        GGML_ASSERT(addr == NULL && size == 0); // NOLINT
-        addr = ptr;
-    }
-
-    void grow_to(size_t target_size) {
-        GGML_ASSERT(addr);
-        if (failed_already) {
-            return;
-        }
-        size_t granularity = lock_granularity();
-        target_size = (target_size + granularity - 1) & ~(granularity - 1);
-        if (target_size > size) {
-            if (raw_lock((uint8_t *) addr + size, target_size - size)) {
-                size = target_size;
-            } else {
-                failed_already = true;
-            }
-        }
-    }
-
-#ifdef _POSIX_MEMLOCK_RANGE
-    static constexpr bool SUPPORTED = true;
-
-    static size_t lock_granularity() {
-        return (size_t) sysconf(_SC_PAGESIZE);
-    }
-
-    #ifdef __APPLE__
-        #define MLOCK_SUGGESTION \
-            "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \
-            "decreasing 'vm.global_no_user_wire_amount'.  Also try increasing RLIMIT_MEMLOCK (ulimit -l).\n"
-    #else
-        #define MLOCK_SUGGESTION \
-            "Try increasing RLIMIT_MEMLOCK ('ulimit -l' as root).\n"
-    #endif
-
-    bool raw_lock(const void * addr, size_t size) const {
-        if (!mlock(addr, size)) {
-            return true;
-        }
-
-        char* errmsg = std::strerror(errno);
-        bool suggest = (errno == ENOMEM);
-
-        // Check if the resource limit is fine after all
-        struct rlimit lock_limit;
-        if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) {
-            suggest = false;
-        }
-        if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) {
-            suggest = false;
-        }
-
-        LLAMA_LOG_WARN("warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s",
-                size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : "");
-        return false;
-    }
-
-    #undef MLOCK_SUGGESTION
-
-    static void raw_unlock(void * addr, size_t size) {
-        if (munlock(addr, size)) {
-            LLAMA_LOG_WARN("warning: failed to munlock buffer: %s\n", std::strerror(errno));
-        }
-    }
-#elif defined(_WIN32)
-    static constexpr bool SUPPORTED = true;
-
-    static size_t lock_granularity() {
-        SYSTEM_INFO si;
-        GetSystemInfo(&si);
-        return (size_t) si.dwPageSize;
-    }
-
-    bool raw_lock(void * ptr, size_t len) const {
-        for (int tries = 1; ; tries++) {
-            if (VirtualLock(ptr, len)) {
-                return true;
-            }
-            if (tries == 2) {
-                LLAMA_LOG_WARN("warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n",
-                    len, size, llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-
-            // It failed but this was only the first try; increase the working
-            // set size and try again.
-            SIZE_T min_ws_size, max_ws_size;
-            if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) {
-                LLAMA_LOG_WARN("warning: GetProcessWorkingSetSize failed: %s\n",
-                        llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-            // Per MSDN: "The maximum number of pages that a process can lock
-            // is equal to the number of pages in its minimum working set minus
-            // a small overhead."
-            // Hopefully a megabyte is enough overhead:
-            size_t increment = len + 1048576;
-            // The minimum must be <= the maximum, so we need to increase both:
-            min_ws_size += increment;
-            max_ws_size += increment;
-            if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) {
-                LLAMA_LOG_WARN("warning: SetProcessWorkingSetSize failed: %s\n",
-                        llama_format_win_err(GetLastError()).c_str());
-                return false;
-            }
-        }
-    }
-
-    static void raw_unlock(void * ptr, size_t len) {
-        if (!VirtualUnlock(ptr, len)) {
-            LLAMA_LOG_WARN("warning: failed to VirtualUnlock buffer: %s\n",
-                    llama_format_win_err(GetLastError()).c_str());
-        }
-    }
-#else
-    static constexpr bool SUPPORTED = false;
-
-    static size_t lock_granularity() {
-        return (size_t) 65536;
-    }
-
-    bool raw_lock(const void * addr, size_t len) const {
-        LLAMA_LOG_WARN("warning: mlock not supported on this system\n");
-        return false;
-    }
-
-    static void raw_unlock(const void * addr, size_t len) {}
-#endif
-};
-using llama_mlocks = std::vector>;
-
-// NOTE: avoid ever using this except for building the token_to_piece caches
-static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
-    std::string piece;
-    piece.resize(piece.capacity());  // using string internal cache
-    const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
-    if (n_chars < 0) {
-        piece.resize(-n_chars);
-        int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special);
-        GGML_ASSERT(check == -n_chars);
-    }
-    else {
-        piece.resize(n_chars);
-    }
-
-    return piece;
-}
-
-static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) {
-    ggml_backend_buffer_type_t buft = nullptr;
-
-#if defined(GGML_USE_CUDA)
-    // host buffers should only be used when data is expected to be copied to/from the GPU
-    if (host_buffer) {
-        buft = ggml_backend_cuda_host_buffer_type();
-    }
-#elif defined(GGML_USE_SYCL)
-    if (host_buffer) {
-        buft = ggml_backend_sycl_host_buffer_type();
-    }
-#elif defined(GGML_USE_CPU_HBM)
-    buft = ggml_backend_cpu_hbm_buffer_type();
-#elif defined(GGML_USE_VULKAN)
-    if (host_buffer) {
-        buft = ggml_backend_vk_host_buffer_type();
-    }
-#endif
-
-    if (buft == nullptr) {
-        buft = ggml_backend_cpu_buffer_type();
-    }
-    return buft;
-
-    GGML_UNUSED(host_buffer);
-}
-
-//
-// globals
-//
-
-struct llama_state {
-    llama_state() {
-#ifdef GGML_USE_METAL
-        ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data);
-#elif defined(GGML_USE_CUDA)
-        ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data);
-#elif defined(GGML_USE_CANN)
-        ggml_backend_cann_log_set_callback(log_callback, log_callback_user_data);
-#endif
-    }
-
-    // We save the log callback globally
-    ggml_log_callback log_callback = llama_log_callback_default;
-    void * log_callback_user_data = nullptr;
-};
-
-static llama_state g_state;
-
-// available llama models
-enum e_model {
-    MODEL_UNKNOWN,
-    MODEL_14M,
-    MODEL_17M,
-    MODEL_22M,
-    MODEL_33M,
-    MODEL_60M,
-    MODEL_70M,
-    MODEL_80M,
-    MODEL_109M,
-    MODEL_137M,
-    MODEL_160M,
-    MODEL_220M,
-    MODEL_250M,
-    MODEL_270M,
-    MODEL_335M,
-    MODEL_410M,
-    MODEL_450M,
-    MODEL_770M,
-    MODEL_780M,
-    MODEL_0_5B,
-    MODEL_1B,
-    MODEL_1_3B,
-    MODEL_1_4B,
-    MODEL_1_6B,
-    MODEL_2B,
-    MODEL_2_8B,
-    MODEL_3B,
-    MODEL_4B,
-    MODEL_6B,
-    MODEL_6_9B,
-    MODEL_7B,
-    MODEL_8B,
-    MODEL_9B,
-    MODEL_11B,
-    MODEL_12B,
-    MODEL_13B,
-    MODEL_14B,
-    MODEL_15B,
-    MODEL_16B,
-    MODEL_20B,
-    MODEL_30B,
-    MODEL_34B,
-    MODEL_35B,
-    MODEL_40B,
-    MODEL_65B,
-    MODEL_70B,
-    MODEL_236B,
-    MODEL_314B,
-    MODEL_SMALL,
-    MODEL_MEDIUM,
-    MODEL_LARGE,
-    MODEL_XL,
-    MODEL_A2_7B,
-    MODEL_8x7B,
-    MODEL_8x22B,
-    MODEL_16x12B,
-    MODEL_10B_128x3_66B,
-    MODEL_57B_A14B,
-    MODEL_27B,
-};
-
-static const size_t kiB = 1024;
-static const size_t MiB = 1024*kiB;
-static const size_t GiB = 1024*MiB;
-
-struct llama_hparams {
-    bool vocab_only;
-    bool rope_finetuned;
-    bool use_par_res;
-
-    uint32_t n_vocab;
-    uint32_t n_ctx_train; // context size the model was trained on
-    uint32_t n_embd;
-    uint32_t n_layer;
-    uint32_t n_rot;
-    uint32_t n_swa = 0; // sliding window attention (SWA)
-    uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
-    uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
-    uint32_t n_expert = 0;
-    uint32_t n_expert_used = 0;
-    uint32_t n_vocab_type = 0; // for BERT-style token types
-    uint32_t n_rel_attn_bkts = 0;
-
-    std::array n_head_arr;
-    std::array n_head_kv_arr;
-    std::array n_ff_arr;
-
-    uint32_t n_layer_dense_lead = 0;
-    uint32_t n_lora_q = 0;
-    uint32_t n_lora_kv = 0;
-    uint32_t n_ff_exp = 0;
-    uint32_t n_ff_shexp = 0;
-    uint32_t n_expert_shared = 0;
-    float    expert_weights_scale = 0.0;
-
-    float f_norm_eps;
-    float f_norm_rms_eps;
-
-    float f_attn_logit_softcapping = 50.0f;
-    float f_final_logit_softcapping = 30.0f;
-
-    // for RWKV
-    uint32_t rescale_every_n_layers = 0;
-    uint32_t time_mix_extra_dim = 0;
-    uint32_t time_decay_extra_dim = 0;
-    uint32_t wkv_head_size = 0;
-
-    float    rope_attn_factor = 1.0f;
-    float    rope_freq_base_train;
-    float    rope_freq_scale_train;
-    uint32_t n_ctx_orig_yarn;
-    float    rope_yarn_log_mul;
-
-    // for State Space Models
-    uint32_t ssm_d_conv  = 0;
-    uint32_t ssm_d_inner = 0;
-    uint32_t ssm_d_state = 0;
-    uint32_t ssm_dt_rank = 0;
-    bool ssm_dt_b_c_rms = false;
-
-    float f_clamp_kqv      = 0.0f;
-    float f_max_alibi_bias = 0.0f;
-    float f_logit_scale    = 0.0f;
-
-    bool causal_attn   = true;
-    bool use_alibi     = false;
-    bool attn_soft_cap = false;
-
-    // needed by encoder-decoder models (e.g. T5, FLAN-T5)
-    // ref: https://github.com/ggerganov/llama.cpp/pull/8141
-    llama_token dec_start_token_id = -1;
-
-    enum llama_pooling_type      pooling_type            = LLAMA_POOLING_TYPE_NONE;
-    enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
-    enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
-
-    bool operator!=(const llama_hparams & other) const {
-        if (this->vocab_only    != other.vocab_only)    return true;
-        if (this->n_vocab       != other.n_vocab)       return true;
-        if (this->n_ctx_train   != other.n_ctx_train)   return true;
-        if (this->n_embd        != other.n_embd)        return true;
-        if (this->n_layer       != other.n_layer)       return true;
-        if (this->n_rot         != other.n_rot)         return true;
-        if (this->n_swa         != other.n_swa)         return true;
-        if (this->n_embd_head_k != other.n_embd_head_k) return true;
-        if (this->n_embd_head_v != other.n_embd_head_v) return true;
-        if (this->n_expert      != other.n_expert)      return true;
-        if (this->n_expert_used != other.n_expert_used) return true;
-
-        if (this->n_head_arr    != other.n_head_arr)    return true;
-        if (this->n_head_kv_arr != other.n_head_kv_arr) return true;
-        if (this->n_ff_arr      != other.n_ff_arr)      return true;
-
-        if (this->n_rel_attn_bkts    != other.n_rel_attn_bkts)    return true;
-        if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
-        if (this->n_lora_q           != other.n_lora_q)           return true;
-        if (this->n_lora_kv          != other.n_lora_kv)          return true;
-        if (this->n_ff_exp           != other.n_ff_exp)           return true;
-        if (this->n_ff_shexp         != other.n_ff_shexp)         return true;
-        if (this->n_expert_shared    != other.n_expert_shared)    return true;
-
-        if (this->rope_finetuned  != other.rope_finetuned)  return true;
-        if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
-
-        if (this->ssm_d_conv  != other.ssm_d_conv)  return true;
-        if (this->ssm_d_inner != other.ssm_d_inner) return true;
-        if (this->ssm_d_state != other.ssm_d_state) return true;
-        if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
-        if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
-
-        if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true;
-        if (this->time_mix_extra_dim     != other.time_mix_extra_dim)     return true;
-        if (this->time_decay_extra_dim   != other.time_decay_extra_dim)   return true;
-        if (this->wkv_head_size          != other.wkv_head_size)          return true;
-
-        if (this->dec_start_token_id != other.dec_start_token_id) return true;
-
-        const float EPSILON = 1e-9f;
-
-        if (!is_float_close(this->f_norm_eps,            other.f_norm_eps,            EPSILON)) return true;
-        if (!is_float_close(this->f_norm_rms_eps,        other.f_norm_rms_eps,        EPSILON)) return true;
-        if (!is_float_close(this->rope_attn_factor,      other.rope_attn_factor,      EPSILON)) return true;
-        if (!is_float_close(this->rope_freq_base_train,  other.rope_freq_base_train,  EPSILON)) return true;
-        if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
-        if (!is_float_close(this->expert_weights_scale,  other.expert_weights_scale,  EPSILON)) return true;
-        if (!is_float_close(this->rope_yarn_log_mul,     other.rope_yarn_log_mul,     EPSILON)) return true;
-
-        return false;
-    }
-
-    uint32_t n_head(uint32_t il = 0) const {
-        if (il < n_layer) {
-            return n_head_arr[il];
-        }
-
-        GGML_ABORT("fatal error");
-    }
-
-    uint32_t n_head_kv(uint32_t il = 0) const {
-        if (il < n_layer) {
-            return n_head_kv_arr[il];
-        }
-
-        GGML_ABORT("fatal error");
-    }
-
-    uint32_t n_ff(uint32_t il = 0) const {
-        if (il < n_layer) {
-            return n_ff_arr[il];
-        }
-
-        GGML_ABORT("fatal error");
-    }
-
-    uint32_t n_gqa(uint32_t il = 0) const {
-        const uint32_t n_head    = this->n_head(il);
-        const uint32_t n_head_kv = this->n_head_kv(il);
-
-        if (n_head_kv == 0) {
-            return 0;
-        }
-
-        return n_head/n_head_kv;
-    }
-
-    uint32_t n_embd_k_gqa(uint32_t il = 0) const { // dimension of key embeddings across all k-v heads
-        const uint32_t n_head_kv = this->n_head_kv(il);
-
-        return n_embd_head_k * n_head_kv;
-    }
-
-    uint32_t n_embd_v_gqa(uint32_t il = 0) const { // dimension of value embeddings across all k-v heads
-        const uint32_t n_head_kv = this->n_head_kv(il);
-
-        return n_embd_head_v * n_head_kv;
-    }
-
-    uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings
-        // corresponds to Mamba's conv_states size or RWKV's token_shift states size
-        if (wkv_head_size != 0) {
-            // for RWKV models
-            return 2 * n_embd;
-        } else {
-            // TODO: maybe support other convolution strides than 1
-            // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
-            return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
-        }
-    }
-
-    uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings
-        if (wkv_head_size != 0) {
-            // corresponds to RWKV's wkv_states size
-            return n_embd * wkv_head_size;
-        } else {
-            // corresponds to Mamba's ssm_states size
-            return ssm_d_state * ssm_d_inner;
-        }
-    }
-};
-
-static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable");
-
-struct llama_cparams {
-    uint32_t n_ctx;           // context size used during inference
-    uint32_t n_batch;
-    uint32_t n_ubatch;
-    uint32_t n_seq_max;
-    int      n_threads;       // number of threads to use for generation
-    int      n_threads_batch; // number of threads to use for batch processing
-
-    float rope_freq_base;
-    float rope_freq_scale;
-
-    uint32_t n_ctx_orig_yarn;
-    // These hyperparameters are not exposed in GGUF, because all
-    // existing YaRN models use the same values for them.
-    float yarn_ext_factor;
-    float yarn_attn_factor;
-    float yarn_beta_fast;
-    float yarn_beta_slow;
-    float defrag_thold;
-
-    bool embeddings;
-    bool causal_attn;
-    bool offload_kqv;
-    bool flash_attn;
-
-    enum llama_pooling_type pooling_type;
-
-    ggml_backend_sched_eval_callback cb_eval;
-    void * cb_eval_user_data;
-};
-
-// TODO: separate into "llama_layer_enc" and "llama_layer_dec"
-struct llama_layer {
-    // normalization
-    struct ggml_tensor * attn_norm;
-    struct ggml_tensor * attn_norm_b;
-    struct ggml_tensor * attn_norm_2;
-    struct ggml_tensor * attn_norm_2_b;
-    struct ggml_tensor * attn_q_norm;
-    struct ggml_tensor * attn_q_norm_b;
-    struct ggml_tensor * attn_k_norm;
-    struct ggml_tensor * attn_k_norm_b;
-    struct ggml_tensor * attn_out_norm;
-    struct ggml_tensor * attn_out_norm_b;
-    struct ggml_tensor * attn_q_a_norm;
-    struct ggml_tensor * attn_kv_a_norm;
-    struct ggml_tensor * attn_sub_norm;
-    struct ggml_tensor * attn_post_norm;
-    struct ggml_tensor * ffn_sub_norm;
-    struct ggml_tensor * attn_norm_cross;
-    struct ggml_tensor * attn_norm_enc;
-
-    // attention
-    struct ggml_tensor * wq;
-    struct ggml_tensor * wk;
-    struct ggml_tensor * wv;
-    struct ggml_tensor * wo;
-    struct ggml_tensor * wqkv;
-    struct ggml_tensor * wq_a;
-    struct ggml_tensor * wq_b;
-    struct ggml_tensor * wkv_a_mqa;
-    struct ggml_tensor * wkv_b;
-    struct ggml_tensor * wq_cross;
-    struct ggml_tensor * wk_cross;
-    struct ggml_tensor * wv_cross;
-    struct ggml_tensor * wo_cross;
-    struct ggml_tensor * wq_enc;
-    struct ggml_tensor * wk_enc;
-    struct ggml_tensor * wv_enc;
-    struct ggml_tensor * wo_enc;
-
-    // attention bias
-    struct ggml_tensor * bq;
-    struct ggml_tensor * bk;
-    struct ggml_tensor * bv;
-    struct ggml_tensor * bo;
-    struct ggml_tensor * bqkv;
-
-    // relative position bias
-    struct ggml_tensor * attn_rel_b;
-    struct ggml_tensor * attn_rel_b_enc;
-    struct ggml_tensor * attn_rel_b_cross;
-
-    // normalization
-    struct ggml_tensor * ffn_norm;
-    struct ggml_tensor * ffn_norm_b;
-    struct ggml_tensor * ffn_post_norm;
-    struct ggml_tensor * layer_out_norm;
-    struct ggml_tensor * layer_out_norm_b;
-    struct ggml_tensor * ffn_norm_exps;
-    struct ggml_tensor * ffn_norm_enc;
-
-    // ff
-    struct ggml_tensor * ffn_gate; // w1
-    struct ggml_tensor * ffn_down; // w2
-    struct ggml_tensor * ffn_up;   // w3
-    struct ggml_tensor * ffn_gate_enc;
-    struct ggml_tensor * ffn_down_enc;
-    struct ggml_tensor * ffn_up_enc;
-
-    // ff MoE
-    struct ggml_tensor * ffn_gate_inp;
-    struct ggml_tensor * ffn_gate_exps;
-    struct ggml_tensor * ffn_down_exps;
-    struct ggml_tensor * ffn_up_exps ;
-
-    // ff shared expert (shexp)
-    struct ggml_tensor * ffn_gate_inp_shexp;
-    struct ggml_tensor * ffn_gate_shexp;
-    struct ggml_tensor * ffn_down_shexp;
-    struct ggml_tensor * ffn_up_shexp;
-
-    // ff bias
-    struct ggml_tensor * ffn_gate_b = nullptr;
-    struct ggml_tensor * ffn_down_b = nullptr; // b2
-    struct ggml_tensor * ffn_up_b   = nullptr; // b3
-    struct ggml_tensor * ffn_act;
-
-    // mamba proj
-    struct ggml_tensor * ssm_in;
-    struct ggml_tensor * ssm_x;
-    struct ggml_tensor * ssm_dt;
-    struct ggml_tensor * ssm_out;
-
-    // mamba
-    struct ggml_tensor * ssm_conv1d;
-    struct ggml_tensor * ssm_a;
-    struct ggml_tensor * ssm_d;
-
-    // mamba bias
-    struct ggml_tensor * ssm_conv1d_b;
-    struct ggml_tensor * ssm_dt_b;
-
-    // rwkv
-    struct ggml_tensor * time_mix_w1;
-    struct ggml_tensor * time_mix_w2;
-    struct ggml_tensor * time_mix_lerp_x;
-    struct ggml_tensor * time_mix_lerp_w;
-    struct ggml_tensor * time_mix_lerp_k;
-    struct ggml_tensor * time_mix_lerp_v;
-    struct ggml_tensor * time_mix_lerp_r;
-    struct ggml_tensor * time_mix_lerp_g;
-
-    struct ggml_tensor * time_mix_first;
-    struct ggml_tensor * time_mix_decay;
-    struct ggml_tensor * time_mix_decay_w1;
-    struct ggml_tensor * time_mix_decay_w2;
-    struct ggml_tensor * time_mix_key;
-    struct ggml_tensor * time_mix_value;
-    struct ggml_tensor * time_mix_receptance;
-    struct ggml_tensor * time_mix_gate;
-
-    struct ggml_tensor * time_mix_ln;
-    struct ggml_tensor * time_mix_ln_b;
-    struct ggml_tensor * time_mix_output;
-
-    struct ggml_tensor * channel_mix_lerp_k;
-    struct ggml_tensor * channel_mix_lerp_r;
-
-    struct ggml_tensor * channel_mix_key;
-    struct ggml_tensor * channel_mix_receptance;
-    struct ggml_tensor * channel_mix_value;
-
-    // long rope factors
-    struct ggml_tensor * rope_long  = nullptr;
-    struct ggml_tensor * rope_short = nullptr;
-    struct ggml_tensor * rope_freqs = nullptr;
-
-    // bitnet scale
-    struct ggml_tensor * wq_scale;
-    struct ggml_tensor * wk_scale;
-    struct ggml_tensor * wv_scale;
-    struct ggml_tensor * wo_scale;
-    struct ggml_tensor * ffn_gate_scale;
-    struct ggml_tensor * ffn_up_scale;
-    struct ggml_tensor * ffn_down_scale;
-};
-
-// very similar to llama_batch,
-// but has more metadata about sequences
-struct llama_ubatch {
-    bool equal_seqs;
-    // TODO: whole_seqs for embeddings?
-
-    uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
-    uint32_t n_seq_tokens; // tokens per sequence
-    uint32_t n_seqs;
-
-    llama_token  *  token;    // [n_tokens]
-    float        *  embd;     // [n_embd, n_tokens]
-    llama_pos    *  pos;      // [n_tokens]
-    int32_t      *  n_seq_id; // [n_seqs]
-    llama_seq_id ** seq_id;   // [n_seqs]
-    int8_t       *  output;   // [n_tokens]
-};
-
-struct llama_kv_cell {
-    llama_pos pos   = -1;
-    llama_pos delta = 0;
-    int32_t   src   = -1; // used by recurrent state models to copy states
-    int32_t   tail  = -1;
-
-    std::set seq_id;
-
-    bool has_seq_id(const llama_seq_id & id) const {
-        return seq_id.find(id) != seq_id.end();
-    }
-
-    bool is_empty() const {
-        return seq_id.empty();
-    }
-
-    bool is_same_seq(const llama_kv_cell & other) const {
-        return seq_id == other.seq_id;
-    }
-};
-
-// ring-buffer of cached KV data
-struct llama_kv_cache {
-    bool has_shift = false;
-    bool do_defrag = false;
-    bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
-    bool v_trans   = true;  // the value tensor is transposed
-
-    // Note: The value of head isn't only used to optimize searching
-    // for a free KV slot. llama_decode_internal also uses it, so it
-    // cannot be freely changed after a slot has been allocated.
-    uint32_t head = 0;
-    uint32_t size = 0;
-    uint32_t used = 0; // used cells (i.e. at least one seq_id)
-
-    // computed before each graph build
-    uint32_t n = 0;
-
-    ggml_type type_k = GGML_TYPE_F16;
-    ggml_type type_v = GGML_TYPE_F16;
-
-    std::vector cells;
-
-    std::vector k_l; // per layer
-    std::vector v_l;
-
-    std::vector ctxs;
-    std::vector bufs;
-
-    size_t total_size() const {
-        size_t size = 0;
-        for (ggml_backend_buffer_t buf : bufs) {
-            size += ggml_backend_buffer_get_size(buf);
-        }
-        return size;
-    }
-
-    ~llama_kv_cache() {
-        for (struct ggml_context * ctx : ctxs) {
-            ggml_free(ctx);
-        }
-        for (ggml_backend_buffer_t buf : bufs) {
-            ggml_backend_buffer_free(buf);
-        }
-    }
-};
-
-struct llama_control_vector {
-    std::vector tensors; // per layer
-    std::vector ctxs;
-    std::vector bufs;
-
-    int32_t layer_start = -1;
-    int32_t layer_end   = -1;
-
-    struct ggml_tensor * tensor_for(int il) const {
-        if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
-            return nullptr;
-        }
-        return tensors[il];
-    }
-
-    struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int  il) const {
-        ggml_tensor * layer_dir = tensor_for(il);
-        if (layer_dir != nullptr) {
-            cur = ggml_add(ctx, cur, layer_dir);
-        }
-        return cur;
-    }
-
-    ~llama_control_vector() {
-        for (struct ggml_context * ctx : ctxs) {
-            ggml_free(ctx);
-        }
-        for (ggml_backend_buffer_t buf : bufs) {
-            ggml_backend_buffer_free(buf);
-        }
-    }
-};
-
-struct llama_model {
-    e_model     type  = MODEL_UNKNOWN;
-    llm_arch    arch  = LLM_ARCH_UNKNOWN;
-    llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
-
-    std::string name = "n/a";
-
-    llama_hparams hparams = {};
-    llama_vocab   vocab;
-
-    struct ggml_tensor * tok_embd;
-    struct ggml_tensor * type_embd;
-    struct ggml_tensor * pos_embd;
-    struct ggml_tensor * tok_norm;
-    struct ggml_tensor * tok_norm_b;
-
-    struct ggml_tensor * output_norm;
-    struct ggml_tensor * output_norm_b;
-    struct ggml_tensor * output;
-    struct ggml_tensor * output_b;
-    struct ggml_tensor * output_norm_enc;
-
-    std::vector layers;
-
-    llama_split_mode split_mode;
-    int main_gpu;
-    int n_gpu_layers;
-
-    std::vector rpc_servers;
-
-    // gguf metadata
-    std::unordered_map gguf_kv;
-
-    // layer -> buffer type mapping
-    struct layer_buft {
-        layer_buft() : buft_matrix(nullptr), buft(nullptr) {}
-        layer_buft(ggml_backend_buffer_type_t matrix) : buft_matrix(matrix), buft(matrix) {}
-        layer_buft(ggml_backend_buffer_type_t matrix, ggml_backend_buffer_type_t other) : buft_matrix(matrix), buft(other) {}
-
-        ggml_backend_buffer_type_t buft_matrix; // matrices only - used by split buffers and backends that support only matrix multiplication
-        ggml_backend_buffer_type_t buft;        // everything else
-    };
-
-    layer_buft buft_input;
-    layer_buft buft_output;
-    std::vector buft_layer;
-
-    // contexts where the model tensors metadata is stored
-    std::vector ctxs;
-
-    // the model memory buffers for the tensor data
-    std::vector bufs;
-
-    // model memory mapped files
-    llama_mmaps mappings;
-
-    // objects representing data potentially being locked in memory
-    llama_mlocks mlock_bufs;
-    llama_mlocks mlock_mmaps;
-
-    // for quantize-stats only
-    std::vector> tensors_by_name;
-
-    int64_t t_load_us = 0;
-    int64_t t_start_us = 0;
-
-    // keep track of loaded lora adapters
-    std::set lora_adapters;
-
-    ~llama_model() {
-        for (struct ggml_context * ctx : ctxs) {
-            ggml_free(ctx);
-        }
-        for (ggml_backend_buffer_t buf : bufs) {
-#ifdef GGML_USE_CUDA
-            if (ggml_backend_buffer_get_type(buf) == ggml_backend_cpu_buffer_type()) {
-                ggml_backend_cuda_unregister_host_buffer(ggml_backend_buffer_get_base(buf));
-            }
-#endif
-            ggml_backend_buffer_free(buf);
-        }
-        while (!lora_adapters.empty()) {
-            llama_lora_adapter_free(*lora_adapters.begin());
-        }
-    }
-};
-
-struct llama_sbatch_seq {
-    int32_t n_seq_id;
-    llama_seq_id * seq_id;
-    size_t offset;
-    size_t length;
-
-    // helper for smoother batch API transition -- can be deprecated in the future
-    llama_seq_id all_seq_id; // used if seq_id == NULL
-};
-
-// sequence-length-aware batch splitting
-struct llama_sbatch {
-    // tokens left in this batch
-    size_t n_tokens;
-
-    size_t n_embd;
-
-    bool logits_all; // TODO: remove once lctx.logits_all is removed too
-
-    // sorted indices into the batch
-    std::vector ids;
-    // batch indices of the output
-    std::vector out_ids;
-    std::vector seq;
-    const llama_batch * batch = nullptr;
-
-    // buffers for the ubatch
-    std::vector    ubatch_token;
-    std::vector          ubatch_embd;
-    std::vector      ubatch_pos;
-    std::vector        ubatch_n_seq_id;
-    std::vector ubatch_seq_id;
-    std::vector         ubatch_output;
-
-    llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false) {
-        // clear empty sequences
-        // the previous ubatch is assumed to be gone,
-        // so nothing should refer to values in these sequences anymore.
-        for (size_t i = seq.size(); i-- > 0;) {
-            if (seq[i].length == 0) {
-                seq.pop_back();
-            } else {
-                break;
-            }
-        }
-        ubatch_token.resize(!has_embd ? n_ubatch : 0);
-        ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
-        ubatch_pos.resize(n_ubatch);
-        ubatch_n_seq_id.resize(n_ubatch);
-        ubatch_seq_id.resize(n_ubatch);
-        ubatch_output.resize(n_ubatch);
-        llama_ubatch ubatch = {
-            /*equal_seqs   =*/ true,
-            /*n_tokens     =*/ 0,
-            /*n_seq_tokens =*/ 0,
-            /*n_seqs       =*/ 0,
-            /*token        =*/ !has_embd ? ubatch_token.data() : nullptr,
-            /*embd         =*/ has_embd  ? ubatch_embd.data()  : nullptr,
-            /*pos          =*/ ubatch_pos.data(),
-            /*n_seq_id     =*/ ubatch_n_seq_id.data(),
-            /*seq_id       =*/ ubatch_seq_id.data(),
-            /*output       =*/ ubatch_output.data(),
-        };
-        return ubatch;
-    }
-
-    void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
-        GGML_ASSERT(batch != nullptr);
-        GGML_ASSERT(length <= seq.length);
-        // Can only add sequences of equal lengths to a batch,
-        // otherwise it isn't clear to which sequence a token belongs
-        GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
-        GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
-        // NOTE: loops are separated for cache-friendliness
-        if (batch->token) {
-            if (ubatch.equal_seqs) {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
-                }
-            } else {
-                // simple split
-                ubatch.token = batch->token + seq.offset;
-            }
-        } else {
-            ubatch.token = nullptr;
-        }
-        if (batch->embd) {
-            if (ubatch.equal_seqs) {
-                for (size_t i = 0; i < length; ++i) {
-                    memcpy(
-                        ubatch.embd + n_embd * (ubatch.n_tokens + i),
-                        batch->embd + n_embd * ids[seq.offset + i],
-                        n_embd * sizeof(float)
-                    );
-                }
-            } else {
-                // simple split
-                ubatch.embd = batch->embd + (n_embd * seq.offset);
-            }
-        } else {
-            ubatch.embd = nullptr;
-        }
-        // from here on, the else branches are deprecated;
-        // they are helpers for smoother batch API transition
-        if (batch->pos) {
-            if (ubatch.equal_seqs) {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
-                }
-            } else {
-                // simple split
-                ubatch.pos = batch->pos + seq.offset;
-            }
-        } else {
-            for (size_t i = 0; i < length; ++i) {
-                llama_pos bi = ids[seq.offset + i];
-                ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
-            }
-        }
-        if (ubatch.equal_seqs) {
-            ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
-            if (seq.seq_id) {
-                ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
-            } else {
-                GGML_ASSERT(seq.n_seq_id == 1);
-                ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
-            }
-        } else {
-            // simple split
-            if (batch->n_seq_id) {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.n_seq_id = batch->n_seq_id + seq.offset;
-                }
-            } else {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
-                }
-            }
-            if (batch->seq_id) {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.seq_id = batch->seq_id + seq.offset;
-                }
-            } else {
-                for (size_t i = 0; i < length; ++i) {
-                    ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id;
-                }
-            }
-        }
-        if (logits_all) {
-            for (size_t i = 0; i < length; ++i) {
-                ubatch.output[ubatch.n_tokens + i] = 1;
-                out_ids.push_back(ids[seq.offset + i]);
-            }
-        } else if (batch->logits) {
-            if (ubatch.equal_seqs) {
-                for (size_t i = 0; i < length; ++i) {
-                    size_t id = ids[seq.offset + i];
-                    int8_t is_output = batch->logits[id];
-                    ubatch.output[ubatch.n_tokens + i] = is_output;
-                    if (is_output) { out_ids.push_back(id); }
-                }
-            } else {
-                // simple split
-                ubatch.output = batch->logits + seq.offset;
-                for (size_t i = 0; i < length; ++i) {
-                    if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
-                }
-            }
-        } else {
-            // only get last output
-            for (size_t i = 0; i < length; ++i) {
-                size_t id = ids[seq.offset + i];
-                int8_t is_last = id == ids.size() - 1;
-                ubatch.output[ubatch.n_tokens + i] = is_last;
-                if (is_last) { out_ids.push_back(id); }
-            }
-        }
-        if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
-            ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
-        }
-        ubatch.n_tokens += length;
-        ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
-        seq.offset += length;
-        seq.length -= length;
-        n_tokens -= length;
-        GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
-    }
-
-    // simple split, unknown number of sequences of unequal lengths
-    llama_ubatch split_simple(size_t n_ubatch) {
-        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
-        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
-        ubatch.equal_seqs = false;
-        if (!seq.empty()) {
-            llama_sbatch_seq & s = seq[0];
-            size_t length = s.length < n_ubatch ? s.length : n_ubatch;
-            GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
-            add_seq_to_ubatch(ubatch, s, length);
-        }
-        return ubatch;
-    }
-
-    // make batches of equal-length sequences
-    llama_ubatch split_equal(size_t n_ubatch) {
-        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
-        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
-        if (!seq.empty()) {
-            size_t length = 0;
-            size_t n_tokens_in_ubatch = 0;
-            GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
-            // smallest first, because it's easier to split this way;
-            // starting from the end to pop in constant time.
-            for (size_t i = seq.size(); i-- > 0;) {
-                llama_sbatch_seq & s = seq[i];
-                GGML_ASSERT(s.length > 0);
-                if (length == 0) {
-                    length = s.length < n_ubatch ? s.length : n_ubatch;
-                }
-                add_seq_to_ubatch(ubatch, s, length);
-                n_tokens_in_ubatch += length;
-                // shared prompts can't be mixed with any of their sequences,
-                // so it's safer to compute them in their own ubatch
-                if (s.n_seq_id > 1) { break; }
-                // stop when there isn't enough space for another sequence
-                if (length + n_tokens_in_ubatch > n_ubatch) { break; }
-            }
-        }
-        return ubatch;
-    }
-
-    // sequence-wise split
-    llama_ubatch split_seq(size_t n_ubatch) {
-        n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
-        llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
-        if (!seq.empty()) {
-            llama_sbatch_seq & s = seq[seq.size() - 1];
-            size_t length = s.length < n_ubatch ? s.length : n_ubatch;
-            GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
-            add_seq_to_ubatch(ubatch, s, length);
-        }
-        return ubatch;
-    }
-
-    void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
-        GGML_ASSERT(batch.n_tokens >= 0);
-        this->batch = &batch;
-        this->n_embd = n_embd;
-        this->logits_all = logits_all;
-
-        n_tokens = batch.n_tokens;
-        ids.resize(n_tokens);
-        out_ids.clear();
-        // TODO: reserve out_ids and seq
-
-        for (size_t i = 0; i < n_tokens; ++i) {
-            ids[i] = i;
-        }
-        if (simple_split) {
-            seq.resize(1);
-            llama_sbatch_seq & s = seq[0];
-            s.n_seq_id = 0;
-            s.seq_id = nullptr;
-            s.offset = 0;
-            s.length = n_tokens;
-            s.all_seq_id = batch.all_seq_id;
-            return;
-        }
-        std::sort(ids.begin(), ids.end(),
-            [&batch](size_t a, size_t b) {
-                int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
-                int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
-                // sort by seq_id, then by pos
-                if (n_seq_a == n_seq_b) {
-                    if (batch.seq_id) {
-                        for (int32_t i = 0; i < n_seq_a; ++i) {
-                            llama_seq_id seq_id_a = batch.seq_id[a][i];
-                            llama_seq_id seq_id_b = batch.seq_id[b][i];
-                            // smaller seq_ids go first
-                            if (seq_id_a != seq_id_b) {
-                                return seq_id_a < seq_id_b;
-                            }
-                        }
-                    }
-                    // when all else is equal, sort by pos
-                    if (batch.pos) {
-                        return batch.pos[a] < batch.pos[b];
-                    }
-                    // no pos, sort by id (assuming batch.all_pos_1 is positive)
-                    return a < b;
-                }
-                // shared prompts go first
-                return n_seq_a > n_seq_b;
-            }
-        );
-        // init seq
-        llama_sbatch_seq * last_seq = nullptr;
-
-        if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) {
-            for (size_t i = 0; i < n_tokens; ++i) {
-                const size_t bi = ids[i];
-                const int32_t n_seqs = batch.n_seq_id[bi];
-                llama_seq_id * seq_ids = batch.seq_id[bi];
-                if (last_seq != nullptr) {
-                    bool same = n_seqs == last_seq->n_seq_id;
-                    for (int32_t j = 0; same && j < n_seqs; ++j) {
-                        if (seq_ids[j] != last_seq->seq_id[j]) {
-                            same = false;
-                        }
-                    }
-                    if (same) {
-                        last_seq->length += 1;
-                        continue;
-                    }
-                }
-                llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id};
-                seq.push_back(new_seq);
-                last_seq = &seq.back();
-            }
-        } else {
-            llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id};
-            seq.push_back(new_seq);
-        }
-        // keep shared prompts first at the end, then sort by length descending.
-        std::sort(seq.begin(), seq.end(),
-            [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
-                if (a.n_seq_id == b.n_seq_id) {
-                    return a.length > b.length;
-                }
-                return a.n_seq_id < b.n_seq_id;
-            }
-        );
-    }
-};
-
-struct llama_context {
-    llama_context(const llama_model & model)
-        : model(model)
-        , t_start_us(model.t_start_us)
-        , t_load_us(model.t_load_us) {}
-
-    ~llama_context() {
-        ggml_backend_sched_free(sched);
-
-        for (ggml_backend_t backend : backends) {
-            ggml_backend_free(backend);
-        }
-
-        ggml_backend_buffer_free(buf_output);
-    }
-
-    const struct llama_model & model;
-
-    struct llama_cparams        cparams;
-    struct llama_sbatch         sbatch;
-    struct llama_kv_cache       kv_self;
-    struct llama_control_vector cvec;
-
-    std::unordered_map lora_adapters;
-
-    std::vector backends;
-#ifdef GGML_USE_METAL
-    ggml_backend_t backend_metal = nullptr;
-#endif
-#ifdef GGML_USE_BLAS
-    ggml_backend_t backend_blas = nullptr;
-#endif
-    ggml_backend_t backend_cpu = nullptr;
-
-    ggml_threadpool_t threadpool       = nullptr;
-    ggml_threadpool_t threadpool_batch = nullptr;
-
-    bool has_evaluated_once = false;
-
-    mutable int64_t t_start_us;
-    mutable int64_t t_load_us;
-    mutable int64_t t_p_eval_us = 0;
-    mutable int64_t t_eval_us   = 0;
-
-    mutable int64_t t_compute_start_us = 0;
-    mutable int64_t n_queued_tokens = 0;
-
-    mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
-    mutable int32_t n_eval   = 0; // number of eval calls
-
-    // host buffer for the model output (logits and embeddings)
-    ggml_backend_buffer_t buf_output = nullptr;
-
-    // decode output (2-dimensional array: [n_outputs][n_vocab])
-    size_t  logits_size = 0; // capacity (of floats) for logits
-    float * logits      = nullptr;
-
-    std::vector output_ids; // map batch token positions to ids of the logits and embd buffers
-    size_t  output_size = 0; // capacity (of tokens positions) for the output buffers
-    int32_t n_outputs   = 0; // number of actually-used outputs in the current ubatch or last logical batch
-
-    bool logits_all = false;
-
-    // embeddings output (2-dimensional array: [n_outputs][n_embd])
-    // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
-    size_t  embd_size = 0; // capacity (of floats) for embeddings
-    float * embd      = nullptr;
-
-    // sequence embeddings output (map of [n_embd] vectors)
-    // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
-    std::map> embd_seq;
-
-    // whether we are computing encoder output or decoder output
-    bool is_encoding = false;
-
-    // output of the encoder part of the encoder-decoder models
-    std::vector embd_enc;
-    std::vector> seq_ids_enc;
-
-    // memory buffers used to evaluate the model
-    std::vector buf_compute_meta;
-    ggml_backend_sched_t sched = nullptr;
-
-    ggml_abort_callback abort_callback      = nullptr;
-    void *              abort_callback_data = nullptr;
-
-    // input tensors
-    struct ggml_tensor * inp_tokens;      // I32 [n_batch]
-    struct ggml_tensor * inp_embd;        // F32 [n_embd, n_batch]
-    struct ggml_tensor * inp_pos;         // I32 [n_batch]
-    struct ggml_tensor * inp_out_ids;     // I32 [n_outputs]
-    struct ggml_tensor * inp_KQ_mask;     // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_K_shift;     // I32 [kv_size]
-    struct ggml_tensor * inp_mean;        // F32 [n_batch, n_batch]
-    struct ggml_tensor * inp_cls;         // I32 [n_batch]
-    struct ggml_tensor * inp_s_copy;      // I32 [kv_size]
-    struct ggml_tensor * inp_s_mask;      // F32 [1, n_kv]
-    struct ggml_tensor * inp_s_seq;       // I32 [n_kv, n_batch]
-    struct ggml_tensor * inp_pos_bucket;    // I32 [n_batch|n_kv, n_batch]
-    struct ggml_tensor * inp_embd_enc;      // F32 [n_embd, n_outputs_enc]
-    struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
-};
-
-struct llama_lora_weight {
-    struct ggml_tensor * a = nullptr;
-    struct ggml_tensor * b = nullptr;
-    llama_lora_weight() = default;
-    llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {}
-};
-
-struct llama_lora_adapter {
-    struct llama_model * base_model;
-    // map tensor name to lora_a_b
-    std::unordered_map ab_map;
-    std::vector ctxs;
-    std::vector bufs;
-
-    float alpha;
-
-    llama_lora_adapter(struct llama_model * base_model): base_model(base_model) {
-        base_model->lora_adapters.insert(this);
-    }
-
-    llama_lora_weight * get_weight(struct ggml_tensor * w) {
-        std::string name(w->name);
-        auto pos = ab_map.find(name);
-        if (ab_map.find(name) != ab_map.end()) {
-            return &pos->second;
-        }
-        return nullptr;
-    }
-
-    ~llama_lora_adapter() {
-        for (struct ggml_context * ctx : ctxs) {
-            ggml_free(ctx);
-        }
-        for (ggml_backend_buffer_t buf : bufs) {
-            ggml_backend_buffer_free(buf);
-        }
-        auto pos = base_model->lora_adapters.find(this);
-        if (pos != base_model->lora_adapters.end()) {
-            base_model->lora_adapters.erase(pos);
-        }
-    }
-};
-
-static size_t llama_get_device_count(const llama_model & model) {
-    size_t count = 1;
-#if defined(GGML_USE_CUDA)
-    count = ggml_backend_cuda_get_device_count();
-#elif defined(GGML_USE_SYCL)
-    count = ggml_backend_sycl_get_device_count();
-#elif defined(GGML_USE_VULKAN)
-    count = ggml_backend_vk_get_device_count();
-#elif defined(GGML_USE_CANN)
-    return ggml_backend_cann_get_device_count();
-#endif
-#if defined(GGML_USE_RPC)
-    count += model.rpc_servers.size();
-#endif
-    return count;
-    GGML_UNUSED(model);
-}
-
-static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) {
-    ggml_backend_buffer_type_t buft = nullptr;
-
-#ifdef GGML_USE_RPC
-    int rpc_count = (int)model.rpc_servers.size();
-#else
-    int rpc_count = 0;
-#endif
-    int local_gpu = gpu - rpc_count;
-#if defined(GGML_USE_RPC)
-    if (gpu < rpc_count) {
-        const char * endpoint = model.rpc_servers[gpu].c_str();
-        return ggml_backend_rpc_buffer_type(endpoint);
-    }
-#endif
-#if defined(GGML_USE_METAL)
-    buft = ggml_backend_metal_buffer_type();
-#elif defined(GGML_USE_CUDA)
-    buft = ggml_backend_cuda_buffer_type(local_gpu);
-#elif defined(GGML_USE_VULKAN)
-    buft = ggml_backend_vk_buffer_type(local_gpu);
-#elif defined(GGML_USE_SYCL)
-    buft = ggml_backend_sycl_buffer_type(local_gpu);
-#elif defined(GGML_USE_KOMPUTE)
-    buft = ggml_backend_kompute_buffer_type(local_gpu);
-    if (buft == nullptr) {
-        LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, local_gpu);
-    }
-#elif defined(GGML_USE_CANN)
-    buft = ggml_backend_cann_buffer_type(local_gpu);
-#endif
-
-    if (buft == nullptr) {
-        buft = llama_default_buffer_type_cpu(true);
-    }
-    return buft;
-    GGML_UNUSED(model);
-    GGML_UNUSED(local_gpu);
-}
-
-static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) {
-    ggml_backend_buffer_type_t buft = nullptr;
-
-#ifdef GGML_USE_CUDA
-    if (ggml_backend_cuda_get_device_count() > 1) {
-        buft = ggml_backend_cuda_split_buffer_type(tensor_split);
-    }
-#endif
-
-#ifdef GGML_USE_SYCL
-    if (ggml_backend_sycl_get_device_count() > 1) {
-        buft = ggml_backend_sycl_split_buffer_type(tensor_split);
-    }
-#endif
-
-    if (buft == nullptr) {
-        buft = llama_default_buffer_type_offload(model, fallback_gpu);
-    }
-    return buft;
-
-    GGML_UNUSED(tensor_split);
-}
-
-static size_t llama_get_device_memory(const llama_model & model, int device) {
-#ifdef GGML_USE_RPC
-    int rpc_count = (int)model.rpc_servers.size();
-#else
-    int rpc_count = 0;
-#endif
-    int local_device = device - rpc_count;
-#if defined(GGML_USE_RPC)
-    if (device < rpc_count) {
-        size_t total;
-        size_t free;
-        const char * endpoint = model.rpc_servers[device].c_str();
-        ggml_backend_rpc_get_device_memory(endpoint, &free, &total);
-        return free;
-    }
-#endif
-#if defined(GGML_USE_CUDA)
-    size_t total;
-    size_t free;
-    ggml_backend_cuda_get_device_memory(local_device, &free, &total);
-    return free;
-#elif defined(GGML_USE_SYCL)
-    size_t total;
-    size_t free;
-    ggml_backend_sycl_get_device_memory(local_device, &free, &total);
-    return free;
-#elif defined(GGML_USE_VULKAN)
-    size_t total;
-    size_t free;
-    ggml_backend_vk_get_device_memory(local_device, &free, &total);
-    return free;
-#elif defined(GGML_USE_CANN)
-    size_t total;
-    size_t free;
-    ggml_backend_cann_get_device_memory(local_device, &free, &total);
-    return free;
-#else
-    return 1;
-#endif
-    GGML_UNUSED(model);
-    GGML_UNUSED(local_device);
-}
-
-//
-// kv cache helpers
-//
-
-static bool llama_kv_cache_init(
-             struct llama_kv_cache & cache,
-               const llama_context * ctx,
-                         ggml_type   type_k,
-                         ggml_type   type_v,
-                          uint32_t   kv_size,
-                              bool   offload) {
-    const llama_model & model = ctx->model;
-    const llama_cparams & cparams = ctx->cparams;
-
-    const struct llama_hparams & hparams = model.hparams;
-
-    const int64_t  n_layer = hparams.n_layer;
-
-    cache.has_shift = false;
-
-    cache.recurrent = llama_model_is_recurrent(&model);
-    cache.v_trans   = !cache.recurrent && !cparams.flash_attn;
-
-    cache.head = 0;
-    cache.size = kv_size;
-    cache.used = 0;
-
-    cache.type_k = type_k;
-    cache.type_v = type_v;
-
-    cache.cells.clear();
-    cache.cells.resize(kv_size);
-
-    // count used buffer types
-    std::map buft_layer_count;
-    if (offload) {
-        for (int64_t i = 0; i < n_layer; ++i) {
-            buft_layer_count[model.buft_layer[i].buft]++;
-        }
-    } else {
-        buft_layer_count[llama_default_buffer_type_cpu(true)] = n_layer;
-    }
-
-    // create a context for each buffer type
-    std::map ctx_map;
-    for (auto & it : buft_layer_count) {
-        int n_layers = it.second;
-        struct ggml_init_params params = {
-            /*.mem_size   =*/ 2u*n_layers*ggml_tensor_overhead(),
-            /*.mem_buffer =*/ NULL,
-            /*.no_alloc   =*/ true,
-        };
-        ggml_context * ctx = ggml_init(params);
-        if (!ctx) {
-            LLAMA_LOG_ERROR("%s: failed to allocate context for kv cache\n", __func__);
-            return false;
-        }
-        ctx_map[it.first] = ctx;
-        cache.ctxs.push_back(ctx);
-    }
-
-    cache.k_l.reserve(n_layer);
-    cache.v_l.reserve(n_layer);
-
-    for (int i = 0; i < (int) n_layer; i++) {
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
-
-        struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
-        ggml_format_name(k, "cache_k_l%d", i);
-        ggml_format_name(v, "cache_v_l%d", i);
-        cache.k_l.push_back(k);
-        cache.v_l.push_back(v);
-    }
-
-    // allocate tensors and initialize the buffers to avoid NaNs in the padding
-    for (auto it : ctx_map) {
-        ggml_backend_buffer_type_t buft = it.first;
-        ggml_context * ctx = it.second;
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-        if (!buf) {
-            LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__);
-            return false;
-        }
-        ggml_backend_buffer_clear(buf, 0);
-        LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
-        cache.bufs.push_back(buf);
-    }
-
-    return true;
-}
-
-// find an empty slot of size "n_tokens" in the cache
-// updates the cache head
-// Note: On success, it's important that cache.head points
-// to the first cell of the slot.
-static bool llama_kv_cache_find_slot(
-           struct llama_kv_cache & cache,
-       const struct llama_ubatch & batch) {
-    const uint32_t n_tokens = batch.n_tokens;
-    const uint32_t n_seqs   = batch.n_seqs;
-    const uint32_t n_seq_tokens = batch.n_seq_tokens;
-
-    if (cache.recurrent) {
-        // For recurrent state architectures (like Mamba or RWKV),
-        // each cache cell can store the state for a whole sequence.
-        // A slot should be always be contiguous.
-
-        // can only process batches with an equal number of new tokens in each sequence
-        GGML_ASSERT(batch.equal_seqs);
-
-        int32_t min = cache.size - 1;
-        int32_t max = 0;
-
-        // everything should fit if all seq_ids are smaller than the max
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const uint32_t n_seq_id = batch.n_seq_id[s];
-            for (uint32_t j = 0; j < n_seq_id; ++j) {
-                const llama_seq_id seq_id = batch.seq_id[s][j];
-
-                if (seq_id < 0 || (uint32_t) seq_id >= cache.size) {
-                    // too big seq_id
-                    // TODO: would it be possible to resize the cache instead?
-                    LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
-                    return false;
-                }
-                if (j > 0) {
-                    llama_kv_cell & seq = cache.cells[seq_id];
-                    if (seq.tail >= 0) {
-                        llama_kv_cell & cell = cache.cells[seq.tail];
-                        // clear cells from seq_ids that become shared
-                        // (should not normally happen, but let's handle it anyway)
-                        cell.seq_id.erase(seq_id);
-                        seq.tail = -1;
-                        if (cell.seq_id.empty()) {
-                            cell.pos = -1;
-                            cell.src = -1;
-                            cache.used -= 1;
-                        }
-                    }
-                }
-            }
-        }
-
-#ifndef NDEBUG
-        {
-            std::vector tails_verif;
-            tails_verif.assign(cache.size, -1);
-            for (uint32_t i = 0; i < cache.size; ++i) {
-                llama_kv_cell & cell = cache.cells[i];
-                for (llama_seq_id seq_id : cell.seq_id) {
-                    if (tails_verif[seq_id] != -1) {
-                        LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
-                    }
-                    tails_verif[seq_id] = i;
-                }
-            }
-            for (uint32_t i = 0; i < cache.size; ++i) {
-                if (tails_verif[i] != cache.cells[i].tail) {
-                    LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]);
-                }
-            }
-        }
-#endif
-
-        // find next empty cell
-        uint32_t next_empty_cell = cache.head;
-
-        for (uint32_t i = 0; i < cache.size; ++i) {
-            if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
-            llama_kv_cell & cell = cache.cells[next_empty_cell];
-            if (cell.is_empty()) { break; }
-            next_empty_cell += 1;
-        }
-
-        // find usable cell range
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
-            llama_kv_cell & seq_meta = cache.cells[seq_id];
-            bool has_cell = false;
-            if (seq_meta.tail >= 0) {
-                llama_kv_cell & cell = cache.cells[seq_meta.tail];
-                GGML_ASSERT(cell.has_seq_id(seq_id));
-                // does this seq_id "own" the cell?
-                if (cell.seq_id.size() == 1) { has_cell = true; }
-            }
-            if (!has_cell) {
-                llama_kv_cell & empty_cell = cache.cells[next_empty_cell];
-                GGML_ASSERT(empty_cell.is_empty());
-                // copy old tail into the empty cell
-                if (seq_meta.tail >= 0) {
-                    llama_kv_cell & orig_cell = cache.cells[seq_meta.tail];
-                    empty_cell.pos = orig_cell.pos;
-                    empty_cell.src = orig_cell.src;
-                    orig_cell.seq_id.erase(seq_id);
-                    empty_cell.seq_id.insert(seq_id); // will be overwritten
-                }
-                seq_meta.tail = next_empty_cell;
-                // find next empty cell
-                if (s + 1 < n_seqs) {
-                    next_empty_cell += 1;
-                    for (uint32_t i = 0; i < cache.size; ++i) {
-                        if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; }
-                        llama_kv_cell & cell = cache.cells[next_empty_cell];
-                        if (cell.is_empty()) { break; }
-                        next_empty_cell += 1;
-                    }
-                }
-            }
-            if (min > seq_meta.tail) { min = seq_meta.tail; }
-            if (max < seq_meta.tail) { max = seq_meta.tail; }
-        }
-
-        // gather and re-order
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            int32_t dst_id = s + min;
-            int32_t src_id = cache.cells[batch.seq_id[s][0]].tail;
-            if (dst_id != src_id) {
-                llama_kv_cell & dst_cell = cache.cells[dst_id];
-                llama_kv_cell & src_cell = cache.cells[src_id];
-
-                std::swap(dst_cell.pos, src_cell.pos);
-                std::swap(dst_cell.src, src_cell.src);
-                std::swap(dst_cell.seq_id, src_cell.seq_id);
-
-                // swap tails (assuming they NEVER overlap)
-                for (const llama_seq_id seq_id : src_cell.seq_id) {
-                    cache.cells[seq_id].tail = src_id;
-                }
-                for (const llama_seq_id seq_id : dst_cell.seq_id) {
-                    cache.cells[seq_id].tail = dst_id;
-                }
-            }
-        }
-
-        // update the pos of the used seqs
-        for (uint32_t s = 0; s < n_seqs; ++s) {
-            const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1];
-            int32_t cell_id = s + min;
-            llama_kv_cell & cell = cache.cells[cell_id];
-
-            if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
-                // What should happen when the pos backtracks or skips a value?
-                // Clearing the state mid-batch would require special-casing which isn't done.
-                LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
-                    __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
-            }
-            cell.pos = last_pos;
-            cell.seq_id.clear();
-            for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) {
-                const llama_seq_id seq_id = batch.seq_id[s][j];
-                cell.seq_id.insert(seq_id);
-                cache.cells[seq_id].tail = cell_id;
-            }
-        }
-
-        // allow getting the range of used cells, from head to head + n
-        cache.head = min;
-        cache.n    = max - min + 1;
-
-        // sanity check
-        return cache.n >= n_seqs;
-    }
-    // otherwise, one cell per token.
-
-    if (n_tokens > cache.size) {
-        LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
-        return false;
-    }
-
-    uint32_t n_tested = 0;
-
-    while (true) {
-        if (cache.head + n_tokens > cache.size) {
-            n_tested += cache.size - cache.head;
-            cache.head = 0;
-            continue;
-        }
-
-        bool found = true;
-        for (uint32_t i = 0; i < n_tokens; i++) {
-            if (cache.cells[cache.head + i].pos >= 0) {
-                found = false;
-                cache.head += i + 1;
-                n_tested   += i + 1;
-                break;
-            }
-        }
-
-        if (found) {
-            break;
-        }
-
-        if (n_tested >= cache.size) {
-            //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
-            return false;
-        }
-    }
-
-    for (uint32_t s = 0; s < n_seqs; s++) {
-        for (uint32_t i = 0; i < n_seq_tokens; ++i) {
-            uint32_t k = s*n_seq_tokens + i;
-            cache.cells[cache.head + k].pos = batch.pos[k];
-
-            for (int32_t j = 0; j < batch.n_seq_id[s]; j++) {
-                cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]);
-            }
-        }
-    }
-
-    cache.used += n_tokens;
-
-    return true;
-}
-
-// find how many cells are currently in use
-static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
-    for (uint32_t i = cache.size; i > 0; --i) {
-        const llama_kv_cell & cell = cache.cells[i - 1];
-
-        if (cell.pos >= 0 && !cell.is_empty()) {
-            return i;
-        }
-    }
-
-    return 0;
-}
-
-static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
-    for (int32_t i = 0; i < (int32_t) cache.size; ++i) {
-        cache.cells[i].pos = -1;
-        cache.cells[i].seq_id.clear();
-        cache.cells[i].src = -1;
-        cache.cells[i].tail = -1;
-    }
-    cache.head = 0;
-    cache.used = 0;
-
-    for (auto & buf : cache.bufs) {
-        ggml_backend_buffer_clear(buf, 0);
-    }
-}
-
-static bool llama_kv_cache_seq_rm(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1) {
-    uint32_t new_head = cache.size;
-
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits::max();
-
-    // models like Mamba or RWKV can't have a state partially erased
-    if (cache.recurrent) {
-        if (seq_id >= (int64_t) cache.size) {
-            // could be fatal
-            return false;
-        }
-        if (0 <= seq_id) {
-            int32_t & tail_id = cache.cells[seq_id].tail;
-            if (tail_id >= 0) {
-                const llama_kv_cell & cell = cache.cells[tail_id];
-                // partial intersection is invalid
-                if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
-                    return false;
-                }
-                // invalidate tails which will be cleared
-                if (p0 <= cell.pos && cell.pos < p1) {
-                    tail_id = -1;
-                }
-            }
-        } else {
-            // seq_id is negative, then the range should include everything or nothing
-            if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) {
-                return false;
-            }
-        }
-    }
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            if (seq_id < 0) {
-                cache.cells[i].seq_id.clear();
-            } else if (cache.cells[i].has_seq_id(seq_id)) {
-                cache.cells[i].seq_id.erase(seq_id);
-            } else {
-                continue;
-            }
-            if (cache.cells[i].is_empty()) {
-                // keep count of the number of used cells
-                if (cache.cells[i].pos >= 0) cache.used--;
-
-                cache.cells[i].pos = -1;
-                cache.cells[i].src = -1;
-                if (new_head == cache.size) new_head = i;
-            }
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
-
-    return true;
-}
-
-static void llama_kv_cache_seq_cp(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id_src,
-                 llama_seq_id   seq_id_dst,
-                    llama_pos   p0,
-                    llama_pos   p1) {
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits::max();
-
-    if (cache.recurrent) {
-        if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
-            llama_kv_cell & tail_src = cache.cells[seq_id_src];
-            llama_kv_cell & tail_dst = cache.cells[seq_id_dst];
-            if (tail_dst.tail >= 0) {
-                // clear destination seq_id if it wasn't empty
-                llama_kv_cell & cell_dst = cache.cells[tail_dst.tail];
-
-                cell_dst.seq_id.erase(seq_id_dst);
-                tail_dst.tail = -1;
-                if (cell_dst.seq_id.empty()) {
-                    cell_dst.pos = -1;
-                    cell_dst.delta = -1;
-                    cell_dst.src = -1;
-                    cache.used -= 1;
-                }
-            }
-            if (tail_src.tail >= 0) {
-                llama_kv_cell & cell_src = cache.cells[tail_src.tail];
-
-                cell_src.seq_id.insert(seq_id_dst);
-                tail_dst.tail = tail_src.tail;
-            }
-        }
-
-        return;
-    }
-    // otherwise, this is the KV cache of a Transformer-like model
-
-    cache.head = 0;
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.cells[i].seq_id.insert(seq_id_dst);
-        }
-    }
-}
-
-static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
-    uint32_t new_head = cache.size;
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.recurrent && (llama_seq_id) i != seq_id) {
-            cache.cells[i].tail = -1;
-        }
-        if (!cache.cells[i].has_seq_id(seq_id)) {
-            if (cache.cells[i].pos >= 0) cache.used--;
-            cache.cells[i].pos = -1;
-            cache.cells[i].src = -1;
-            cache.cells[i].seq_id.clear();
-            if (new_head == cache.size) new_head = i;
-        } else {
-            cache.cells[i].seq_id.clear();
-            cache.cells[i].seq_id.insert(seq_id);
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
-}
-
-static void llama_kv_cache_seq_add(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                    llama_pos   delta) {
-    uint32_t new_head = cache.size;
-
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits::max();
-    // If there is no range then return early to avoid looping over the cache.
-    if (p0 == p1) return;
-
-    if (cache.recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be shifted
-        if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            const int32_t tail_id = cache.cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cache.cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos += delta;
-                }
-            }
-        }
-        return;
-    }
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.has_shift = true;
-            cache.cells[i].pos   += delta;
-            cache.cells[i].delta += delta;
-
-            if (cache.cells[i].pos < 0) {
-                if (!cache.cells[i].is_empty()) {
-                    cache.used--;
-                }
-                cache.cells[i].pos = -1;
-                cache.cells[i].seq_id.clear();
-                if (new_head == cache.size) {
-                    new_head = i;
-                }
-            }
-        }
-    }
-
-    // If we freed up a slot, set head to it so searching can start there.
-    // Otherwise we just start the next search from the beginning.
-    cache.head = new_head != cache.size ? new_head : 0;
-}
-
-static void llama_kv_cache_seq_div(
-        struct llama_kv_cache & cache,
-                 llama_seq_id   seq_id,
-                    llama_pos   p0,
-                    llama_pos   p1,
-                          int   d) {
-    if (p0 < 0) p0 = 0;
-    if (p1 < 0) p1 = std::numeric_limits::max();
-    // If there is no range then return early to avoid looping over the cache.
-    if (p0 == p1) return;
-
-    if (cache.recurrent) {
-        // for Mamba-like or RWKV models, only the pos needs to be changed
-        if (0 <= seq_id && seq_id < (int64_t) cache.size) {
-            const int32_t tail_id = cache.cells[seq_id].tail;
-            if (tail_id >= 0) {
-                llama_kv_cell & cell = cache.cells[tail_id];
-                if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
-                    cell.pos /= d;
-                }
-            }
-        }
-        return;
-    }
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.has_shift = true;
-
-            {
-                llama_pos p_old = cache.cells[i].pos;
-                cache.cells[i].pos   /= d;
-                cache.cells[i].delta += cache.cells[i].pos - p_old;
-            }
-        }
-    }
-}
-
-static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) {
-    llama_pos result = 0;
-
-    for (uint32_t i = 0; i < cache.size; ++i) {
-        if (cache.cells[i].has_seq_id(seq_id)) {
-            result = std::max(result, cache.cells[i].pos);
-        }
-    }
-
-    return result;
-}
-
-static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
-    if (!cache.recurrent) {
-        cache.do_defrag = true;
-    }
-}
-
-static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) {
-    // the FA kernels require padding to avoid extra runtime boundary checks
-    return cparams.flash_attn ? 256u : 32u;
-}
-
-//
-// model loading and saving
-//
-
-enum llama_fver {
-    GGUF_FILE_VERSION_V1 = 1,
-    GGUF_FILE_VERSION_V2 = 2,
-    GGUF_FILE_VERSION_V3 = 3,
-};
-
-static const char * llama_file_version_name(llama_fver version) {
-    switch (version) {
-        case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
-        case GGUF_FILE_VERSION_V2: return "GGUF V2";
-        case GGUF_FILE_VERSION_V3: return "GGUF V3 (latest)";
-    }
-
-    return "unknown";
-}
-
-static std::string llama_format_tensor_shape(const std::vector & ne) {
-    char buf[256];
-    snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0));
-    for (size_t i = 1; i < ne.size(); i++) {
-        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i));
-    }
-    return buf;
-}
-
-static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
-    char buf[256];
-    snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]);
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]);
-    }
-    return buf;
-}
-
-namespace GGUFMeta {
-    template 
-    struct GKV_Base_Type {
-        static constexpr gguf_type gt = gt_;
-
-        static T getter(const gguf_context * ctx, const int kid) {
-            return gfun(ctx, kid);
-        }
-    };
-
-    template struct GKV_Base;
-
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-    template<> struct GKV_Base: GKV_Base_Type {};
-
-    template<> struct GKV_Base {
-        static constexpr gguf_type gt = GGUF_TYPE_STRING;
-
-        static std::string getter(const gguf_context * ctx, const int kid) {
-            return gguf_get_val_str(ctx, kid);
-        }
-    };
-
-    struct ArrayInfo {
-        const gguf_type gt;
-        const size_t length;
-        const void * data;
-    };
-
-    template<> struct GKV_Base {
-        public:
-        static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
-        static ArrayInfo getter(const gguf_context *ctx, const int k) {
-            return ArrayInfo {
-                gguf_get_arr_type(ctx, k),
-                size_t(gguf_get_arr_n(ctx, k)),
-                gguf_get_arr_data(ctx, k),
-            };
-        }
-    };
-
-    template
-    class GKV : public GKV_Base {
-        GKV() = delete;
-
-        public:
-        static T get_kv(const gguf_context * ctx, const int k) {
-            const enum gguf_type kt = gguf_get_kv_type(ctx, k);
-
-            if (kt != GKV::gt) {
-                throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
-                    gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
-            }
-            return GKV::getter(ctx, k);
-        }
-
-        static const char * override_type_to_str(const llama_model_kv_override_type ty) {
-            switch (ty) {
-                case LLAMA_KV_OVERRIDE_TYPE_BOOL:  return "bool";
-                case LLAMA_KV_OVERRIDE_TYPE_INT:   return "int";
-                case LLAMA_KV_OVERRIDE_TYPE_FLOAT: return "float";
-                case LLAMA_KV_OVERRIDE_TYPE_STR:   return "str";
-            }
-            return "unknown";
-        }
-
-        static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override * ovrd) {
-            if (!ovrd) { return false; }
-            if (ovrd->tag == expected_type) {
-                LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
-                    __func__, override_type_to_str(ovrd->tag), ovrd->key);
-                switch (ovrd->tag) {
-                    case LLAMA_KV_OVERRIDE_TYPE_BOOL:  {
-                        LLAMA_LOG_INFO("%s\n", ovrd->val_bool ? "true" : "false");
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_INT:   {
-                        LLAMA_LOG_INFO("%" PRId64 "\n", ovrd->val_i64);
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_FLOAT: {
-                        LLAMA_LOG_INFO("%.6f\n", ovrd->val_f64);
-                    } break;
-                    case LLAMA_KV_OVERRIDE_TYPE_STR: {
-                        LLAMA_LOG_INFO("%s\n", ovrd->val_str);
-                    } break;
-                    default:
-                        // Shouldn't be possible to end up here, but just in case...
-                        throw std::runtime_error(
-                            format("Unsupported attempt to override %s type for metadata key %s\n",
-                                override_type_to_str(ovrd->tag), ovrd->key));
-                }
-                return true;
-            }
-            LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
-                __func__, ovrd->key, override_type_to_str(expected_type), override_type_to_str(ovrd->tag));
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_BOOL, ovrd)) {
-                target = ovrd->val_bool;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value && std::is_integral::value, bool>::type
-        try_override(OT & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_INT, ovrd)) {
-                target = ovrd->val_i64;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_FLOAT, ovrd)) {
-                target = ovrd->val_f64;
-                return true;
-            }
-            return false;
-        }
-
-        template
-        static typename std::enable_if::value, bool>::type
-        try_override(T & target, const struct llama_model_kv_override * ovrd) {
-            if (validate_override(LLAMA_KV_OVERRIDE_TYPE_STR, ovrd)) {
-                target = ovrd->val_str;
-                return true;
-            }
-            return false;
-        }
-
-        static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            if (try_override(target, ovrd)) {
-                return true;
-            }
-            if (k < 0) { return false; }
-            target = get_kv(ctx, k);
-            return true;
-        }
-
-        static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            return set(ctx, gguf_find_key(ctx, key), target, ovrd);
-        }
-
-        static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override * ovrd = nullptr) {
-            return set(ctx, key.c_str(), target, ovrd);
-        }
-    };
-}
-
-using llama_buf_map = std::unordered_map;
-
-static size_t llama_model_max_nodes(const llama_model & model) {
-    return std::max(8192, model.tensors_by_name.size()*5);
-}
-
-struct llama_model_loader {
-    int n_kv      = 0;
-    int n_tensors = 0;
-    int n_created = 0;
-
-    int64_t n_elements = 0;
-    size_t  n_bytes    = 0;
-
-    bool use_mmap = false;
-    bool check_tensors;
-
-    llama_files files;
-    llama_ftype ftype;
-    llama_fver  fver;
-
-    llama_mmaps mappings;
-
-    // Holds information on a model weight
-    struct llama_tensor_weight {
-        uint16_t  idx; // source file index
-        size_t   offs; // tensor data offset in the original file
-
-        ggml_tensor * tensor;
-
-        llama_tensor_weight(const llama_file * file, uint16_t idx, const char * name, const struct gguf_context * gguf_ctx, ggml_tensor * tensor) : idx(idx), tensor(tensor) {
-            const int tensor_idx = gguf_find_tensor(gguf_ctx, name);
-            offs = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, tensor_idx);
-
-            if (offs + ggml_nbytes(tensor) < offs || offs + ggml_nbytes(tensor) > file->size) {
-                throw std::runtime_error(format("tensor '%s' data is not within the file bounds, model is corrupted or incomplete", name));
-            }
-        }
-    };
-    std::vector weights;
-
-    std::unordered_map kv_overrides;
-
-    struct gguf_context * meta = NULL;
-    std::vector contexts;
-
-    std::string arch_name;
-    LLM_KV      llm_kv    = LLM_KV(LLM_ARCH_UNKNOWN);
-
-    llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
-        int trace = 0;
-        if (getenv("LLAMA_TRACE")) {
-            trace = atoi(getenv("LLAMA_TRACE"));
-        }
-
-        if (param_overrides_p != nullptr) {
-            for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) {
-                kv_overrides.insert({std::string(p->key), *p});
-            }
-        }
-
-        struct ggml_context * ctx = NULL;
-        struct gguf_init_params params = {
-            /*.no_alloc = */ true,
-            /*.ctx      = */ &ctx,
-        };
-
-        meta = gguf_init_from_file(fname.c_str(), params);
-        if (!meta) {
-            throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str()));
-        }
-
-        get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
-        llm_kv = LLM_KV(llm_arch_from_string(arch_name));
-
-        files.emplace_back(new llama_file(fname.c_str(), "rb"));
-        contexts.emplace_back(ctx);
-
-        // Save tensors data offset of the main file.
-        // For subsidiary files, `meta` tensor data offset must not be used,
-        // so we build a unified tensors index for weights.
-        for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-            weights.emplace_back(files.back().get(), 0, cur->name, meta, cur);
-        }
-        uint16_t n_split = 0;
-        get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false);
-
-        // Load additional GGML contexts
-        if (n_split > 1) {
-            uint16_t idx = 0;
-            get_key(llm_kv(LLM_KV_SPLIT_NO), idx);
-            if (idx != 0) {
-                throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx));
-            }
-
-            char split_prefix[PATH_MAX] = {0};
-            if (!llama_split_prefix(split_prefix, sizeof(split_prefix), fname.c_str(), idx, n_split)) {
-                throw std::runtime_error(format("invalid split file: %s", fname.c_str()));
-            }
-
-            if (trace > 0) {
-                LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split);
-            }
-
-            char split_path[PATH_MAX] = {0};
-            for (idx = 1; idx < n_split; idx++) {
-                llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
-
-                struct gguf_init_params split_params = {
-                    /*.no_alloc = */ true,
-                    /*.ctx      = */ &ctx,
-                };
-                struct gguf_context * ctx_gguf = gguf_init_from_file(split_path, split_params);
-                if (!ctx_gguf) {
-                    throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path));
-                }
-
-                files.emplace_back(new llama_file(split_path, "rb"));
-                contexts.emplace_back(ctx);
-
-                // Save tensors data offset info of the shard.
-                for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-                    weights.emplace_back(files.back().get(), idx, cur->name, ctx_gguf, cur);
-                }
-
-                gguf_free(ctx_gguf);
-            }
-
-            get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors);
-
-            // sanity check
-            {
-                const int n_tensors_loaded = (int) weights.size();
-                if (n_tensors != n_tensors_loaded) {
-                    throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded));
-                }
-            }
-
-            LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n",  __func__, n_split - 1);
-        }
-
-        n_kv      = gguf_get_n_kv(meta);
-        n_tensors = weights.size();
-
-        fver = (enum llama_fver) gguf_get_version(meta);
-
-        std::set tensor_names;
-        for (auto & w : weights) {
-            n_elements += ggml_nelements(w.tensor);
-            n_bytes    += ggml_nbytes(w.tensor);
-            // make sure there is no duplicated tensor names
-            const std::string name(w.tensor->name);
-            auto found = tensor_names.find(name);
-            if (found != tensor_names.end()) {
-                throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", w.tensor->name));
-            }
-            tensor_names.insert(name);
-        }
-
-        LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n",
-                __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver));
-
-        // determine file type based on the number of tensors for each quantization and print meta data
-        // TODO: make optional
-        {
-            std::map n_type;
-
-            uint32_t n_type_max = 0;
-            enum ggml_type type_max = GGML_TYPE_F32;
-
-            for (int i = 0; i < n_tensors; i++) {
-                const ggml_tensor * tensor = weights.at(i).tensor;
-                enum ggml_type type = tensor->type;
-
-                n_type[type]++;
-
-                if (n_type_max < n_type[type]) {
-                    n_type_max = n_type[type];
-                    type_max   = type;
-                }
-
-                if (trace > 0) {
-                    const uint16_t sid = weights.at(i).idx;
-                    LLAMA_LOG_INFO("%s: - tensor %4d, split %2d: %32s %-8s [ %s ]\n", __func__, i, sid, ggml_get_name(tensor), ggml_type_name(type), llama_format_tensor_shape(tensor).c_str());
-                }
-            }
-
-            switch (type_max) {
-                case GGML_TYPE_F32:     ftype = LLAMA_FTYPE_ALL_F32;        break;
-                case GGML_TYPE_F16:     ftype = LLAMA_FTYPE_MOSTLY_F16;     break;
-                case GGML_TYPE_BF16:    ftype = LLAMA_FTYPE_MOSTLY_BF16;    break;
-                case GGML_TYPE_Q4_0:    ftype = LLAMA_FTYPE_MOSTLY_Q4_0;    break;
-                case GGML_TYPE_Q4_1:    ftype = LLAMA_FTYPE_MOSTLY_Q4_1;    break;
-                case GGML_TYPE_Q5_0:    ftype = LLAMA_FTYPE_MOSTLY_Q5_0;    break;
-                case GGML_TYPE_Q5_1:    ftype = LLAMA_FTYPE_MOSTLY_Q5_1;    break;
-                case GGML_TYPE_Q8_0:    ftype = LLAMA_FTYPE_MOSTLY_Q8_0;    break;
-                case GGML_TYPE_Q2_K:    ftype = LLAMA_FTYPE_MOSTLY_Q2_K;    break;
-                case GGML_TYPE_Q3_K:    ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M;  break;
-                case GGML_TYPE_Q4_K:    ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M;  break;
-                case GGML_TYPE_Q5_K:    ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M;  break;
-                case GGML_TYPE_Q6_K:    ftype = LLAMA_FTYPE_MOSTLY_Q6_K;    break;
-                case GGML_TYPE_TQ1_0:   ftype = LLAMA_FTYPE_MOSTLY_TQ1_0;   break;
-                case GGML_TYPE_TQ2_0:   ftype = LLAMA_FTYPE_MOSTLY_TQ2_0;   break;
-                case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
-                case GGML_TYPE_IQ2_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS;  break;
-                case GGML_TYPE_IQ2_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ2_S;   break;
-                case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
-                case GGML_TYPE_IQ1_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ1_S;   break;
-                case GGML_TYPE_IQ1_M:   ftype = LLAMA_FTYPE_MOSTLY_IQ1_M;   break;
-                case GGML_TYPE_IQ4_NL:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL;  break;
-                case GGML_TYPE_IQ4_XS:  ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS;  break;
-                case GGML_TYPE_IQ3_S:   ftype = LLAMA_FTYPE_MOSTLY_IQ3_S;   break;
-                case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
-                case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
-                case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break;
-                default:
-                    {
-                        LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
-                        ftype = LLAMA_FTYPE_ALL_F32;
-                    } break;
-            }
-
-            // this is a way to mark that we have "guessed" the file type
-            ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED);
-
-            {
-                const int kid = gguf_find_key(meta, "general.file_type"); // TODO: use LLM_KV
-                if (kid >= 0) {
-                    ftype = (llama_ftype) gguf_get_val_u32(meta, kid);
-                }
-            }
-
-            LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__);
-
-            for (int i = 0; i < n_kv; i++) {
-                const char * name           = gguf_get_key(meta, i);
-                const enum gguf_type type   = gguf_get_kv_type(meta, i);
-                const std::string type_name =
-                    type == GGUF_TYPE_ARRAY
-                    ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta, i)), gguf_get_arr_n(meta, i))
-                    : gguf_type_name(type);
-
-                std::string value          = gguf_kv_to_str(meta, i);
-                const size_t MAX_VALUE_LEN = 40;
-                if (value.size() > MAX_VALUE_LEN) {
-                    value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str());
-                }
-                replace_all(value, "\n", "\\n");
-
-                LLAMA_LOG_INFO("%s: - kv %3d: %42s %-16s = %s\n", __func__, i, name, type_name.c_str(), value.c_str());
-            }
-
-            // print type counts
-            for (auto & kv : n_type) {
-                if (kv.second == 0) {
-                    continue;
-                }
-
-                LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second);
-            }
-        }
-
-        if (!llama_mmap::SUPPORTED) {
-            LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__);
-            use_mmap = false;
-        }
-
-        this->use_mmap = use_mmap;
-        this->check_tensors = check_tensors;
-    }
-
-    ~llama_model_loader() {
-        if (meta) {
-            gguf_free(meta);
-        }
-        for (auto * ctx : contexts) {
-            ggml_free(ctx);
-        }
-    }
-
-    template
-    typename std::enable_if::value, bool>::type
-    get_arr_n(const std::string & key, T & result, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0) {
-            if (required) {
-                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta, kid);
-
-
-        result = arr_info.length;
-        return true;
-    }
-
-    template
-    typename std::enable_if::value, bool>::type
-    get_arr_n(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_arr_n(llm_kv(kid), result, required);
-    }
-
-    template
-    bool get_arr(const std::string & key, std::vector & result, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
-            if (required) {
-                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta, kid);
-
-        switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same::value) ||
-                                            (std::is_same::value));  break;
-            default:
-                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
-        }
-
-        result.resize(arr_info.length);
-        result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
-
-        return true;
-    }
-
-    template
-    bool get_arr(const std::string & key, std::array & result, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
-            if (required) {
-                throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        struct GGUFMeta::ArrayInfo arr_info =
-            GGUFMeta::GKV::get_kv(meta, kid);
-
-        switch (arr_info.gt) {
-            case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break;
-            case GGUF_TYPE_INT32:   GGML_ASSERT(
-                                            (std::is_same::value) ||
-                                            (std::is_same::value));  break;
-            default:
-                throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
-        }
-
-        if (arr_info.length > N_MAX) {
-            throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX));
-        }
-
-        std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin());
-
-        return true;
-    }
-
-    template
-    bool get_arr(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_arr(llm_kv(kid), result, required);
-    }
-
-    template
-    bool get_key(const std::string & key, T & result, const bool required = true) {
-        auto it = kv_overrides.find(key);
-
-        const struct llama_model_kv_override * override =
-            it != kv_overrides.end() ? &it->second : nullptr;
-
-        const bool found = GGUFMeta::GKV::set(meta, key, result, override);
-
-        if (required && !found) {
-            throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-        }
-
-        return found;
-    }
-
-    template
-    bool get_key(const enum llm_kv kid, T & result, const bool required = true) {
-        return get_key(llm_kv(kid), result, required);
-    }
-
-    // get array of n <= N_MAX elements, or a single element repeated n times
-    template
-    bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required = true) {
-        const int kid = gguf_find_key(meta, key.c_str());
-
-        if (kid < 0) {
-            if (required) {
-                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
-            }
-            return false;
-        }
-
-        if (n > N_MAX) {
-            throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str()));
-        }
-
-        if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) {
-            struct GGUFMeta::ArrayInfo arr_info =
-                GGUFMeta::GKV::get_kv(meta, kid);
-
-            if (n != arr_info.length) {
-                throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length));
-            }
-
-            return get_arr(key, result, required);
-        } else {
-            T value;
-
-            bool ok = get_key(key, value, required);
-            if (!ok) {
-                return false;
-            }
-
-            for (uint32_t i = 0; i < n; i++) {
-                result[i] = value;
-            }
-
-            return true;
-        }
-    }
-
-    template
-    bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true) {
-        return get_key_or_arr(llm_kv(kid), result, n, required);
-    }
-
-    std::string get_arch_name() const {
-        return arch_name;
-    }
-
-    enum llm_arch get_arch() const {
-        return llm_kv.arch;
-    }
-
-    const char * get_tensor_name(int i) const {
-        return weights.at(i).tensor->name;
-    }
-
-    const llama_tensor_weight * get_weight(const char * name) const {
-        for (const auto & weight : weights) {
-            if (strcmp(name, weight.tensor->name) == 0) {
-                return &weight;
-            }
-        }
-        return nullptr;
-    }
-
-    const llama_tensor_weight * get_weight(int i) const {
-        return get_weight(get_tensor_name(i));
-    }
-
-    const llama_tensor_weight & require_weight(const char * name) const {
-        const llama_tensor_weight * weight = get_weight(name);
-        if (!weight) {
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
-        }
-        return *weight;
-    }
-
-    struct ggml_tensor * get_tensor_meta(const char * name) const {
-        const auto * weight = get_weight(name);
-        if (!weight) {
-            return nullptr;
-        }
-        return weight->tensor;
-    }
-
-    struct ggml_tensor * require_tensor_meta(const char * name) const {
-        struct ggml_tensor * tensor = get_tensor_meta(name);
-        if (!tensor) {
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name));
-        }
-        return tensor;
-    }
-
-    struct ggml_tensor * get_tensor_meta(int i) const {
-        return get_tensor_meta(get_tensor_name(i));
-    }
-
-    struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) {
-        struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur);
-        ggml_set_name(tensor, ggml_get_name(cur));
-
-        if (duplicated) {
-            size_data += ggml_nbytes(cur);
-        } else {
-            n_created++;
-        }
-
-        return tensor;
-    }
-
-    const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const {
-        const struct ggml_tensor * cur = get_tensor_meta(name.c_str());
-
-        if (cur == NULL) {
-            if (!required) {
-                return NULL;
-            }
-            throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str()));
-        }
-
-        {
-            bool is_ok = true;
-            for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
-                if ((i < ne.size() && ne[i] != cur->ne[i]) || (i >= ne.size() && cur->ne[i] != 1)) {
-                    is_ok = false;
-                    break;
-                }
-            }
-            if (!is_ok) {
-                throw std::runtime_error(
-                        format("%s: tensor '%s' has wrong shape; expected %s, got %s",
-                            __func__, name.c_str(),
-                            llama_format_tensor_shape(ne).c_str(),
-                            llama_format_tensor_shape(cur).c_str()));
-            }
-        }
-
-        return cur;
-    }
-
-    static const int TENSOR_NOT_REQUIRED = 1;
-    static const int TENSOR_DUPLICATED   = 2;
-
-    struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) {
-        const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
-
-        if (cur == NULL) {
-            return NULL;
-        }
-
-        return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED);
-    }
-
-    struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector & ne, size_t offset, bool required = true) {
-        const struct ggml_tensor * cur = check_tensor_dims(name, ne, required);
-
-        if (cur == NULL) {
-            return NULL;
-        }
-
-        if (cur->type != base->type) {
-            throw std::runtime_error(format("%s: tensor '%s' has wrong type; expected %s, got %s", __func__, name.c_str(), ggml_type_name(base->type), ggml_type_name(cur->type)));
-        }
-
-        std::array dims;
-        for (size_t i = 0; i < GGML_MAX_DIMS; ++i) {
-            dims[i] = i < ne.size() ? ne[i] : 1;
-        }
-
-        struct ggml_tensor * tensor = ggml_view_4d(ctx, base,
-                                        dims[0], dims[1], dims[2], dims[3],
-                                        cur->nb[1], cur->nb[2], cur->nb[3],
-                                        offset);
-
-        ggml_set_name(tensor, name.c_str());
-
-        n_created++;
-
-        return tensor;
-    }
-
-    void done_getting_tensors() const {
-        if (n_created != n_tensors) {
-            throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
-        }
-    }
-
-    void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr) {
-        if (use_mmap) {
-            mappings.reserve(files.size());
-            mmaps_used.reserve(files.size());
-            for (const auto & file : files) {
-                std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, ggml_is_numa()));
-                mmaps_used.emplace_back(mapping->size, 0);
-                if (mlock_mmaps) {
-                    std::unique_ptr mlock_mmap(new llama_mlock());
-                    mlock_mmap->init(mapping->addr);
-                    mlock_mmaps->emplace_back(std::move(mlock_mmap));
-                }
-                mappings.emplace_back(std::move(mapping));
-            }
-        }
-
-        // compute the total size of all tensors for progress reporting
-        for (auto & w : weights) {
-            size_data += ggml_nbytes(w.tensor);
-        }
-    }
-
-    void get_mapping_range(size_t * first, size_t * last, void ** addr, int idx, ggml_context * ctx) const {
-        GGML_ASSERT(!mappings.empty());
-        const auto & mapping = mappings.at(idx);
-
-        *first = mapping->size;
-        *last  = 0;
-        *addr = mapping->addr;
-        for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
-            try {
-                const auto * weight = get_weight(ggml_get_name(tensor));
-                if (!weight) {
-                    continue;
-                }
-                if (weight->idx != idx) {
-                    continue;
-                }
-                *first = std::min(*first, weight->offs);
-                *last  = std::max(*last,  weight->offs + ggml_nbytes(tensor));
-            } catch(...) {
-                // the tensor is not in the model
-            }
-        }
-    }
-
-    // for backwards compatibility, does not support ggml-backend
-    void load_data_for(struct ggml_tensor * cur) const {
-        const auto & w = require_weight(ggml_get_name(cur));
-
-        if (use_mmap) {
-            const auto & mapping = mappings.at(w.idx);
-            if (cur->data == nullptr) {
-                cur->data = (uint8_t *)mapping->addr + w.offs;
-            } else {
-                memcpy(cur->data, (uint8_t *)mapping->addr + w.offs, ggml_nbytes(cur));
-            }
-        } else {
-            GGML_ASSERT(cur->data != nullptr);
-            GGML_ASSERT(w.idx < files.size());
-            const auto & file = files.at(w.idx);
-            file->seek(w.offs, SEEK_SET);
-            file->read_raw(cur->data, ggml_nbytes(cur));
-        }
-
-        if (check_tensors && !ggml_validate_row_data(cur->type, cur->data, ggml_nbytes(cur))) {
-            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
-        }
-    }
-
-    size_t size_done = 0;
-    size_t size_data = 0;
-    std::vector> mmaps_used;
-
-    // Returns false if cancelled by progress_callback
-    bool load_all_data(
-            struct ggml_context * ctx,
-            llama_buf_map & bufs_mmap,
-            llama_mlocks * lmlocks,
-            llama_progress_callback progress_callback,
-            void * progress_callback_user_data) {
-        GGML_ASSERT(size_data != 0 && "call init_mappings() first");
-
-        std::vector> read_buf;
-        std::vector>> validation_result;
-
-#if defined(GGML_USE_CUDA)
-        // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives.
-        // NVMe raid configurations might require more / larger buffers.
-        constexpr size_t n_buffers = 4;
-        constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB
-
-        std::vector host_buffers;
-        std::vector host_ptrs;
-        std::vector events;
-        size_t buffer_idx = 0; // buffer to use for async loads
-
-        ggml_backend_t cuda_backend = nullptr;
-        if (!use_mmap && !check_tensors) {
-            // When not using mmaped io use async uploads from pinned memory to GPU memory.
-            // First determine if the CUDA backend is active, and if so, determine the device ID.
-            ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr;
-            if (buf) {
-                ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf);
-                for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
-                    auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i);
-                    if (buffer_type == cuda_buffer_type) {
-                        cuda_backend = ggml_backend_cuda_init(i);
-                        break;
-                    }
-                }
-            }
-
-            // If the cuda backend is active create pinned memory buffers and events for synchronisation.
-            if (cuda_backend) {
-                for (size_t idx = 0; idx < n_buffers; ++idx) {
-                    host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size));
-                    host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx]));
-                    events.emplace_back(ggml_backend_event_new(cuda_backend));
-                }
-            }
-        }
-#endif
-
-        for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
-            const auto * weight = get_weight(ggml_get_name(cur));
-            if (weight == nullptr) {
-                // this can happen with split experts models
-                continue;
-            }
-
-            if (progress_callback) {
-                if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) {
-                    return false;
-                }
-            }
-
-            size_t n_size = ggml_nbytes(cur);
-
-            if (use_mmap) {
-                const auto & mapping = mappings.at(weight->idx);
-                ggml_backend_buffer_t buf_mmap = nullptr;
-                if (bufs_mmap.count(weight->idx)) {
-                    buf_mmap = bufs_mmap.at(weight->idx);
-                }
-                uint8_t * data = (uint8_t *) mapping->addr + weight->offs;
-
-                if (check_tensors) {
-                    validation_result.emplace_back(std::async(std::launch::async, [cur, data, n_size] {
-                        return std::make_pair(cur, ggml_validate_row_data(cur->type, data, n_size));
-                    }));
-                }
-
-                GGML_ASSERT(buf_mmap || cur->data); // either we have a buffer to allocate the tensor in, or it is already allocated
-                if (buf_mmap && cur->data == nullptr) {
-                    ggml_backend_tensor_alloc(buf_mmap, cur, data);
-                    if (lmlocks) {
-                        const auto & lmlock = lmlocks->at(weight->idx);
-                        lmlock->grow_to(weight->offs + n_size);
-                    }
-
-                    auto & mmap_used = mmaps_used[weight->idx];
-                    mmap_used.first  = std::min(mmap_used.first,  weight->offs);
-                    mmap_used.second = std::max(mmap_used.second, weight->offs + n_size);
-                } else {
-                    ggml_backend_tensor_set(cur, data, 0, n_size);
-                }
-            } else {
-                GGML_ASSERT(weight->idx < files.size());
-                const auto & file = files.at(weight->idx);
-                if (ggml_backend_buffer_is_host(cur->buffer)) {
-                    file->seek(weight->offs, SEEK_SET);
-                    file->read_raw(cur->data, n_size);
-                    if (check_tensors) {
-                        validation_result.emplace_back(std::async(std::launch::async, [cur, n_size] {
-                            return std::make_pair(cur, ggml_validate_row_data(cur->type, cur->data, n_size));
-                        }));
-                    }
-                } else {
-#if defined(GGML_USE_CUDA)
-                    // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU.
-                    if (cuda_backend) {
-                        file->seek(weight->offs, SEEK_SET);
-
-                        size_t bytes_read = 0;
-
-                        while (bytes_read < n_size) {
-                            size_t read_iteration = std::min(buffer_size, n_size - bytes_read);
-
-                            ggml_backend_event_synchronize(events[buffer_idx]);
-                            file->read_raw(host_ptrs[buffer_idx], read_iteration);
-                            ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration);
-                            ggml_backend_event_record(events[buffer_idx]);
-
-                            bytes_read += read_iteration;
-                            ++buffer_idx;
-                            buffer_idx %= n_buffers;
-                        }
-                    }
-                    else
-#endif
-                    {
-                        read_buf.resize(n_size);
-                        file->seek(weight->offs, SEEK_SET);
-                        file->read_raw(read_buf.data(), n_size);
-                        ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size);
-                        if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) {
-                            throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur)));
-                        }
-                    }
-                }
-            }
-
-            size_done += n_size;
-        }
-
-#if defined(GGML_USE_CUDA)
-        // free temporary resources used for async cuda uploads
-        if (cuda_backend) {
-            for (size_t idx = 0; idx < n_buffers;++idx) {
-                ggml_backend_event_synchronize(events[idx]);
-                ggml_backend_event_free(events[idx]);
-                ggml_backend_buffer_free(host_buffers[idx]);
-            }
-            ggml_backend_free(cuda_backend);
-        }
-#endif
-
-        // check validation results
-        bool validation_failed = false;
-        for (auto & future : validation_result) {
-            auto result = future.get();
-            if (!result.second) {
-                LLAMA_LOG_ERROR("%s: tensor '%s' has invalid data\n", __func__, ggml_get_name(result.first));
-                validation_failed = true;
-            }
-        }
-        if (validation_failed) {
-            throw std::runtime_error("found tensors with invalid data");
-        }
-
-        // check if this is the last call and do final cleanup
-        if (size_done >= size_data) {
-            // unmap offloaded tensors and metadata
-            if (use_mmap) {
-                for (uint32_t idx = 0; idx < mappings.size(); idx++) {
-                    const auto & mmap_used = mmaps_used.at(idx);
-                    auto & mapping = mappings.at(idx);
-                    mapping->unmap_fragment(0, mmap_used.first);
-                    if (mmap_used.second != 0) {
-                        mapping->unmap_fragment(mmap_used.second, mapping->size);
-                    }
-                }
-            }
-            if (progress_callback) {
-                // Even though the model is done loading, we still honor
-                // cancellation since we need to free allocations.
-                return progress_callback(1.0f, progress_callback_user_data);
-            }
-        }
-
-        return true;
-    }
-};
-
-template<>
-bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
-    uint32_t tmp;
-    const bool found = get_key(kid, tmp, required);
-    if (found) {
-        result = (enum llama_pooling_type) tmp;
-    } else {
-        result = LLAMA_POOLING_TYPE_UNSPECIFIED;
-    }
-    return found;
-}
-
-
-//
-// load LLaMA models
-//
-
-static const char * llama_model_arch_name(llm_arch arch) {
-    auto it = LLM_ARCH_NAMES.find(arch);
-    if (it == LLM_ARCH_NAMES.end()) {
-        return "unknown";
-    }
-    return it->second;
-}
-
-static std::string llama_model_ftype_name(llama_ftype ftype) {
-    if (ftype & LLAMA_FTYPE_GUESSED) {
-        return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
-    }
-
-    switch (ftype) {
-        case LLAMA_FTYPE_ALL_F32:         return "all F32";
-        case LLAMA_FTYPE_MOSTLY_F16:      return "F16";
-        case LLAMA_FTYPE_MOSTLY_BF16:     return "BF16";
-        case LLAMA_FTYPE_MOSTLY_Q4_0:     return "Q4_0";
-        case LLAMA_FTYPE_MOSTLY_Q4_1:     return "Q4_1";
-        case LLAMA_FTYPE_MOSTLY_Q5_0:     return "Q5_0";
-        case LLAMA_FTYPE_MOSTLY_Q5_1:     return "Q5_1";
-        case LLAMA_FTYPE_MOSTLY_Q8_0:     return "Q8_0";
-        case LLAMA_FTYPE_MOSTLY_Q2_K:     return "Q2_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q2_K_S:   return "Q2_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q3_K_S:   return "Q3_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q3_K_M:   return "Q3_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q3_K_L:   return "Q3_K - Large";
-        case LLAMA_FTYPE_MOSTLY_Q4_K_S:   return "Q4_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q4_K_M:   return "Q4_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q5_K_S:   return "Q5_K - Small";
-        case LLAMA_FTYPE_MOSTLY_Q5_K_M:   return "Q5_K - Medium";
-        case LLAMA_FTYPE_MOSTLY_Q6_K:     return "Q6_K";
-        case LLAMA_FTYPE_MOSTLY_TQ1_0:    return "TQ1_0 - 1.69 bpw ternary";
-        case LLAMA_FTYPE_MOSTLY_TQ2_0:    return "TQ2_0 - 2.06 bpw ternary";
-        case LLAMA_FTYPE_MOSTLY_IQ2_XXS:  return "IQ2_XXS - 2.0625 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ2_XS:   return "IQ2_XS - 2.3125 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ2_S:    return "IQ2_S - 2.5 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ2_M:    return "IQ2_M - 2.7 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_XS:   return "IQ3_XS - 3.3 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_XXS:  return "IQ3_XXS - 3.0625 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ1_S:    return "IQ1_S - 1.5625 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ1_M:    return "IQ1_M - 1.75 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ4_NL:   return "IQ4_NL - 4.5 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ4_XS:   return "IQ4_XS - 4.25 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_S:    return "IQ3_S - 3.4375 bpw";
-        case LLAMA_FTYPE_MOSTLY_IQ3_M:    return "IQ3_S mix - 3.66 bpw";
-        case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4";
-        case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8";
-        case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8";
-
-        default: return "unknown, may not work";
-    }
-}
-
-static const char * llama_model_type_name(e_model type) {
-    switch (type) {
-        case MODEL_14M:           return "14M";
-        case MODEL_17M:           return "17M";
-        case MODEL_22M:           return "22M";
-        case MODEL_33M:           return "33M";
-        case MODEL_60M:           return "60M";
-        case MODEL_70M:           return "70M";
-        case MODEL_80M:           return "80M";
-        case MODEL_109M:          return "109M";
-        case MODEL_137M:          return "137M";
-        case MODEL_160M:          return "160M";
-        case MODEL_220M:          return "220M";
-        case MODEL_250M:          return "250M";
-        case MODEL_270M:          return "270M";
-        case MODEL_335M:          return "335M";
-        case MODEL_410M:          return "410M";
-        case MODEL_450M:          return "450M";
-        case MODEL_770M:          return "770M";
-        case MODEL_780M:          return "780M";
-        case MODEL_0_5B:          return "0.5B";
-        case MODEL_1B:            return "1B";
-        case MODEL_1_3B:          return "1.3B";
-        case MODEL_1_4B:          return "1.4B";
-        case MODEL_1_6B:          return "1.6B";
-        case MODEL_2B:            return "2B";
-        case MODEL_2_8B:          return "2.8B";
-        case MODEL_3B:            return "3B";
-        case MODEL_4B:            return "4B";
-        case MODEL_6B:            return "6B";
-        case MODEL_6_9B:          return "6.9B";
-        case MODEL_7B:            return "7B";
-        case MODEL_8B:            return "8B";
-        case MODEL_9B:            return "9B";
-        case MODEL_11B:           return "11B";
-        case MODEL_12B:           return "12B";
-        case MODEL_13B:           return "13B";
-        case MODEL_14B:           return "14B";
-        case MODEL_15B:           return "15B";
-        case MODEL_16B:           return "16B";
-        case MODEL_20B:           return "20B";
-        case MODEL_30B:           return "30B";
-        case MODEL_34B:           return "34B";
-        case MODEL_35B:           return "35B";
-        case MODEL_40B:           return "40B";
-        case MODEL_65B:           return "65B";
-        case MODEL_70B:           return "70B";
-        case MODEL_236B:          return "236B";
-        case MODEL_314B:          return "314B";
-        case MODEL_SMALL:         return "0.1B";
-        case MODEL_MEDIUM:        return "0.4B";
-        case MODEL_LARGE:         return "0.8B";
-        case MODEL_XL:            return "1.5B";
-        case MODEL_A2_7B:         return "A2.7B";
-        case MODEL_8x7B:          return "8x7B";
-        case MODEL_8x22B:         return "8x22B";
-        case MODEL_16x12B:        return "16x12B";
-        case MODEL_10B_128x3_66B: return "10B+128x3.66B";
-        case MODEL_57B_A14B:      return "57B.A14B";
-        case MODEL_27B:           return "27B";
-        default:                  return "?B";
-    }
-}
-
-static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
-    switch (type) {
-        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
-        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
-        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
-        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
-        default:                    return "unknown";
-    }
-}
-
-static void llm_load_arch(llama_model_loader & ml, llama_model & model) {
-    model.arch = ml.get_arch();
-    if (model.arch == LLM_ARCH_UNKNOWN) {
-        throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'");
-    }
-}
-
-static void llm_load_hparams(
-        llama_model_loader & ml,
-        llama_model & model) {
-    auto & hparams = model.hparams;
-    const gguf_context * ctx = ml.meta;
-
-    // get metadata as string
-    for (int i = 0; i < gguf_get_n_kv(ctx); i++) {
-        enum gguf_type type = gguf_get_kv_type(ctx, i);
-        if (type == GGUF_TYPE_ARRAY) {
-            continue;
-        }
-        const char * name = gguf_get_key(ctx, i);
-        const std::string value = gguf_kv_to_str(ctx, i);
-        model.gguf_kv.emplace(name, value);
-    }
-
-    // get general kv
-    ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
-
-    // get hparams kv
-    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
-
-    // everything past this point is not vocab-related
-    if (hparams.vocab_only) {
-        return;
-    }
-
-    ml.get_key(LLM_KV_CONTEXT_LENGTH,    hparams.n_ctx_train);
-    ml.get_key(LLM_KV_EMBEDDING_LENGTH,  hparams.n_embd);
-    ml.get_key(LLM_KV_BLOCK_COUNT,       hparams.n_layer);
-    ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
-    ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
-
-    GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
-    GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
-    if (hparams.n_expert > 0) {
-        GGML_ASSERT(hparams.n_expert_used > 0);
-    } else {
-        GGML_ASSERT(hparams.n_expert_used == 0);
-    }
-
-    // zero-out the per-layer hparams
-    std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
-    std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
-    std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
-
-    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer);
-    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer);
-
-    // n_head_kv is optional, default to n_head
-    hparams.n_head_kv_arr = hparams.n_head_arr;
-
-    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false);
-
-    bool rope_finetuned = false;
-    ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
-    hparams.rope_finetuned = rope_finetuned;
-
-    hparams.n_ctx_orig_yarn = hparams.n_ctx_train;
-    ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false);
-
-    // rope_freq_base (optional)
-    hparams.rope_freq_base_train = 10000.0f;
-    ml.get_key(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train, false);
-
-    std::string rope_scaling("linear");
-    ml.get_key(LLM_KV_ROPE_SCALING_TYPE, rope_scaling, false);
-    hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
-    GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED);
-
-    // rope_freq_scale (inverse of the kv) is optional
-    float ropescale = 0.0f;
-    if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) {
-        // try the old key name
-        ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false);
-    }
-    hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
-
-    ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
-
-    // non-transformer models do not have attention heads
-    if (hparams.n_head() > 0) {
-        // gpt-neox n_rot = rotary_pct * (n_embd / n_head)
-        // gpt-j n_rot = rotary_dim
-
-        hparams.n_embd_head_k = hparams.n_embd / hparams.n_head();
-        ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
-
-        hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
-        ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
-
-        // sanity check for n_rot (optional)
-        hparams.n_rot = hparams.n_embd_head_k;
-
-        ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
-
-        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
-            if (hparams.n_rot != hparams.n_embd_head_k) {
-                throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
-            }
-        }
-    } else {
-        hparams.n_rot = 0;
-        hparams.n_embd_head_k = 0;
-        hparams.n_embd_head_v = 0;
-    }
-
-    // arch-specific KVs
-    switch (model.arch) {
-        case LLM_ARCH_LLAMA:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                if (hparams.n_expert == 8) {
-                    switch (hparams.n_layer) {
-                        case 32: model.type = e_model::MODEL_8x7B; break;
-                        case 56: model.type = e_model::MODEL_8x22B; break;
-                        default: model.type = e_model::MODEL_UNKNOWN;
-                    }
-                } else {
-                    switch (hparams.n_layer) {
-                        case 22: model.type = e_model::MODEL_1B; break;
-                        case 26: model.type = e_model::MODEL_3B; break;
-                        // granite uses a vocab with len 49152
-                        case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break;
-                        case 36: model.type = e_model::MODEL_8B; break; // granite
-                        case 40: model.type = e_model::MODEL_13B; break;
-                        case 48: model.type = e_model::MODEL_34B; break;
-                        case 60: model.type = e_model::MODEL_30B; break;
-                        case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break;
-                        default: model.type = e_model::MODEL_UNKNOWN;
-                    }
-                }
-            } break;
-        case LLM_ARCH_MINICPM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 40: model.type = e_model::MODEL_2B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_GROK:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 64: model.type = e_model::MODEL_314B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_FALCON:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 60: model.type = e_model::MODEL_40B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_BAICHUAN:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-
-                if (model.type == e_model::MODEL_13B) {
-                    // TODO: become GGUF KV parameter
-                    hparams.f_max_alibi_bias = 8.0f;
-                }
-            } break;
-        case LLM_ARCH_STARCODER:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 36: model.type = e_model::MODEL_3B; break;
-                    case 42: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_15B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_REFACT:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_1B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-
-                // TODO: become GGUF KV parameter
-                hparams.f_max_alibi_bias = 8.0f;
-            } break;
-        case LLM_ARCH_BERT:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
-                ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type, false);
-
-                switch (hparams.n_layer) {
-                    case 3:
-                        model.type = e_model::MODEL_17M; break; // bge-micro
-                    case 6:
-                        model.type = e_model::MODEL_22M; break; // MiniLM-L6
-                    case 12:
-                        switch (hparams.n_embd) {
-                            case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small
-                            case 768: model.type = e_model::MODEL_109M; break; // bge-base
-                        } break;
-                    case 24:
-                        model.type = e_model::MODEL_335M; break; // bge-large
-                }
-            } break;
-        case LLM_ARCH_JINA_BERT_V2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
-                ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type);
-                hparams.f_max_alibi_bias = 8.0f;
-
-                switch (hparams.n_layer) {
-                    case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small
-                    case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base
-                }
-            } break;
-        case LLM_ARCH_NOMIC_BERT:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
-                ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
-                ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type);
-
-                if (hparams.n_layer == 12 && hparams.n_embd == 768) {
-                    model.type = e_model::MODEL_137M;
-                }
-            } break;
-        case LLM_ARCH_BLOOM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 30:
-                        switch (hparams.n_embd) {
-                            case 2560: model.type = e_model::MODEL_3B; break;
-                            case 4096: model.type = e_model::MODEL_7B; break;
-                        } break;
-                }
-
-                // TODO: become GGUF KV parameter
-                hparams.f_max_alibi_bias = 8.0f;
-            } break;
-        case LLM_ARCH_MPT:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,      hparams.f_clamp_kqv, false);
-                ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 48: model.type = e_model::MODEL_30B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_STABLELM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_3B; break;
-                    case 40: model.type = e_model::MODEL_12B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_QWEN:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_QWEN2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break;
-                    case 80: model.type = e_model::MODEL_70B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_QWEN2MOE:
-            {
-                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
-                ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
-
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_A2_7B; break;
-                    case 28: model.type = e_model::MODEL_57B_A14B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_PHI2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_3B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_PHI3:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_3B; break;
-                    case 40: model.type = e_model::MODEL_14B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-
-                // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
-                if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
-                    // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
-                    hparams.n_swa = 2047;
-                } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
-                    // default value for Phi-3-mini-128k-instruct
-                    hparams.n_swa = 262144;
-                } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
-                    // default value for Phi-3-medium-128k-instruct
-                    hparams.n_swa = 131072;
-                }
-                bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
-                if (!found_swa && hparams.n_swa == 0) {
-                    throw std::runtime_error("invalid value for sliding_window");
-                }
-            } break;
-        case LLM_ARCH_PLAMO:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_GPT2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 12: model.type = e_model::MODEL_SMALL; break;
-                    case 24: model.type = e_model::MODEL_MEDIUM; break;
-                    case 36: model.type = e_model::MODEL_LARGE; break;
-                    case 48: model.type = e_model::MODEL_XL; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_CODESHELL:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 42: model.type = e_model::MODEL_7B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_ORION:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-
-                switch (hparams.n_layer) {
-                    case 40: model.type = e_model::MODEL_14B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_INTERNLM2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 48: model.type = e_model::MODEL_20B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_GEMMA:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 18: model.type = e_model::MODEL_2B; break;
-                    case 28: model.type = e_model::MODEL_7B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_GEMMA2:
-            {
-                hparams.n_swa = 4096; // default value of gemma 2
-                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
-                ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
-                hparams.attn_soft_cap = true;
-
-                switch (hparams.n_layer) {
-                    case 26: model.type = e_model::MODEL_2B; break;
-                    case 42: model.type = e_model::MODEL_9B; break;
-                    case 46: model.type = e_model::MODEL_27B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_STARCODER2:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 30: model.type = e_model::MODEL_3B; break;
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_15B; break;
-                    case 52: model.type = e_model::MODEL_20B; break; // granite
-                    case 88: model.type = e_model::MODEL_34B; break; // granite
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_MAMBA:
-            {
-                ml.get_key(LLM_KV_SSM_CONV_KERNEL,    hparams.ssm_d_conv);
-                ml.get_key(LLM_KV_SSM_INNER_SIZE,     hparams.ssm_d_inner);
-                ml.get_key(LLM_KV_SSM_STATE_SIZE,     hparams.ssm_d_state);
-                ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
-                ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);
-
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 24:
-                        switch (hparams.n_embd) {
-                            case 768: model.type = e_model::MODEL_SMALL; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 48:
-                        switch (hparams.n_embd) {
-                            case 1024: model.type = e_model::MODEL_MEDIUM; break;
-                            case 1536: model.type = e_model::MODEL_LARGE; break;
-                            case 2048: model.type = e_model::MODEL_XL; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 64:
-                        switch (hparams.n_embd) {
-                            case 2560: model.type = e_model::MODEL_3B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_XVERSE:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    case 80: model.type = e_model::MODEL_65B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_COMMAND_R:
-            {
-                ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 40: model.type = e_model::MODEL_35B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_DBRX:
-        {
-            ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
-            ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,      hparams.f_clamp_kqv);
-
-            switch (hparams.n_layer) {
-                case 40: model.type = e_model::MODEL_16x12B; break;
-                default: model.type = e_model::MODEL_UNKNOWN;
-            }
-        } break;
-        case LLM_ARCH_OLMO:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,     hparams.f_clamp_kqv, false);
-
-                switch (hparams.n_layer) {
-                    case 22: model.type = e_model::MODEL_1B; break;
-                    case 32: model.type = e_model::MODEL_7B; break;
-                    case 80: model.type = e_model::MODEL_70B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_OPENELM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                case 16: model.type = e_model::MODEL_270M; break;
-                case 20: model.type = e_model::MODEL_450M; break;
-                case 28: model.type = e_model::MODEL_1B; break;
-                case 36: model.type = e_model::MODEL_3B; break;
-                default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_GPTNEOX:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
-                switch (hparams.n_layer) {
-                    case 6:
-                        switch (hparams.n_ff()) {
-                            case 512: model.type = e_model::MODEL_14M; break;
-                            case 2048: model.type = e_model::MODEL_70M; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 12:
-                        switch (hparams.n_ff()) {
-                            case 3072: model.type = e_model::MODEL_160M; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 16:
-                        switch (hparams.n_ff()) {
-                            case 8192: model.type = e_model::MODEL_1B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 24:
-                        switch (hparams.n_ff()) {
-                            case 4096: model.type = e_model::MODEL_410M; break;
-                            case 8192: model.type = e_model::MODEL_1_4B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 32:
-                        switch (hparams.n_ff()) {
-                            case 10240: model.type = e_model::MODEL_2_8B; break;
-                            case 16384: model.type = e_model::MODEL_6_9B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 36:
-                        switch (hparams.n_ff()) {
-                            case 20480: model.type = e_model::MODEL_12B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 44:
-                        switch (hparams.n_ff()) {
-                            case 24576: model.type = e_model::MODEL_20B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_ARCTIC:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                if (hparams.n_expert == 128) {
-                    switch (hparams.n_layer) {
-                        case 35: model.type = e_model::MODEL_10B_128x3_66B; break;
-                        default: model.type = e_model::MODEL_UNKNOWN;
-                    }
-                } else {
-                    model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_DEEPSEEK2:
-            {
-                bool is_lite = (hparams.n_layer == 27);
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
-                if (!is_lite) {
-                    ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
-                }
-                ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
-                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
-                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
-                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
-                ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
-
-                switch (hparams.n_layer) {
-                    case 27: model.type = e_model::MODEL_16B; break;
-                    case 60: model.type = e_model::MODEL_236B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_CHATGLM:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                switch (hparams.n_layer) {
-                    case 28: model.type = e_model::MODEL_6B; break;
-                    case 40: model.type = e_model::MODEL_9B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_BITNET:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 26: model.type = e_model::MODEL_3B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_T5:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
-
-                uint32_t dec_start_token_id;
-                if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) {
-                    hparams.dec_start_token_id = dec_start_token_id;
-                }
-
-                switch (hparams.n_layer) {
-                    case 6:  model.type = e_model::MODEL_60M;  break; // t5-small
-                    case 8:  model.type = e_model::MODEL_80M;  break; // flan-t5-small
-                    case 12:
-                        switch (hparams.n_ff()) {
-                            case 3072: model.type = e_model::MODEL_220M; break; // t5-base
-                            case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 24:
-                        switch (hparams.n_ff()) {
-                            case 4096:  model.type = e_model::MODEL_770M; break; // t5-large
-                            case 2816:  model.type = e_model::MODEL_780M; break; // flan-t5-large
-                            case 16384: model.type = e_model::MODEL_3B;   break; // t5-3b
-                            case 5120:  model.type = e_model::MODEL_3B;   break; // flan-t5-xl
-                            case 65536: model.type = e_model::MODEL_11B;  break; // t5-11b
-                            case 10240: model.type = e_model::MODEL_11B;  break; // flan-t5-xxl
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-               }
-            } break;
-        case LLM_ARCH_T5ENCODER:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts);
-                model.type = e_model::MODEL_UNKNOWN;
-            } break;
-        case LLM_ARCH_JAIS:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1_3B; break;
-                    case 40: model.type = e_model::MODEL_13B; break;
-                    /* TODO: add variants */
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_NEMOTRON:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_4B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_EXAONE:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-
-                switch (hparams.n_layer) {
-                    case 32: model.type = e_model::MODEL_8B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        case LLM_ARCH_RWKV6:
-            {
-                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
-                ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
-                ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
-                ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
-                ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
-
-                switch (hparams.n_layer) {
-                    case 24: model.type = e_model::MODEL_1_6B; break;
-                    case 32:
-                        switch (hparams.n_embd) {
-                            case 2560: model.type = e_model::MODEL_3B; break;
-                            case 4096: model.type = e_model::MODEL_7B; break;
-                            default: model.type = e_model::MODEL_UNKNOWN;
-                        } break;
-                    case 61: model.type = e_model::MODEL_14B; break;
-                    default: model.type = e_model::MODEL_UNKNOWN;
-                }
-            } break;
-        default: (void)0;
-    }
-
-    model.ftype = ml.ftype;
-
-    if (hparams.f_max_alibi_bias > 0.0f) {
-        hparams.use_alibi = true;
-    }
-
-    hparams.rope_type = llama_rope_type(&model);
-}
-
-static void llm_load_vocab(
-        llama_model_loader & ml,
-        llama_model & model) {
-    auto & vocab = model.vocab;
-
-    struct gguf_context * ctx = ml.meta;
-
-    const auto kv = LLM_KV(model.arch);
-
-    // determine vocab type
-    {
-        std::string tokenizer_model;
-        std::string tokenizer_pre;
-
-        ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
-        ml.get_key(LLM_KV_TOKENIZER_PRE,   tokenizer_pre, false);
-
-        if (tokenizer_model == "no_vocab") {
-            vocab.type = LLAMA_VOCAB_TYPE_NONE;
-
-            // default special tokens
-            vocab.special_bos_id  = -1;
-            vocab.special_eos_id  = -1;
-            vocab.special_unk_id  = -1;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = -1;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-            vocab.linefeed_id     = -1;
-
-            return;
-        } else if (tokenizer_model == "llama") {
-            vocab.type = LLAMA_VOCAB_TYPE_SPM;
-
-            // default special tokens
-            vocab.special_bos_id  = 1;
-            vocab.special_eos_id  = 2;
-            vocab.special_unk_id  = 0;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = -1;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-        } else if (tokenizer_model == "bert") {
-            vocab.type = LLAMA_VOCAB_TYPE_WPM;
-
-            // default special tokens
-            vocab.special_bos_id  = -1;
-            vocab.special_eos_id  = -1;
-            vocab.special_unk_id  = 100;
-            vocab.special_sep_id  = 102;
-            vocab.special_pad_id  = 0;
-            vocab.special_cls_id  = 101;
-            vocab.special_mask_id = 103;
-        } else if (tokenizer_model == "gpt2") {
-            vocab.type = LLAMA_VOCAB_TYPE_BPE;
-
-            // read bpe merges and populate bpe ranks
-            const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
-            if (merges_keyidx == -1) {
-                throw std::runtime_error("cannot find tokenizer merges in model file\n");
-            }
-
-            const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
-            for (int i = 0; i < n_merges; i++) {
-                const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
-                GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
-
-                std::string first;
-                std::string second;
-
-                const size_t pos = word.find(' ', 1);
-
-                if (pos != std::string::npos) {
-                    first  = word.substr(0, pos);
-                    second = word.substr(pos + 1);
-                }
-
-                vocab.bpe_ranks.emplace(std::make_pair(first, second), i);
-            }
-
-            // default special tokens
-            vocab.special_bos_id  = 11;
-            vocab.special_eos_id  = 11;
-            vocab.special_unk_id  = -1;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = -1;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-        } else if (tokenizer_model == "t5") {
-            vocab.type = LLAMA_VOCAB_TYPE_UGM;
-
-            // default special tokens
-            vocab.special_bos_id  = -1;
-            vocab.special_eos_id  = 1;
-            vocab.special_unk_id  = 2;
-            vocab.special_sep_id  = -1;
-            vocab.special_pad_id  = 0;
-            vocab.special_cls_id  = -1;
-            vocab.special_mask_id = -1;
-
-            const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
-            if (precompiled_charsmap_keyidx != -1) {
-                size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
-                const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
-                vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap);
-#ifdef IS_BIG_ENDIAN
-                // correct endiannes of data in precompiled_charsmap binary blob
-                uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0];
-                *xcda_blob_size = __builtin_bswap32(*xcda_blob_size);
-                assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap);
-                size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t);
-                uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)];
-                for (size_t i = 0; i < xcda_array_size; ++i) {
-                    xcda_array[i] = __builtin_bswap32(xcda_array[i]);
-                }
-#endif
-            }
-        } else if (tokenizer_model == "rwkv") {
-            vocab.type = LLAMA_VOCAB_TYPE_RWKV;
-
-            // default special tokens
-            vocab.special_bos_id = -1;
-            vocab.special_eos_id = -1;
-            vocab.special_unk_id = -1;
-            vocab.special_sep_id = -1;
-            vocab.special_pad_id = -1;
-        } else {
-            throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
-        }
-
-        // for now, only BPE models have pre-tokenizers
-        if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = true;
-            if (tokenizer_pre.empty()) {
-                LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
-                LLAMA_LOG_WARN("%s:                                             \n", __func__);
-                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
-                LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED!        \n", __func__);
-                LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL             \n", __func__);
-                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
-                LLAMA_LOG_WARN("%s:                                             \n", __func__);
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            } else if (tokenizer_pre == "default") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            } else if (
-                    tokenizer_pre == "llama3"   ||
-                    tokenizer_pre == "llama-v3" ||
-                    tokenizer_pre == "llama-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
-                vocab.tokenizer_ignore_merges = true;
-                vocab.tokenizer_add_bos = true;
-            } else if (
-                    tokenizer_pre == "deepseek-llm") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "deepseek-coder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                    tokenizer_pre == "falcon") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
-            } else if (
-                    tokenizer_pre == "mpt") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT;
-            } else if (
-                    tokenizer_pre == "starcoder") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER;
-            } else if (
-                    tokenizer_pre == "gpt-2"   ||
-                    tokenizer_pre == "phi-2"   ||
-                    tokenizer_pre == "jina-es" ||
-                    tokenizer_pre == "jina-de" ||
-                    tokenizer_pre == "jina-v2-es" ||
-                    tokenizer_pre == "jina-v2-de" ||
-                    tokenizer_pre == "jina-v2-code") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2;
-            } else if (
-                    tokenizer_pre == "refact") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT;
-            } else if (
-                tokenizer_pre == "command-r") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "qwen2") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "stablelm2") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2;
-            } else if (
-                tokenizer_pre == "olmo") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO;
-            } else if (
-                tokenizer_pre == "dbrx") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
-            } else if (
-                tokenizer_pre == "smaug-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
-            } else if (
-                tokenizer_pre == "poro-chat") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "chatglm-bpe") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
-                vocab.special_bos_id  = -1;
-            } else if (
-                tokenizer_pre == "viking") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "jais") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
-            } else if (
-                tokenizer_pre == "tekken") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
-                vocab.tokenizer_clean_spaces = false;
-                vocab.tokenizer_ignore_merges = true;
-                vocab.tokenizer_add_bos = true;
-            } else if (
-                tokenizer_pre == "smollm") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
-                vocab.tokenizer_clean_spaces = false;
-            } else if (
-                tokenizer_pre == "codeshell") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
-            } else if (
-                tokenizer_pre == "bloom") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BLOOM;
-            } else if (
-                tokenizer_pre == "gpt3-finnish") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH;
-            } else if (
-                tokenizer_pre == "exaone") {
-                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE;
-            } else {
-                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
-            }
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = true;
-            vocab.tokenizer_clean_spaces = false;
-            vocab.tokenizer_add_bos = true;
-            vocab.tokenizer_add_eos = false;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = true;
-            vocab.tokenizer_add_bos = true;
-            vocab.tokenizer_add_eos = false;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_bos = false;
-            vocab.tokenizer_add_eos = true;
-        } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-            vocab.tokenizer_add_space_prefix = false;
-            vocab.tokenizer_clean_spaces = false;
-            vocab.tokenizer_add_bos = false;
-            vocab.tokenizer_add_eos = false;
-        } else {
-            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-        }
-
-        ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX,      vocab.tokenizer_add_space_prefix,         false);
-        ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false);
-    }
-
-    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
-    if (token_idx == -1) {
-        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
-    }
-
-    const float * scores = nullptr;
-    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
-    if (score_idx != -1) {
-        scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
-    }
-
-    const int * toktypes = nullptr;
-    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
-    if (toktype_idx != -1) {
-        toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
-    }
-
-    const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
-
-    vocab.n_vocab = n_vocab;
-    vocab.id_to_token.resize(n_vocab);
-
-    for (uint32_t i = 0; i < n_vocab; i++) {
-        std::string word = gguf_get_arr_str(ctx, token_idx, i);
-        GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
-
-        vocab.token_to_id[word] = i;
-        vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size());
-
-        auto & token_data = vocab.id_to_token[i];
-        token_data.text  = std::move(word);
-        token_data.score = scores ? scores[i] : 0.0f;
-        token_data.attr  = LLAMA_TOKEN_ATTR_NORMAL;
-
-        if (toktypes) {  //TODO: remove, required until per token attributes are available from GGUF file
-            switch(toktypes[i]) {
-                case LLAMA_TOKEN_TYPE_UNKNOWN:      token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN;      break;
-                case LLAMA_TOKEN_TYPE_UNUSED:       token_data.attr = LLAMA_TOKEN_ATTR_UNUSED;       break;
-                case LLAMA_TOKEN_TYPE_NORMAL:       token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;       break;
-                case LLAMA_TOKEN_TYPE_CONTROL:      token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;      break;
-                case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
-                case LLAMA_TOKEN_TYPE_BYTE:         token_data.attr = LLAMA_TOKEN_ATTR_BYTE;         break;
-                case LLAMA_TOKEN_TYPE_UNDEFINED:    token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
-                default:                            token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
-            }
-        }
-    }
-    GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
-
-    // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
-    if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-        // For Fill-In-the-Middle (FIM)/infill models which where converted
-        // prior to support of FIM special tokens in GGUF, the following
-        // will allow those models to continue to work. The general names
-        // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and
-        // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once
-        // new versions of these models have been published.
-        std::string gen_name;
-        ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false);
-
-        std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(),
-            [](unsigned char c){ return std::tolower(c); });
-
-        if (gen_name.find("code") != std::string::npos) {
-            if (model.arch == LLM_ARCH_LLAMA
-              && 32010 < vocab.id_to_token.size()
-              && vocab.id_to_token[32007].text.find("
") != std::string::npos
-              && vocab.id_to_token[32008].text.find("") != std::string::npos
-              && vocab.id_to_token[32009].text.find("") != std::string::npos
-              && vocab.id_to_token[32010].text.find("") != std::string::npos) {
-                vocab.special_prefix_id = 32007;
-                vocab.special_suffix_id = 32008;
-                vocab.special_middle_id = 32009;
-                vocab.special_eot_id    = 32010;
-            } else if (model.arch == LLM_ARCH_GEMMA
-              && 107 < vocab.id_to_token.size()
-              && vocab.id_to_token[67].text == "<|fim_prefix|>"
-              && vocab.id_to_token[69].text == "<|fim_suffix|>"
-              && vocab.id_to_token[68].text == "<|fim_middle|>"
-              && vocab.id_to_token[107].text == "") {
-                vocab.special_prefix_id = 67;
-                vocab.special_suffix_id = 69;
-                vocab.special_middle_id = 68;
-                // TODO: this is not EOT, it is "file separator" token, needs fix
-                //       https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572
-                //vocab.special_eot_id    = 70;
-                vocab.special_eot_id    = 107;
-            }
-        }
-        try {
-            vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n');
-        } catch (const std::exception & e) {
-            LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
-            vocab.linefeed_id = vocab.special_pad_id;
-        }
-    } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
-        vocab.linefeed_id = vocab.special_pad_id;
-    } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
-        const std::vector ids = llama_tokenize_internal(vocab, "\n", false);
-        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
-        vocab.linefeed_id = ids[0];
-    } else {
-        const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A
-        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
-        vocab.linefeed_id = ids[0];
-    }
-
-    // special tokens
-    {
-        const std::vector> special_token_types = {
-            { LLM_KV_TOKENIZER_BOS_ID,    vocab.special_bos_id    },
-            { LLM_KV_TOKENIZER_EOS_ID,    vocab.special_eos_id    },
-            { LLM_KV_TOKENIZER_UNK_ID,    vocab.special_unk_id    },
-            { LLM_KV_TOKENIZER_SEP_ID,    vocab.special_sep_id    },
-            { LLM_KV_TOKENIZER_PAD_ID,    vocab.special_pad_id    },
-            { LLM_KV_TOKENIZER_CLS_ID,    vocab.special_cls_id    },
-            { LLM_KV_TOKENIZER_MASK_ID,   vocab.special_mask_id   },
-            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
-            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
-            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
-            { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
-            { LLM_KV_TOKENIZER_EOM_ID,    vocab.special_eom_id    },
-        };
-
-        for (const auto & it : special_token_types) {
-            const std::string & key = kv(std::get<0>(it));
-            int32_t & id = std::get<1>(it);
-
-            uint32_t new_id;
-            if (!ml.get_key(std::get<0>(it), new_id, false)) {
-                continue;
-            }
-            if (new_id >= vocab.id_to_token.size()) {
-                LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n",
-                    __func__, key.c_str(), new_id, id);
-            } else {
-                id = new_id;
-            }
-        }
-
-        // Handle add_bos_token and add_eos_token
-        {
-            bool temp = true;
-
-            if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
-                vocab.tokenizer_add_bos = temp;
-            }
-            if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
-                vocab.tokenizer_add_eos = temp;
-            }
-        }
-
-        // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
-        //       for now, we apply this workaround to find the EOT token based on its text
-        if (vocab.special_eot_id == -1) {
-            for (const auto & t : vocab.token_to_id) {
-                if (
-                        // TODO: gemma "" is exported as a normal token, so the following check does not work
-                        //       need to fix convert script
-                        //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
-                        (t.first == "<|eot_id|>" ||
-                         t.first == "<|im_end|>" ||
-                         t.first == "<|end|>" ||
-                         t.first == "" ||
-                         t.first == "<|endoftext|>"
-                        )
-                   ) {
-                    vocab.special_eot_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                            __func__, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                    break;
-                }
-            }
-        }
-
-        // find EOM token: "<|eom_id|>"
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
-        //       for now, we apply this workaround to find the EOM token based on its text
-        if (vocab.special_eom_id == -1) {
-            const auto & t = vocab.token_to_id.find("<|eom_id|>");
-            if (t != vocab.token_to_id.end()) {
-                vocab.special_eom_id = t->second;
-                if ((vocab.id_to_token[t->second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                    LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                        __func__, t->first.c_str());
-                    vocab.id_to_token[t->second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                }
-            }
-        }
-    }
-
-    // build special tokens cache
-    {
-        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
-                vocab.cache_special_tokens.push_back(id);
-            }
-        }
-
-        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
-            [&] (const llama_vocab::id a, const llama_vocab::id b) {
-                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
-            }
-        );
-
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
-    }
-
-    // build token to piece cache
-    {
-        size_t size_cache = 0;
-
-        std::vector cache_token_to_piece(n_vocab);
-
-        for (uint32_t id = 0; id < n_vocab; ++id) {
-            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
-
-            size_cache += cache_token_to_piece[id].size();
-        }
-
-        std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
-
-        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
-    }
-
-    // Handle per token attributes
-    //NOTE: Each model customizes per token attributes.
-    //NOTE: Per token attributes are missing from the GGUF file.
-    //TODO: Extract attributes from GGUF file.
-    {
-        auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool {
-            for (auto substr : substrs) {
-                if (str.find(substr) < std::string::npos) {
-                    return true;
-                }
-            }
-            return false;
-        };
-
-        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
-            uint32_t current = vocab.id_to_token.at(id).attr;
-            current = value ? (current | attr) : (current & ~attr);
-            vocab.id_to_token[id].attr = (llama_token_attr) current;
-        };
-
-        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
-            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
-        };
-
-        std::string model_name;
-        std::string tokenizer_pre;
-
-        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
-        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
-
-        // model name to lowercase
-        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
-            [] (const std::string::value_type x) {
-                return std::tolower(x);
-            }
-        );
-
-        // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
-            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
-        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
-            for (auto id : vocab.cache_special_tokens) {
-                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {""}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {"", "", "<|endoftext|>"}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
-            }
-        }
-    }
-}
-
-static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
-    const auto & hparams = model.hparams;
-    const auto & vocab   = model.vocab;
-
-    const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
-
-    auto print_f = [](const std::function & f, uint32_t n) {
-        bool is_var = false;
-
-        std::vector v;
-        for (uint32_t i = 0; i < n; ++i) {
-            v.push_back(f(i));
-            if (v[i] != v[0]) {
-                is_var = true;
-            }
-        }
-
-        std::stringstream ss;
-
-        if (is_var) {
-            ss << "[";
-            for (uint32_t i = 0; i < n; ++i) {
-                ss << v[i];
-                if (i < n - 1) {
-                    ss << ", ";
-                }
-            }
-            ss << "]";
-        } else {
-            ss << v[0];
-        }
-
-        return ss.str();
-    };
-
-    // hparams
-    LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
-    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch));
-    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
-    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
-    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
-    LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
-
-    if (!hparams.vocab_only) {
-        LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
-        LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
-        LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
-        LLAMA_LOG_INFO("%s: n_head           = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head(il);    }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
-        LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
-        LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
-        LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
-        LLAMA_LOG_INFO("%s: n_gqa            = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il);        }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
-        LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
-        LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
-        LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
-        LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
-        LLAMA_LOG_INFO("%s: n_ff             = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
-        LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
-        LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
-        LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
-        LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
-        LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
-        LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
-        LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
-        LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
-        LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
-        LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
-        LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
-        LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
-        LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
-        LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
-        LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
-    }
-
-    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
-    LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
-    if (ml.n_elements >= 1e12) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, ml.n_elements*1e-12);
-    } else if (ml.n_elements >= 1e9) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
-    } else if (ml.n_elements >= 1e6) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f M\n", __func__, ml.n_elements*1e-6);
-    } else {
-        LLAMA_LOG_INFO("%s: model params     = %.2f K\n", __func__, ml.n_elements*1e-3);
-    }
-    if (ml.n_bytes < GiB) {
-        LLAMA_LOG_INFO("%s: model size       = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0,        ml.n_bytes*8.0/ml.n_elements);
-    } else {
-        LLAMA_LOG_INFO("%s: model size       = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
-    }
-
-    // general kv
-    LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
-
-    // special tokens
-    if (vocab.special_bos_id    != -1) { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,  vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id    != -1) { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,  vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_unk_id    != -1) { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,  vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id    != -1) { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,  vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id    != -1) { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,  vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id    != -1) { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,  vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id   != -1) { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-
-    if (vocab.linefeed_id       != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,       vocab.id_to_token[vocab.linefeed_id].text.c_str() );       }
-    if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token        = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); }
-    if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token        = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
-    if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
-    if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
-
-    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
-
-    if (model.arch == LLM_ARCH_DEEPSEEK2) {
-        LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
-        LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
-        LLAMA_LOG_INFO("%s: n_lora_kv            = %d\n",     __func__, hparams.n_lora_kv);
-        LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
-        LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
-        LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
-        LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
-    }
-
-    if (model.arch == LLM_ARCH_QWEN2MOE) {
-        LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
-        LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
-    }
-}
-
-// Returns false if cancelled by progress_callback
-static bool llm_load_tensors(
-        llama_model_loader & ml,
-        llama_model & model,
-        int n_gpu_layers,
-        enum llama_split_mode split_mode,
-        int main_gpu,
-        const float * tensor_split,
-        bool use_mlock,
-        llama_progress_callback progress_callback,
-        void * progress_callback_user_data) {
-    model.t_start_us = ggml_time_us();
-
-    auto & hparams = model.hparams;
-
-    model.split_mode   = split_mode;
-    model.main_gpu     = main_gpu;
-    model.n_gpu_layers = n_gpu_layers;
-
-    const int n_layer     = hparams.n_layer;
-    const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
-    bool use_mmap_buffer = true;
-
-    // there is very little benefit to offloading the input layer, so always keep it on the CPU
-    model.buft_input = llama_default_buffer_type_cpu(true);
-    //model.buft_input = llama_default_buffer_type_offload(main_gpu);
-
-    model.buft_layer.resize(n_layer);
-
-    // assign cpu layers
-    for (int i = 0; i < i_gpu_start; ++i) {
-        model.buft_layer[i] = llama_default_buffer_type_cpu(true);
-    }
-
-    if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
-        // calculate the split points
-        int device_count = llama_get_device_count(model);
-        bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
-        std::vector splits(device_count);
-        if (all_zero) {
-            // default split, by free memory
-            for (int i = 0; i < device_count; ++i) {
-                splits[i] = llama_get_device_memory(model, i);
-            }
-        } else {
-            std::copy(tensor_split, tensor_split + device_count, splits.begin());
-        }
-
-        // sum and normalize the splits to get the split points
-        float split_sum = 0.0f;
-        for (int i = 0; i < device_count; ++i) {
-            split_sum += splits[i];
-            splits[i] = split_sum;
-        }
-        for (int i = 0; i < device_count; ++i) {
-            splits[i] /= split_sum;
-        }
-
-        // assign the repeating layers to the devices according to the splits
-        int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
-        for (int i = i_gpu_start; i < n_layer; ++i) {
-            int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
-            model.buft_layer[i] = llama_default_buffer_type_offload(model, layer_gpu);
-        }
-        // assign the output layer
-        if (n_gpu_layers > n_layer) {
-            int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
-            model.buft_output = llama_default_buffer_type_offload(model, layer_gpu);
-        } else {
-            model.buft_output = llama_default_buffer_type_cpu(true);
-        }
-    } else {
-        ggml_backend_buffer_type_t split_buft;
-        if (split_mode == LLAMA_SPLIT_MODE_ROW) {
-            split_buft = llama_default_buffer_type_split(model, main_gpu, tensor_split);
-        } else {
-            // LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
-            split_buft = llama_default_buffer_type_offload(model, main_gpu);
-        }
-        // assign the repeating layers
-        for (int i = i_gpu_start; i < n_layer; ++i) {
-            model.buft_layer[i] = {
-                split_buft,
-                llama_default_buffer_type_offload(model, main_gpu)
-            };
-        }
-        // assign the output layer
-        if (n_gpu_layers > n_layer) {
-            model.buft_output = {
-                split_buft,
-                llama_default_buffer_type_offload(model, main_gpu)
-            };
-        } else {
-            model.buft_output = llama_default_buffer_type_cpu(true);
-        }
-    }
-
-    // count used buffer types
-    std::map buft_layer_count;
-    buft_layer_count[model.buft_input.buft]++;
-    buft_layer_count[model.buft_input.buft_matrix]++;
-    buft_layer_count[model.buft_output.buft]++;
-    buft_layer_count[model.buft_output.buft_matrix]++;
-    for (int i = 0; i < n_layer; ++i) {
-        buft_layer_count[model.buft_layer[i].buft]++;
-        buft_layer_count[model.buft_layer[i].buft_matrix]++;
-    }
-
-    // create one context per buffer type
-    size_t ctx_size = ggml_tensor_overhead()*(ml.n_tensors + 1); // +1 for models where tok_embd is duplicated as output
-
-    // for moe merged tensors
-    ctx_size += ggml_tensor_overhead()*n_layer*3;
-
-    std::map ctx_map;
-    for (auto & it : buft_layer_count) {
-        struct ggml_init_params params = {
-            /*.mem_size   =*/ ctx_size,
-            /*.mem_buffer =*/ NULL,
-            /*.no_alloc   =*/ true,
-        };
-        ggml_context * ctx = ggml_init(params);
-        if (!ctx) {
-            throw std::runtime_error(format("failed to create context"));
-        }
-        ctx_map[it.first] = ctx;
-        model.ctxs.push_back(ctx);
-    }
-
-    LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MiB\n", __func__, model.ctxs.size()*ctx_size/1024.0/1024.0);
-
-    // create tensors for the weights
-    {
-        // note: cast to int64_t since we will use these for the tensor dimensions
-        const int64_t n_head        = hparams.n_head();
-        const int64_t n_head_kv     = hparams.n_head_kv();
-        const int64_t n_embd        = hparams.n_embd;
-        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
-        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-        const int64_t n_embd_head_v = hparams.n_embd_head_v;
-        const int64_t n_ff          = hparams.n_ff();
-        const int64_t n_embd_gqa    = n_embd_v_gqa;
-        const int64_t n_vocab       = hparams.n_vocab;
-        const int64_t n_vocab_type  = hparams.n_vocab_type;
-        const int64_t n_rot         = hparams.n_rot;
-        const int64_t n_expert      = hparams.n_expert;
-        const int64_t n_expert_used = hparams.n_expert_used;
-        const int64_t n_ctx_train   = hparams.n_ctx_train;
-
-        if (n_expert > 0 && hparams.n_expert_used == 0) {
-            throw std::runtime_error("model has expert layers but no expert layers are used");
-        }
-
-        ggml_context * ctx_input        = ctx_map.at(model.buft_input.buft);
-        ggml_context * ctx_output       = ctx_map.at(model.buft_output.buft);
-        ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix);
-
-        auto ctx_for_layer       = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); };
-        auto ctx_for_layer_split = [&](int i) { return ctx_map.at(model.buft_layer[i].buft_matrix); };
-
-        model.layers.resize(n_layer);
-
-        const auto tn = LLM_TN(model.arch);
-        switch (model.arch) {
-            case LLM_ARCH_LLAMA:
-            case LLM_ARCH_REFACT:
-            case LLM_ARCH_MINICPM:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
-
-                        // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-
-                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-
-                        if (n_expert == 0) {
-                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                            layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                            layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-
-                            // optional MLP bias
-                            layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        } else {
-                            layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
-
-                            layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            if (layer.ffn_gate_exps) {
-                                layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
-                                layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
-                            } else {
-                                // merge split expert into a single tensor for compatibility with older models
-                                // requires disabling mmap
-                                use_mmap_buffer = false;
-
-                                ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type;
-                                ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
-                                ggml_type type_up   = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, 0).c_str())->type;
-
-                                layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd,   n_ff, n_expert);
-                                layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down,   n_ff, n_embd, n_expert);
-                                layer.ffn_up_exps   = ggml_new_tensor_3d(ctx_split, type_up,   n_embd,   n_ff, n_expert);
-
-                                ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str());
-                                ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str());
-                                ggml_set_name(layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i).c_str());
-
-                                for (uint32_t x = 0; x < n_expert; ++x) {
-                                    // the individual experts are loaded into a view of the merged tensor
-                                    ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x);
-                                    ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x);
-                                    ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x);
-                                }
-                            }
-                        }
-                    }
-                } break;
-            case LLM_ARCH_GROK:
-                {
-                    if (n_expert == 0) {
-                        throw std::runtime_error("Grok model cannot have zero experts");
-                    }
-
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-
-                        layer.ffn_gate_inp  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert});
-                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        if (layer.ffn_gate_exps) {
-                            layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
-                            layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
-                        } else {
-                            // merge split expert into a single tensor for compatibility with older models
-                            // requires disabling mmap
-                            use_mmap_buffer = false;
-
-                            ggml_type type_gate = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, 0).c_str())->type;
-                            ggml_type type_down = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, 0).c_str())->type;
-                            ggml_type type_up   = ml.require_tensor_meta(tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, 0).c_str())->type;
-
-                            layer.ffn_gate_exps = ggml_new_tensor_3d(ctx_split, type_gate, n_embd,   n_ff, n_expert);
-                            layer.ffn_down_exps = ggml_new_tensor_3d(ctx_split, type_down,   n_ff, n_embd, n_expert);
-                            layer.ffn_up_exps   = ggml_new_tensor_3d(ctx_split, type_up,   n_embd,   n_ff, n_expert);
-
-                            ggml_set_name(layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i).c_str());
-                            ggml_set_name(layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i).c_str());
-                            ggml_set_name(layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i).c_str());
-
-                            for (uint32_t x = 0; x < n_expert; ++x) {
-                                // the individual experts are loaded into a view of the merged tensor
-                                ml.create_tensor_as_view(ctx_split, layer.ffn_gate_exps, tn(LLM_TENSOR_FFN_GATE_EXP, "weight", i, x), { n_embd, n_ff }, layer.ffn_gate_exps->nb[2]*x);
-                                ml.create_tensor_as_view(ctx_split, layer.ffn_down_exps, tn(LLM_TENSOR_FFN_DOWN_EXP, "weight", i, x), { n_ff, n_embd }, layer.ffn_down_exps->nb[2]*x);
-                                ml.create_tensor_as_view(ctx_split, layer.ffn_up_exps,   tn(LLM_TENSOR_FFN_UP_EXP,   "weight", i, x), { n_embd, n_ff }, layer.ffn_up_exps->nb[2]*x);
-                            }
-                        }
-
-                        layer.layer_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
-                    }
-                } break;
-            case LLM_ARCH_DBRX:
-            {
-                if (n_expert == 0) {
-                    throw std::runtime_error("DBRX model cannot have zero experts");
-                }
-
-                model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                // output
-                {
-                    model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                    model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                }
-
-                for (int i = 0; i < n_layer; ++i) {
-                    ggml_context * ctx_layer = ctx_for_layer(i);
-                    ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                    auto & layer = model.layers[i];
-
-                    layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                    layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                    layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                    layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
-
-                    layer.ffn_gate_inp  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert});
-                    layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert});
-                    layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert});
-                    layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert});
-                }
-            } break;
-            case LLM_ARCH_BAICHUAN:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_FALCON:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
-
-                        layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_STARCODER:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            // needs to be on GPU
-                            model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
-
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
-
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff});
-                        layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
-                    }
-                } break;
-            case LLM_ARCH_BERT:
-            case LLM_ARCH_NOMIC_BERT:
-                {
-                    model.tok_embd     = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab});
-                    model.type_embd    = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
-
-                    if (model.arch == LLM_ARCH_BERT) {
-                        model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train});
-                    }
-
-                    model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
-                    model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd});
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        if (model.arch == LLM_ARCH_BERT) {
-                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                            layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
-
-                            layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                            layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
-
-                            layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                            layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
-                        } else {
-                            layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        }
-
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd});
-
-                        layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
-                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd});
-
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd});
-
-                        if (model.arch == LLM_ARCH_BERT) {
-                            layer.bo         = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
-                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff});
-                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
-                        } else {
-                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
-                        }
-
-                        layer.layer_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
-                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd});
-                    }
-                } break;
-            case LLM_ARCH_JINA_BERT_V2:
-                {
-                    model.tok_embd  = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}); // word_embeddings
-                    model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); // token_type_embeddings
-
-                    model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
-                    model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}); //LayerNorm bias
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i]; // JinaBertLayer
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd});
-
-                        layer.attn_q_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa});
-
-                        layer.attn_k_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa});
-
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}); //output_dens
-
-                        layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
-                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd});
-
-                        layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd});
-
-                        layer.layer_out_norm   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
-                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias",   i), {n_embd});
-                    }
-                } break;
-            case LLM_ARCH_BLOOM:
-                {
-                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
-                    model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
-                    model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias",   i), {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias",   i), {n_embd + 2*n_embd_gqa});
-
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd});
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias",   i), {n_embd});
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd});
-
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias",   i), {n_ff});
-                    }
-                } break;
-            case LLM_ARCH_MPT:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.attn_q_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.attn_k_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // AWQ ScaleActivation layer
-                        layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_STABLELM:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm =   ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        // optional bias tensors, present in Stable LM 2 1.6B
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // optional q and k layernorms, present in StableLM 2 12B
-                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head},    llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_QWEN:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd*3});
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff/2});
-                    }
-                } break;
-            case LLM_ARCH_QWEN2:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd});
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa});
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_QWEN2MOE:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd});
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa});
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-
-                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
-
-                        GGML_ASSERT(n_expert      > 0);
-                        GGML_ASSERT(n_expert_used > 0);
-
-                        // MoE branch
-                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
-
-                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert});
-                        layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert});
-                        layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert});
-
-                        // Shared expert branch
-                        const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
-
-                        layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
-                        layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, n_ff_shexp});
-                        layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp,     n_embd});
-                        layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, n_ff_shexp});
-                    }
-                } break;
-            case LLM_ARCH_PHI2:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                        model.output_b      = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        if (layer.wqkv == nullptr) {
-                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
-                            layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd});
-
-                            layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
-                            layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i),   {n_embd_gqa});
-
-                            layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
-                            layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i),   {n_embd_gqa});
-                        }
-
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
-
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
-                    }
-                } break;
-            case LLM_ARCH_PHI3:
-                {
-                    const int64_t n_embd_head = n_embd / n_head;
-
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd });
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab });
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd });
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd });
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd });
-
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd });
-                        layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff });
-
-                        layer.rope_long  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                    }
-                } break;
-            case LLM_ARCH_PLAMO:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_GPT2:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
-
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
-
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
-                    }
-                } break;
-            case LLM_ARCH_CODESHELL:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
-
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
-
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
-                    }
-                } break;
-            case LLM_ARCH_ORION:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
-
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_INTERNLM2:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        // layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_GEMMA:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                    model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                    }
-                } break;
-            case LLM_ARCH_GEMMA2:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                    model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
-                        layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
-                    }
-                } break;
-            case LLM_ARCH_STARCODER2:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd});
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa});
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa});
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
-
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-
-                        // optional bias tensors
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP ,  "bias", i), {  n_ff});
-                    }
-                } break;
-            case LLM_ARCH_MAMBA:
-                {
-                    const int64_t d_conv  = hparams.ssm_d_conv;
-                    const int64_t d_inner = hparams.ssm_d_inner;
-                    const int64_t d_state = hparams.ssm_d_state;
-                    const int64_t dt_rank = hparams.ssm_dt_rank;
-
-                    // only an expansion factor of 2 is supported for now
-                    GGML_ASSERT(2 * n_embd == d_inner);
-
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed, duplicated to allow offloading
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        // norm
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
-
-                        layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
-                        layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
-
-                        layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
-
-                        layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
-                        layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
-
-                        // no "weight" suffix for these
-                        layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
-                        layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
-
-                        // out_proj
-                        layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
-                    }
-                } break;
-            case LLM_ARCH_XVERSE:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_COMMAND_R:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        // init output from the input tok embed
-                        model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        if (n_layer >= 64){
-                            layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head});
-                            layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv});
-                        }
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_OLMO:  // adapted from LLM_ARCH_LLAMA with norm params removed
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_OPENELM:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        // init output from the input tok embed
-                        model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        const int64_t n_head      =   hparams.n_head(i);
-                        const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
-                        const int64_t n_ff        =   hparams.n_ff(i);
-
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k});
-                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
-                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                    }
-                } break;
-            case LLM_ARCH_GPTNEOX:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
-
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
-
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
-                    }
-                } break;
-            case LLM_ARCH_ARCTIC:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
-                        layer.ffn_norm_exps = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd});
-                        layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
-                        layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
-                        layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
-                    }
-                } break;
-            case LLM_ARCH_DEEPSEEK2:
-                {
-                    const bool is_lite = (hparams.n_layer == 27);
-
-                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
-                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-
-                    const int64_t q_lora_rank  = hparams.n_lora_q;
-                    const int64_t kv_lora_rank = hparams.n_lora_kv;
-
-                    const int64_t n_ff_exp        = hparams.n_ff_exp;
-                    const int64_t n_expert_shared = hparams.n_expert_shared;
-
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        if (!is_lite) {
-                            layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank});
-                        }
-
-                        layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
-
-                        if (!is_lite) {
-                            layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
-                            layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k});
-                        } else {
-                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
-                        }
-
-                        layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
-                        layer.wkv_b     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
-                        layer.wo        = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-
-                        if (i < (int) hparams.n_layer_dense_lead) {
-                            layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                            layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                            layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                        } else {
-                            layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
-
-                            GGML_ASSERT(n_expert      > 0);
-                            GGML_ASSERT(n_expert_used > 0);
-
-                            // MoE branch
-                            layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert});
-                            layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert});
-                            layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert});
-
-                            // Shared expert branch
-                            layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared});
-                            layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd});
-                            layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared});
-                        }
-                    }
-                } break;
-            case LLM_ARCH_BITNET:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm     = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd});
-                        layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
-
-                        layer.wq       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wk       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wv       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wo       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm     = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd});
-                        layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
-
-                        layer.ffn_gate       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
-                        layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up         = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_scale   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_T5:
-                {
-                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
-
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm     = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd});
-
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm_enc  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd});
-                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
-
-                        layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up_enc   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff});
-
-                        layer.attn_norm  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd});
-                        layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
-
-                        layer.attn_norm_cross  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd});
-                        // this tensor seems to be unused in HF transformers implementation
-                        layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wk_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_T5ENCODER:
-                {
-                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
-
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm_enc  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd});
-                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
-
-                        layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up_enc   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_JAIS:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // Output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
-
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
-
-                        layer.ffn_gate   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff});
-
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
-                    }
-                } break;
-            case LLM_ARCH_CHATGLM:
-                {
-                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
-
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2});
-
-                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                    }
-                } break;
-            case LLM_ARCH_NEMOTRON:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm   = ml.create_tensor(ctx_output,   tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
-                        model.output        = ml.create_tensor(ctx_output_split,  tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-
-                        // optional bias tensors
-                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
-
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-
-                        // optional MLP bias
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_EXAONE:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // output
-                    {
-                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-                        ggml_context * ctx_split = ctx_for_layer_split(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
-
-                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
-                    }
-                } break;
-            case LLM_ARCH_RWKV6:
-                {
-                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-
-                    // Block 0, LN0
-                    model.tok_norm = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
-                    model.tok_norm_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd});
-
-                    // output
-                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                    model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
-                    model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
-
-                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
-                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
-                    const int head_size = hparams.wkv_head_size;
-                    const int attn_hidden_size = n_embd;
-                    const int ffn_size = hparams.n_ff_arr[0];
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        ggml_context * ctx_layer = ctx_for_layer(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
-
-                        layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd});
-                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd});
-
-                        layer.time_mix_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5});
-                        layer.time_mix_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5});
-
-                        layer.time_mix_lerp_x = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1});
-                        layer.time_mix_lerp_w = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1});
-                        layer.time_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1});
-                        layer.time_mix_lerp_v = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1});
-                        layer.time_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1});
-                        layer.time_mix_lerp_g = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1});
-
-                        layer.time_mix_first = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size});
-                        layer.time_mix_decay = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd});
-                        layer.time_mix_decay_w1 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim});
-                        layer.time_mix_decay_w2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size});
-                        layer.time_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd});
-                        layer.time_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd});
-                        layer.time_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd});
-                        layer.time_mix_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd});
-
-                        layer.time_mix_ln = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd});
-                        layer.time_mix_ln_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd});
-                        layer.time_mix_output = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size});
-
-                        layer.channel_mix_lerp_k = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1});
-                        layer.channel_mix_lerp_r = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1});
-
-                        layer.channel_mix_key = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size});
-                        layer.channel_mix_value = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd});
-                        layer.channel_mix_receptance = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd});
-                    }
-
-                } break;
-            default:
-                throw std::runtime_error("unknown architecture");
-        }
-    }
-
-    ml.done_getting_tensors();
-
-    ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr);
-    model.mappings.reserve(ml.mappings.size());
-
-    // create the backend buffers
-    std::vector> ctx_bufs;
-    ctx_bufs.reserve(ctx_map.size());
-
-    // Ensure we have enough capacity for the maximum backend buffer we will potentially create
-    size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
-    model.bufs.reserve(n_max_backend_buffer);
-
-    for (auto & it : ctx_map) {
-        ggml_backend_buffer_type_t buft = it.first;
-        ggml_context * ctx              = it.second;
-
-        llama_buf_map bufs;
-        bufs.reserve(n_max_backend_buffer);
-
-        // only the mmap region containing the tensors in the model is mapped to the backend buffer
-        // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
-        // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
-        if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(true)) {
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                void * addr = nullptr;
-                size_t first, last;
-                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
-                if (first >= last) {
-                    continue;
-                }
-                ggml_backend_buffer_t buf = ggml_backend_cpu_buffer_from_ptr((char *) addr + first, last - first);
-                if (buf == nullptr) {
-                    throw std::runtime_error("unable to allocate backend CPU buffer");
-                }
-                model.bufs.push_back(buf);
-                bufs.emplace(idx, buf);
-#ifdef GGML_USE_CUDA
-                if (n_layer >= n_gpu_layers) {
-                    ggml_backend_cuda_register_host_buffer(
-                        ggml_backend_buffer_get_base(buf),
-                        ggml_backend_buffer_get_size(buf));
-                }
-#endif
-            }
-        }
-#ifdef GGML_USE_METAL
-        else if (ml.use_mmap && use_mmap_buffer && buft == ggml_backend_metal_buffer_type()) {
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                const size_t max_size = ggml_get_max_tensor_size(ctx);
-                void * addr = nullptr;
-                size_t first, last;
-                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
-                if (first >= last) {
-                    continue;
-                }
-                ggml_backend_buffer_t buf = ggml_backend_metal_buffer_from_ptr((char *) addr + first, last - first, max_size);
-                if (buf == nullptr) {
-                    throw std::runtime_error("unable to allocate backend metal buffer");
-                }
-                model.bufs.push_back(buf);
-                bufs.emplace(idx, buf);
-            }
-        }
-#endif
-        else {
-            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-            if (buf == nullptr) {
-                throw std::runtime_error("unable to allocate backend buffer");
-            }
-            model.bufs.push_back(buf);
-            if (use_mlock && ggml_backend_buffer_is_host(buf)) {
-                model.mlock_bufs.emplace_back(new llama_mlock);
-                auto & mlock_buf = model.mlock_bufs.back();
-                mlock_buf->init   (ggml_backend_buffer_get_base(buf));
-                mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
-            }
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                bufs.emplace(idx, buf);
-            }
-        }
-
-        if (bufs.empty()) {
-            throw std::runtime_error("failed to allocate buffer");
-        }
-
-        for (auto & buf : bufs) {
-            // indicate that this buffer contains weights
-            // this is used by ggml_backend_sched to improve op scheduling -> ops that use a weight are preferably scheduled to the backend that contains the weight
-            ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
-        }
-
-        ctx_bufs.emplace_back(ctx, bufs);
-    }
-
-    if (llama_supports_gpu_offload()) {
-        const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
-
-        LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
-        if (n_gpu_layers > (int) hparams.n_layer) {
-            LLAMA_LOG_INFO("%s: offloading non-repeating layers to GPU\n", __func__);
-        }
-
-        const int max_backend_supported_layers = hparams.n_layer + 1;
-        const int max_offloadable_layers       = hparams.n_layer + 1;
-
-        LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
-    }
-
-    // print memory requirements
-    for (ggml_backend_buffer_t buf : model.bufs) {
-        LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
-    }
-
-    // populate tensors_by_name
-    for (ggml_context * ctx : model.ctxs) {
-        for (auto * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) {
-            model.tensors_by_name.emplace_back(ggml_get_name(cur), cur);
-        }
-    }
-
-    // load tensor data
-    for (auto & it : ctx_bufs) {
-        ggml_context * ctx = it.first;
-        auto & bufs = it.second;
-        if (!ml.load_all_data(ctx, bufs, use_mlock ? &model.mlock_mmaps : NULL, progress_callback, progress_callback_user_data)) {
-            return false;
-        }
-    }
-
-    if (use_mmap_buffer) {
-        for (auto & mapping : ml.mappings) {
-            model.mappings.emplace_back(std::move(mapping));
-        }
-    }
-
-    // loading time will be recalculate after the first eval, so
-    // we take page faults deferred by mmap() into consideration
-    model.t_load_us = ggml_time_us() - model.t_start_us;
-    return true;
-}
-
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
-static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
+static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) {
+    // loading time will be recalculated after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    model.t_load_us = 0;
+    time_meas tm(model.t_load_us);
+
+    model.t_start_us = tm.t_start_us;
+
     try {
-        llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
+        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
+
+        ml.print_info();
 
         model.hparams.vocab_only = params.vocab_only;
 
         try {
-            llm_load_arch(ml, model);
+            model.load_arch(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
         }
         try {
-            llm_load_hparams(ml, model);
+            model.load_hparams(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
         }
         try {
-            llm_load_vocab(ml, model);
+            model.load_vocab(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
         }
 
-        llm_load_print_meta(ml, model);
-
-        if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
-            model.hparams.n_vocab != model.vocab.id_to_token.size()) {
-            throw std::runtime_error("vocab size mismatch");
-        }
+        model.load_stats(ml);
+        model.print_info();
 
         if (params.vocab_only) {
             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
             return 0;
         }
 
-#ifdef GGML_USE_KOMPUTE
-        if (params.n_gpu_layers > 0 && (
-            !(model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON)
-            || !(
-                model.ftype == LLAMA_FTYPE_ALL_F32 ||
-                model.ftype == LLAMA_FTYPE_MOSTLY_F16 ||
-                model.ftype == LLAMA_FTYPE_MOSTLY_BF16 ||
-                model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
-                model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1
-            )
-        )) {
-            // TODO(cebtenzzre): propagate this error outside of llama_load_model_from_file
-            LLAMA_LOG_WARN("%s: disabling Kompute due to unsupported model arch or quantization\n", __func__);
-            params.n_gpu_layers = 0;
-        }
-#endif
-
-        if (!llm_load_tensors(
-            ml, model, params.n_gpu_layers, params.split_mode,  params.main_gpu, params.tensor_split, params.use_mlock,
-            params.progress_callback, params.progress_callback_user_data
-        )) {
+        if (!model.load_tensors(ml)) {
             return -2;
         }
     } catch (const std::exception & err) {
@@ -8683,31 +103,52 @@ enum llm_ffn_gate_type {
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
+    LLM_NORM_GROUP,
 };
 
 static struct ggml_tensor * llm_build_inp_embd(
         struct ggml_context * ctx,
        struct llama_context & lctx,
         const llama_hparams & hparams,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_tensor * tok_embd,
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
 
     struct ggml_tensor * inpL;
 
-    if (batch.token) {
-        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
+    if (ubatch.token) {
+        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
         cb(lctx.inp_tokens, "inp_tokens", -1);
         ggml_set_input(lctx.inp_tokens);
 
         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
+
+        // apply lora for embedding tokens if needed
+        for (auto & it : lctx.lora) {
+            struct llama_adapter_lora_weight * lw = it.first->get_weight(tok_embd);
+            if (lw == nullptr) {
+                continue;
+            }
+            const float adapter_scale = it.second;
+            const float scale = lw->get_scale(it.first->alpha, adapter_scale);
+            struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
+                ctx, lw->b, // non-transposed lora_b
+                ggml_get_rows(ctx, lw->a, lctx.inp_tokens)
+            ), scale);
+            inpL = ggml_add(ctx, inpL, inpL_delta);
+        }
     } else {
-       lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
         inpL = lctx.inp_embd;
         ggml_set_input(lctx.inp_embd);
     }
 
+    // For Granite architecture
+    if (hparams.f_embedding_scale != 0.0f) {
+        inpL = ggml_scale(ctx, inpL, hparams.f_embedding_scale);
+    }
+
     cb(inpL, "inp_embd", -1);
 
     return inpL;
@@ -8764,17 +205,16 @@ static struct ggml_tensor * llm_build_lora_mm(
           struct ggml_tensor * w,
           struct ggml_tensor * cur) {
     struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
-    for (auto & it : lctx.lora_adapters) {
-        struct llama_lora_weight * lora = it.first->get_weight(w);
-        if (lora == nullptr) {
+    for (auto & it : lctx.lora) {
+        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
+        if (lw == nullptr) {
             continue;
         }
-        const float alpha = it.first->alpha;
-        const float rank  = (float) lora->b->ne[0];
-        const float scale = alpha ? it.second * alpha / rank : it.second;
+        const float adapter_scale = it.second;
+        const float scale = lw->get_scale(it.first->alpha, adapter_scale);
         struct ggml_tensor * ab_cur = ggml_mul_mat(
-            ctx0, lora->b,
-            ggml_mul_mat(ctx0, lora->a, cur)
+            ctx0, lw->b,
+            ggml_mul_mat(ctx0, lw->a, cur)
         );
         ab_cur = ggml_scale(ctx0, ab_cur, scale);
         res = ggml_add(ctx0, res, ab_cur);
@@ -8790,17 +230,17 @@ static struct ggml_tensor * llm_build_lora_mm_id(
           struct ggml_tensor * cur, // struct ggml_tensor * b
           struct ggml_tensor * ids) {
     struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
-    for (auto & it : lctx.lora_adapters) {
-        struct llama_lora_weight * lora = it.first->get_weight(w);
-        if (lora == nullptr) {
+    for (auto & it : lctx.lora) {
+        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
+        if (lw == nullptr) {
             continue;
         }
         const float alpha = it.first->alpha;
-        const float rank  = (float) lora->b->ne[0];
+        const float rank  = (float) lw->b->ne[0];
         const float scale = alpha ? it.second * alpha / rank : it.second;
         struct ggml_tensor * ab_cur = ggml_mul_mat_id(
-            ctx0, lora->b,
-            ggml_mul_mat_id(ctx0, lora->a, cur, ids),
+            ctx0, lw->b,
+            ggml_mul_mat_id(ctx0, lw->a, cur, ids),
             ids
         );
         ab_cur = ggml_scale(ctx0, ab_cur, scale);
@@ -8819,8 +259,14 @@ static struct ggml_tensor * llm_build_norm(
          const llm_build_cb & cb,
                         int   il) {
     switch (type) {
-        case LLM_NORM:     cur = ggml_norm    (ctx, cur, hparams.f_norm_eps);     break;
-        case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break;
+        case LLM_NORM:       cur = ggml_norm      (ctx, cur, hparams.f_norm_eps);     break;
+        case LLM_NORM_RMS:   cur = ggml_rms_norm  (ctx, cur, hparams.f_norm_rms_eps); break;
+        case LLM_NORM_GROUP:
+            {
+                cur = ggml_reshape_3d(ctx, cur, cur->ne[0], 1, cur->ne[1]);
+                cur = ggml_group_norm(ctx, cur, hparams.n_norm_groups, hparams.f_norm_group_eps);
+                cur = ggml_reshape_2d(ctx, cur, cur->ne[0],    cur->ne[2]);
+            } break;
     }
 
     if (mw || mb) {
@@ -8976,12 +422,14 @@ static struct ggml_tensor * llm_build_moe_ffn(
          struct ggml_tensor * up_exps,
          struct ggml_tensor * gate_exps,
          struct ggml_tensor * down_exps,
+         struct ggml_tensor * exp_probs_b,
                     int64_t   n_expert,
                     int64_t   n_expert_used,
             llm_ffn_op_type   type_op,
                        bool   norm_w,
                        bool   scale_w,
                       float   w_scale,
+llama_expert_gating_func_type gating_op,
          const llm_build_cb & cb,
                         int   il) {
     int64_t n_embd = cur->ne[0];
@@ -8990,11 +438,31 @@ static struct ggml_tensor * llm_build_moe_ffn(
     ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
     cb(logits, "ffn_moe_logits", il);
 
-    ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
+    ggml_tensor * probs = nullptr;
+    switch (gating_op) {
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
+            {
+                probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
+            } break;
+        case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID:
+            {
+                probs = ggml_sigmoid(ctx, logits); // [n_expert, n_tokens]
+            } break;
+        default:
+            GGML_ABORT("fatal error");
+    }
     cb(probs, "ffn_moe_probs", il);
 
+    // add experts selection bias - introduced in DeepSeek V3
+    // leave probs unbiased as it's later used to get expert weights
+    ggml_tensor * selection_probs = probs;
+    if (exp_probs_b != nullptr) {
+        selection_probs = ggml_add(ctx, probs, exp_probs_b);
+        cb(selection_probs, "ffn_moe_probs_biased", il);
+    }
+
     // select experts
-    ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
+    ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
     cb(selected_experts->src[0], "ffn_moe_argsort", il);
     cb(selected_experts, "ffn_moe_topk", il);
 
@@ -9124,20 +592,16 @@ static struct ggml_tensor * llm_build_kqv(
         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
-            ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
-        }
+        ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
 
         cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
     } else {
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
-            // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
-            // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
-            ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-        }
+        // note: this op tends to require high floating point range
+        //       while for some models F16 is enough, for others it is not, so we default to F32 here
+        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
 
         if (model.arch == LLM_ARCH_GROK) {
             // need to do the following:
@@ -9146,9 +610,6 @@ static struct ggml_tensor * llm_build_kqv(
             // kq = 30 * tanh(kq / 30)
             // before the softmax below
 
-            //try from phi2
-            //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-
             kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
             kq = ggml_scale(ctx, kq, 30);
         }
@@ -9258,7 +719,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
     // FIXME: zero-out NANs?
     states = ggml_mul(ctx, states, state_mask);
 
-    // copy states which won't be changed further (between n_seqs and n_rs)
+    // copy states which won't be changed further (between n_seqs and n_kv)
     ggml_build_forward_expand(graph,
         ggml_cpy(ctx,
             ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)),
@@ -9272,7 +733,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
 static struct ggml_tensor * llm_build_mamba(
         struct ggml_context * ctx,
        struct llama_context & lctx,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_cgraph * graph,
          struct ggml_tensor * cur,
          struct ggml_tensor * state_copy,
@@ -9288,17 +749,17 @@ static struct ggml_tensor * llm_build_mamba(
     const int64_t d_inner = hparams.ssm_d_inner;
     const int64_t d_state = hparams.ssm_d_state;
     const int64_t dt_rank = hparams.ssm_dt_rank;
-    const int64_t n_seqs  = batch.n_seqs;
+    const int64_t n_seqs  = ubatch.n_seqs;
     // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
     const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
     // Use the same RMS norm as the final layer norm
     const float norm_rms_eps = hparams.f_norm_rms_eps;
 
-    const int64_t n_seq_tokens = batch.n_seq_tokens;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
     GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(batch.equal_seqs);
-    GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
+    GGML_ASSERT(ubatch.equal_seqs);
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
     struct ggml_tensor * conv_states_all = kv.k_l[il];
     struct ggml_tensor * ssm_states_all  = kv.v_l[il];
@@ -9410,20 +871,24 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         const struct llama_layer * layer,
         struct ggml_tensor * cur,
         struct ggml_tensor * x_prev,
-        struct ggml_tensor ** wkv_state) {
-    size_t n_embed      = cur->ne[0];
+        struct ggml_tensor ** wkv_state,
+        size_t wkv_head_size,
+        size_t head_count_kv) {
+    size_t n_embd       = cur->ne[0];
     size_t n_seq_tokens = cur->ne[1];
     size_t n_seqs       = cur->ne[2];
 
-    size_t head_size  = layer->time_mix_first->ne[0];
-    size_t head_count = layer->time_mix_first->ne[1];
+    size_t head_size  = wkv_head_size;
+    size_t head_count = n_embd / head_size;
 
     size_t n_tokens = n_seqs * n_seq_tokens;
 
+    bool is_qrwkv = layer->time_mix_first == nullptr;
+
     struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
 
-    sx  = ggml_reshape_2d(ctx, sx,  n_embed, n_tokens);
-    cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
+    sx  = ggml_reshape_2d(ctx, sx,  n_embd, n_tokens);
+    cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
 
     struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
 
@@ -9448,69 +913,64 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         xxx
     );
 
-    struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], 0);
-    struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * sizeof(float));
-    struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 2 * sizeof(float));
-    struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 3 * sizeof(float));
-    struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 4 * sizeof(float));
+    struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
+    if (layer->time_mix_lerp_fused) {
+        // fusing these weights makes some performance improvement
+        sx  = ggml_reshape_3d(ctx, sx,  n_embd, 1, n_tokens);
+        cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
+        xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+    } else {
+        // for backward compatibility
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
 
-    struct ggml_tensor * xw = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mw, layer->time_mix_lerp_w),
-            sx
-        ),
-        cur
-    );
+        xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
+        xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
+        xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
+        xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
+        xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
+    }
 
-    struct ggml_tensor * xk = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mk, layer->time_mix_lerp_k),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
+    struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk);
+    struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv);
+    if (layer->time_mix_receptance_b) {
+        r = ggml_add(ctx, r, layer->time_mix_receptance_b);
+    }
+    if (layer->time_mix_key_b) {
+        k = ggml_add(ctx, k, layer->time_mix_key_b);
+    }
+    if (layer->time_mix_value_b) {
+        v = ggml_add(ctx, v, layer->time_mix_value_b);
+    }
 
-    struct ggml_tensor * xv = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mv, layer->time_mix_lerp_v),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg);
+    if (is_qrwkv) {
+        g = ggml_sigmoid(ctx, g);
+    } else {
+        g = ggml_silu(ctx, g);
+    }
 
-    struct ggml_tensor * xr = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mr, layer->time_mix_lerp_r),
-            sx
-        ),
-        cur
-    );
+    if (head_count_kv != head_count) {
+        GGML_ASSERT(head_count % head_count_kv == 0);
+        k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens);
+        v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens);
+        struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
+        k = ggml_repeat(ctx, k, tmp);
+        v = ggml_repeat(ctx, v, tmp);
+    }
 
-    struct ggml_tensor * xg = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mg, layer->time_mix_lerp_g),
-            sx
-        ),
-        cur
-    );
-
-    struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk), 1,         head_size, head_count, n_tokens);
-    struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * g = ggml_silu(
-        ctx,
-        llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
-    );
+    k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens);
+    v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens);
+    r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens);
 
     struct ggml_tensor * w = ggml_mul_mat(
         ctx,
@@ -9521,30 +981,40 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         )
     );
 
-    w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed));
+    w = ggml_add(ctx, w, layer->time_mix_decay);
     w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
-    w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+    w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
 
-    k = ggml_transpose(ctx, k);
-    v = ggml_transpose(ctx, v);
-    r = ggml_transpose(ctx, r);
+    if (is_qrwkv) {
+        // k = k * (1 - w)
+        k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
+    }
 
-    struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
-    cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
-    *wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));
+    struct ggml_tensor * wkv_output;
+    if (!layer->time_mix_first) {
+        wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
+    } else {
+        wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    }
+    cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
+    *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
 
-    // group norm with head_count groups
-    cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens);
-    cur = ggml_norm(ctx, cur, 64e-5f);
+    if (!is_qrwkv) {
+        // group norm with head_count groups
+        cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
+        cur = ggml_norm(ctx, cur, 64e-5f);
 
-    // Convert back to regular vectors.
-    cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
-    cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+        // Convert back to regular vectors.
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+        cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+    } else {
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+    }
 
     cur = ggml_mul(ctx, cur, g);
     cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
 
-    return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs);
+    return ggml_reshape_3d(ctx, cur, n_embd, n_seq_tokens, n_seqs);
 }
 
 static struct ggml_tensor * llm_build_rwkv6_channel_mix(
@@ -9574,7 +1044,7 @@ struct llm_build_context {
           llama_context  & lctx;
     const llama_hparams  & hparams;
     const llama_cparams  & cparams;
-    const llama_ubatch   & batch;
+    const llama_ubatch   & ubatch;
     const llama_kv_cache & kv_self;
 
     const int64_t n_embd;
@@ -9620,14 +1090,14 @@ struct llm_build_context {
     // TODO: consider making the entire interface noexcept
     llm_build_context(
         llama_context  & lctx,
-    const llama_ubatch & batch,
+    const llama_ubatch & ubatch,
     const llm_build_cb & cb,
                   bool   worst_case) :
         model            (lctx.model),
         lctx             (lctx),
         hparams          (model.hparams),
         cparams          (lctx.cparams),
-        batch            (batch),
+        ubatch           (ubatch),
         kv_self          (lctx.kv_self),
         n_embd           (hparams.n_embd),
         n_layer          (hparams.n_layer),
@@ -9649,7 +1119,7 @@ struct llm_build_context {
         beta_slow        (cparams.yarn_beta_slow),
         norm_eps         (hparams.f_norm_eps),
         norm_rms_eps     (hparams.f_norm_rms_eps),
-        n_tokens         (batch.n_tokens),
+        n_tokens         (ubatch.n_tokens),
         n_kv             (worst_case ? kv_self.size : kv_self.n),
         n_outputs        (worst_case ? n_tokens : lctx.n_outputs),
         n_outputs_enc    (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
@@ -9690,14 +1160,12 @@ struct llm_build_context {
     }
 
     void free() {
-        if (ctx0) {
-            ggml_free(ctx0);
-            ctx0 = nullptr;
-        }
+        ggml_free(ctx0);
+        ctx0 = nullptr;
     }
 
     struct ggml_cgraph * build_k_shift() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         GGML_ASSERT(kv_self.size == n_ctx);
 
@@ -9709,17 +1177,36 @@ struct llm_build_context {
             const int64_t n_head_kv = hparams.n_head_kv(il);
             const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             struct ggml_tensor * rope_factors = build_rope_factors(il);
-            struct ggml_tensor * tmp =
-                // we rotate only the first n_rot dimensions
-                ggml_rope_ext_inplace(ctx0,
-                        ggml_view_3d(ctx0, kv_self.k_l[il],
-                            n_embd_head_k, n_head_kv, n_ctx,
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
-                            ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
-                            0),
+            struct ggml_tensor * k =
+                ggml_view_3d(ctx0, kv_self.k_l[il],
+                    n_embd_head_k, n_head_kv, n_ctx,
+                    ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+                    ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                    0);
+
+            struct ggml_tensor * tmp;
+            if (ggml_is_quantized(k->type)) {
+                // dequantize to f32 -> RoPE -> quantize back
+                tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
+                cb(tmp, "K_f32", il);
+                for (auto & backend : lctx.backends) {
+                    // Figure out which backend KV cache belongs to
+                    if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) {
+                        ggml_backend_sched_set_tensor_backend(lctx.sched.get(), tmp, backend.get());
+                        break;
+                    }
+                }
+                tmp = ggml_rope_ext_inplace(ctx0, tmp,
                         lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
-
+                cb(tmp, "K_shifted_f32", il);
+                tmp = ggml_cpy(ctx0, tmp, k);
+            } else {
+                // we rotate only the first n_rot dimensions
+                tmp = ggml_rope_ext_inplace(ctx0, k,
+                        lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+            }
             cb(tmp, "K_shifted", il);
             ggml_build_forward_expand(gf, tmp);
         }
@@ -9728,7 +1215,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_defrag(const std::vector & ids) {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         for (uint32_t i = 0; i < ids.size(); ++i) {
             const uint32_t id = ids[i];
@@ -9877,8 +1364,8 @@ struct llm_build_context {
     struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
         // find result_norm tensor for input
         struct ggml_tensor * inp = nullptr;
-        for (int i = gf->n_nodes - 1; i >= 0; --i) {
-            inp = gf->nodes[i];
+        for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
+            inp = ggml_graph_node(gf, i);
             if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
                 break;
             } else {
@@ -9890,6 +1377,10 @@ struct llm_build_context {
         struct ggml_tensor * cur;
 
         switch (pooling_type) {
+            case LLAMA_POOLING_TYPE_NONE:
+                {
+                    cur = inp;
+                } break;
             case LLAMA_POOLING_TYPE_MEAN:
                 {
                     struct ggml_tensor * inp_mean = build_inp_mean();
@@ -9901,9 +1392,26 @@ struct llm_build_context {
                     struct ggml_tensor * inp_cls = build_inp_cls();
                     cur = ggml_get_rows(ctx0, inp, inp_cls);
                 } break;
-            case LLAMA_POOLING_TYPE_NONE:
+            case LLAMA_POOLING_TYPE_RANK:
                 {
-                    cur = inp;
+                    struct ggml_tensor * inp_cls = build_inp_cls();
+                    inp = ggml_get_rows(ctx0, inp, inp_cls);
+
+                    // classification head
+                    // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
+                    GGML_ASSERT(model.cls       != nullptr);
+                    GGML_ASSERT(model.cls_b     != nullptr);
+
+                    cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls, inp), model.cls_b);
+                    cur = ggml_tanh(ctx0, cur);
+
+                    // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+                    // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
+                    if (model.cls_out) {
+                        GGML_ASSERT(model.cls_out_b != nullptr);
+
+                        cur = ggml_add (ctx0, ggml_mul_mat(ctx0, model.cls_out, cur), model.cls_out_b);
+                    }
                 } break;
             default:
                 {
@@ -9966,7 +1474,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_llama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9978,7 +1486,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -9986,6 +1494,7 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -10038,7 +1547,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10049,11 +1558,17 @@ struct llm_build_context {
                 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
             }
 
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
             // feed-forward network
             if (model.layers[il].ffn_gate_inp == nullptr) {
+
                 cur = llm_build_norm(ctx0, ffn_inp, hparams,
                         model.layers[il].ffn_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -10078,13 +1593,20 @@ struct llm_build_context {
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
                         model.layers[il].ffn_down_exps,
+                        nullptr,
                         n_expert, n_expert_used,
                         LLM_FFN_SILU, true,
                         false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                         cb, il);
                 cb(cur, "ffn_moe_out", il);
             }
 
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
@@ -10104,6 +1626,12 @@ struct llm_build_context {
 
         // lm_head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        // For Granite architecture
+        if (hparams.f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
+        }
+
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10111,8 +1639,11 @@ struct llm_build_context {
         return gf;
     }
 
-    struct ggml_cgraph * build_baichuan() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+    struct ggml_cgraph * build_deci() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10121,10 +1652,168 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr;
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_head    = hparams.n_head(il);
+
+            if (n_head == 0) {
+                // attention-free layer of Llama-3_1-Nemotron-51B
+                cur = inpL;
+            } else {
+                // norm
+                cur = llm_build_norm(ctx0, inpL, hparams,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            if (n_head > 0 && n_head_kv == 0) {
+                // "linear attention" of Llama-3_1-Nemotron-51B
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
+                cb(cur, "wo", il);
+            } else if (n_head > 0) {
+                // self-attention
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
+            // modified to support attention-free layer of Llama-3_1-Nemotron-51B
+            struct ggml_tensor * ffn_inp = cur;
+            if (n_head > 0) {
+                ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+            }
+
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // For Granite architecture
+            if (hparams.f_residual_scale) {
+                cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        // For Granite architecture
+        if (hparams.f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct ggml_cgraph * build_baichuan() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -10149,7 +1838,7 @@ struct llm_build_context {
                 cb(Vcur, "Vcur", il);
 
                 switch (model.type) {
-                    case MODEL_7B:
+                    case LLM_TYPE_7B:
                         Qcur = ggml_rope_ext(
                             ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -10161,7 +1850,7 @@ struct llm_build_context {
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         break;
-                    case MODEL_13B:
+                    case LLM_TYPE_13B:
                         Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
                         Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
                         break;
@@ -10227,7 +1916,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_xverse() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10236,7 +1925,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -10330,7 +2019,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_falcon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -10340,7 +2029,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -10450,7 +2139,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_grok() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10462,7 +2151,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // multiply by embedding_multiplier_scale of 78.38367176906169
         inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
@@ -10558,9 +2247,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_GELU, true,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -10607,7 +2298,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_dbrx() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10620,7 +2311,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -10699,9 +2390,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -10733,7 +2426,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_starcoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -10742,7 +2435,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -10837,7 +2530,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_refact() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10845,7 +2538,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -10931,7 +2624,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bert() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -10947,7 +2640,7 @@ struct llm_build_context {
         }
 
         // construct input embeddings (token, type, position)
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // token types are hardcoded to zero ("Sentence A")
         struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
@@ -11115,8 +2808,8 @@ struct llm_build_context {
             inpL = cur;
         }
 
-        // final output
         cur = inpL;
+
         cb(cur, "result_embd", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11125,7 +2818,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bloom() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -11134,7 +2827,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -11226,7 +2919,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mpt() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -11236,7 +2929,7 @@ struct llm_build_context {
         struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -11374,7 +3067,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -11516,7 +3209,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11524,7 +3217,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -11628,7 +3321,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11637,7 +3330,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -11739,8 +3432,126 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_qwen2vl() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4);
+        cb(lctx.inp_pos, "inp_pos", -1);
+        ggml_set_input(lctx.inp_pos);
+        struct ggml_tensor * inp_pos = lctx.inp_pos;
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+        int sections[4];
+        std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_multi(
+                    ctx0,
+                    ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_multi(
+                    ctx0,
+                    ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_qwen2moe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -11752,7 +3563,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -11829,9 +3640,11 @@ struct llm_build_context {
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
                         model.layers[il].ffn_down_exps,
+                        nullptr,
                         n_expert, n_expert_used,
                         LLM_FFN_SILU, false,
                         false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                         cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -11886,7 +3699,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -11897,7 +3710,7 @@ struct llm_build_context {
         struct ggml_tensor * ffn_output;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12007,7 +3820,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -12016,13 +3829,19 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
+        struct ggml_tensor * KQ_mask = nullptr;
+        if (hparams.n_swa == 0) {
+            // Phi-4 doesn't use sliding window attention
+            KQ_mask = build_inp_KQ_mask();
+        } else {
+            KQ_mask = build_inp_KQ_mask_swa();
+        }
 
         for (int il = 0; il < n_layer; ++il) {
             auto residual = inpL;
@@ -12034,7 +3853,7 @@ struct llm_build_context {
 
                 struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm,
-                    NULL,
+                    model.layers[il].attn_norm_b,
                     LLM_NORM_RMS, cb, il);
                 cb(attn_norm_output, "attn_norm", il);
 
@@ -12049,8 +3868,7 @@ struct llm_build_context {
                     Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
                     Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
-                }
-                else {
+                } else {
                     Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
@@ -12080,7 +3898,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -12094,14 +3912,12 @@ struct llm_build_context {
             residual = cur;
 
             cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_norm, NULL,
+                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
                 LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            // FF
-            // special-case: the up and gate tensors are merged into a single tensor
-            // TOOD: support into llm_build_ffn
-            {
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
                 cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         NULL,                      NULL, NULL,
@@ -12109,6 +3925,20 @@ struct llm_build_context {
                         NULL,
                         LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        cb, il);
+                cb(cur, "ffn_moe_out", il);
             }
 
             cur = ggml_add(ctx0, residual, cur);
@@ -12121,11 +3951,16 @@ struct llm_build_context {
 
         cur = llm_build_norm(ctx0, inpL, hparams,
             model.output_norm,
-            NULL,
+            model.output_norm_b,
             LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (model.output_b != nullptr) {
+            cb(cur, "result_output_no_bias", -1);
+            cur = ggml_add(ctx0, cur, model.output_b);
+        }
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -12144,7 +3979,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12239,7 +4074,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gpt2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -12249,7 +4084,7 @@ struct llm_build_context {
         struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12344,7 +4179,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_codeshell() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -12354,7 +4189,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12455,7 +4290,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_orion() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12464,7 +4299,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12573,7 +4408,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_internlm2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12582,7 +4417,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12690,26 +4525,23 @@ struct llm_build_context {
         return gf;
     }
 
-    // ref: https://arxiv.org/abs/2203.03466
-    //      https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
-    // based on the original build_llama() function
-    struct ggml_cgraph * build_minicpm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+    struct ggml_cgraph * build_minicpm3() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
-        const int64_t n_embd_head = hparams.n_embd_head_v;
-        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
-        GGML_ASSERT(n_embd_head == hparams.n_rot);
-
-        const int64_t n_embd = hparams.n_embd;
         //TODO: if the model varies, these parameters need to be read from the model
         const int64_t n_embd_base = 256;
         const float scale_embd  = 12.0f;
         const float scale_depth = 1.4f;
+        const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
+
+        const uint32_t n_embd_head_qk_rope = hparams.n_rot;
+        const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+        const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // scale the input embeddings
         inpL = ggml_scale(ctx0, inpL, scale_embd);
@@ -12724,53 +4556,118 @@ struct llm_build_context {
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
+            struct ggml_tensor * rope_factors = build_rope_factors(il);
             // norm
             cur = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm, NULL,
                     LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
 
-            // self-attention
+            // self_attention
             {
-                // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
-                cb(Qcur, "Qcur", il);
-                if (model.layers[il].bq) {
-                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
-                    cb(Qcur, "Qcur", il);
-                }
+                struct ggml_tensor * q = NULL;
+                // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
+                q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
+                cb(q, "q", il);
 
-                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
-                cb(Kcur, "Kcur", il);
-                if (model.layers[il].bk) {
-                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
-                    cb(Kcur, "Kcur", il);
-                }
+                q = llm_build_norm(ctx0, q, hparams,
+                        model.layers[il].attn_q_a_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(q, "q", il);
 
-                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
-                cb(Vcur, "Vcur", il);
-                if (model.layers[il].bv) {
-                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
-                    cb(Vcur, "Vcur", il);
-                }
+                // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
+                q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
+                cb(q, "q", il);
 
-                Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos, nullptr,
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_head * n_embd_head_qk_rope, n_tokens}
+                struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
+                        ggml_row_size(q->type, hparams.n_embd_head_k),
+                        ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
+                        ggml_row_size(q->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
+                struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
+                cb(kv_pe_compresseed, "kv_pe_compresseed", il);
+
+                // split into {kv_lora_rank, n_tokens}
+                struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        0);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // and {n_embd_head_qk_rope, n_tokens}
+                struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
+                        kv_pe_compresseed->nb[1],
+                        kv_pe_compresseed->nb[1],
+                        ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
+                cb(k_pe, "k_pe", il);
+
+                kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+                kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
+                        model.layers[il].attn_kv_a_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(kv_compressed, "kv_compressed", il);
+
+                // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
+                struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
+                cb(kv, "kv", il);
+
+                // split into {n_head * n_embd_head_qk_nope, n_tokens}
+                struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                        ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
+                        ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        0);
+                cb(k_nope, "k_nope", il);
+
+                // and {n_head * n_embd_head_v, n_tokens}
+                struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
+                        ggml_row_size(kv->type, (n_embd_head_qk_nope)));
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_cont(ctx0, v_states);
+                cb(v_states, "v_states", il);
+
+                v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
+                    ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
+                    0);
+                cb(v_states, "v_states", il);
+
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                q_pe = ggml_rope_ext(
+                    ctx0, q_pe, inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
-                cb(Qcur, "Qcur", il);
+                cb(q_pe, "q_pe", il);
 
-                Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                // shared RoPE key
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
+                k_pe = ggml_rope_ext(
+                    ctx0, k_pe, inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
-                cb(Kcur, "Kcur", il);
+                cb(k_pe, "k_pe", il);
+
+                struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
+                cb(q_states, "q_states", il);
+
+                struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
+                cb(k_states, "k_states", il);
 
                 cur = llm_build_kv(ctx0, lctx, kv_self, gf,
-                        model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        model.layers[il].wo, NULL,
+                        k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -12783,7 +4680,7 @@ struct llm_build_context {
             // scale_res - scale the hidden states for residual connection
             const float scale_res = scale_depth/sqrtf(float(n_layer));
             cur = ggml_scale(ctx0, cur, scale_res);
-            cb(cur, "hidden_scaled", -1);
+            cb(cur, "hidden_scaled", il);
 
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
@@ -12806,7 +4703,7 @@ struct llm_build_context {
 
             // scale the hidden states for residual connection
             cur = ggml_scale(ctx0, cur, scale_res);
-            cb(cur, "hidden_scaled_ffn", -1);
+            cb(cur, "hidden_scaled_ffn", il);
 
             cur = ggml_add(ctx0, cur, ffn_inp);
             cur = lctx.cvec.apply_to(ctx0, cur, il);
@@ -12838,14 +4735,14 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
         cb(inpL, "inp_scaled", -1);
@@ -12946,14 +4843,14 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
         cb(inpL, "inp_scaled", -1);
@@ -12996,9 +4893,9 @@ struct llm_build_context {
 
                 // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
                 switch (model.type) {
-                    case e_model::MODEL_2B:
-                    case e_model::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
-                    case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    case LLM_TYPE_2B:
+                    case LLM_TYPE_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
                     default: GGML_ABORT("fatal error");
                 };
                 cb(Qcur, "Qcur_scaled", il);
@@ -13082,7 +4979,7 @@ struct llm_build_context {
 
 
     struct ggml_cgraph * build_starcoder2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13091,7 +4988,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -13201,13 +5098,13 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mamba() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
         // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         struct ggml_tensor * state_copy = build_inp_s_copy();
         struct ggml_tensor * state_mask = build_inp_s_mask();
@@ -13219,7 +5116,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
 
-            cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
+            cur = llm_build_mamba(ctx0, lctx, ubatch, gf, cur,
                     state_copy, state_mask,
                     kv_head, n_kv, cb, il);
 
@@ -13256,7 +5153,7 @@ struct llm_build_context {
 
     struct ggml_cgraph * build_command_r() {
 
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13265,7 +5162,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -13403,6 +5300,137 @@ struct llm_build_context {
 
     }
 
+    struct ggml_cgraph * build_cohere2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        const float f_logit_scale = hparams.f_logit_scale;
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        // cohere2 requires different mask for layers using sliding window (SWA)
+        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask();
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
+
+        // sliding window switch pattern
+        const int32_t sliding_window_pattern = 4;
+
+        for (int il = 0; il < n_layer; ++il) {
+            // three layers sliding window attention (window size 4096) and ROPE
+            // fourth layer uses global attention without positional embeddings
+            const bool           is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
+            struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+            struct ggml_tensor * ffn_inp = cur;
+
+            // self-attention
+            {
+                // rope freq factors for 128k context
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                if (is_sliding) {
+                    Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
+                                        beta_fast, beta_slow);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+                                        rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
+                                        attn_factor, beta_fast, beta_slow);
+                    cb(Kcur, "Kcur", il);
+                } else {
+                    // For non-sliding layers, just reshape without applying RoPE
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                    cb(Qcur, "Qcur", il);
+
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
+                                   KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur                              = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpL                             = ggml_get_rows(ctx0, inpL, inp_out_ids);
+                ffn_inp                          = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
+            }
+
+            struct ggml_tensor * attn_out = cur;
+
+            // feed-forward network
+            {
+                cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
+                                    NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
+                                    cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            // add together residual + FFN + self-attention
+            cur = ggml_add(ctx0, cur, inpL);
+            cur = ggml_add(ctx0, cur, attn_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (f_logit_scale) {
+            cur = ggml_scale(ctx0, cur, f_logit_scale);
+        }
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // ref: https://allenai.org/olmo
     // based on the original build_llama() function, changes:
     //   * non-parametric layer norm
@@ -13410,7 +5438,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_olmo() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -13422,7 +5450,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -13533,15 +5561,269 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_olmo2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            cur = inpL;
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_ffn(ctx0, lctx, ffn_inp,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                model.layers[il].ffn_post_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    // based on the build_qwen2moe() function, changes:
+    //   * removed shared experts
+    //   * removed bias
+    //   * added q, k norm
+    struct ggml_cgraph * build_olmoe() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    nullptr,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU, false,
+                    false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                    cb, il);
+            cb(cur, "ffn_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_openelm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -13659,7 +5941,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gptneox() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -13668,7 +5950,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -13801,7 +6083,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_arctic() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -13813,7 +6095,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -13900,9 +6182,11 @@ struct llm_build_context {
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
                     model.layers[il].ffn_down_exps,
+                    nullptr,
                     n_expert, n_expert_used,
                     LLM_FFN_SILU, true,
                     false, 0.0,
+                    LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
                     cb, il);
             cb(cur, "ffn_moe_out", il);
 
@@ -13932,8 +6216,165 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_deepseek() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                ggml_tensor * moe_out =
+                        llm_build_moe_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_gate_inp,
+                            model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps,
+                            model.layers[il].ffn_down_exps,
+                            nullptr,
+                            n_expert, n_expert_used,
+                            LLM_FFN_SILU, false,
+                            false, hparams.expert_weights_scale,
+                            LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                            cb, il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                // FFN shared expert
+                {
+                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
+                            NULL,
+                            LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                }
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     struct ggml_cgraph * build_deepseek2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -13954,7 +6395,7 @@ struct llm_build_context {
         struct ggml_tensor * inpL;
 
         // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14055,7 +6496,7 @@ struct llm_build_context {
                     0);
                 cb(v_states, "v_states", il);
 
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 q_pe = ggml_rope_ext(
                     ctx0, q_pe, inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -14064,7 +6505,7 @@ struct llm_build_context {
                 cb(q_pe, "q_pe", il);
 
                 // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 k_pe = ggml_rope_ext(
                     ctx0, k_pe, inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -14115,9 +6556,11 @@ struct llm_build_context {
                             model.layers[il].ffn_up_exps,
                             model.layers[il].ffn_gate_exps,
                             model.layers[il].ffn_down_exps,
+                            model.layers[il].ffn_exp_probs_b,
                             n_expert, n_expert_used,
-                            LLM_FFN_SILU, false,
+                            LLM_FFN_SILU, hparams.expert_weights_norm,
                             true, hparams.expert_weights_scale,
+                            (enum llama_expert_gating_func_type) hparams.expert_gating_func,
                             cb, il);
                 cb(moe_out, "ffn_moe_out", il);
 
@@ -14161,7 +6604,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bitnet() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -14169,7 +6612,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14303,6 +6746,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
+        // FIXME: do not use model.tok_embd directly, duplicate as model.output
         cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur);
         cb(cur, "result_output", -1);
 
@@ -14310,8 +6754,8 @@ struct llm_build_context {
         return gf;
     }
 
-    struct ggml_cgraph * build_t5_encoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+    struct ggml_cgraph * build_t5_enc() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -14323,7 +6767,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         GGML_ASSERT(lctx.is_encoding);
         struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
@@ -14442,8 +6886,8 @@ struct llm_build_context {
         return gf;
     }
 
-    struct ggml_cgraph * build_t5_decoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+    struct ggml_cgraph * build_t5_dec() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -14455,7 +6899,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         GGML_ASSERT(!lctx.is_encoding);
         GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
@@ -14648,7 +7092,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_jais() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -14657,7 +7101,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -14740,7 +7184,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_chatglm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -14749,7 +7193,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14854,7 +7298,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_nemotron() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -14863,7 +7307,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14975,7 +7419,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_exaone() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -14987,7 +7431,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -15102,16 +7546,16 @@ struct llm_build_context {
     }
 
     ggml_cgraph * build_rwkv6() {
-        ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // Token shift state dimensions should be 2 * n_emb
         GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
 
-        const int64_t n_seqs = batch.n_seqs;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_seqs = ubatch.n_seqs;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
         GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(batch.equal_seqs);
+        GGML_ASSERT(ubatch.equal_seqs);
         GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
 
         struct ggml_tensor * cur;
@@ -15119,7 +7563,7 @@ struct llm_build_context {
         struct ggml_tensor * state_copy = build_inp_s_copy();
         struct ggml_tensor * state_mask = build_inp_s_mask();
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
         inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
 
         for (int il = 0; il < n_layer; ++il) {
@@ -15147,7 +7591,7 @@ struct llm_build_context {
                 1
             );
 
-            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
             ggml_build_forward_expand(gf, cur);
             ggml_build_forward_expand(
                 gf,
@@ -15204,9 +7648,449 @@ struct llm_build_context {
         cur = ggml_get_rows(ctx0, cur, inp_out_ids);
 
         cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
+    ggml_cgraph * build_rwkv6qwen2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+
+        const int64_t n_seqs = ubatch.n_seqs;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        struct ggml_tensor * state_copy = build_inp_s_copy();
+        struct ggml_tensor * state_mask = build_inp_s_mask();
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            // (ab)using the KV cache to store the states
+            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.k_l[il], state_copy, state_mask,
+                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.v_l[il], state_copy, state_mask,
+                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
+
+            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
+            struct ggml_tensor * x_prev = ggml_concat(
+                ctx0,
+                token_shift,
+                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+                1
+            );
+
+            struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
+                    ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
+                )
+            );
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
+            ggml_build_forward_expand(gf, ffn_inp);
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    // ref: https://github.com/facebookresearch/chameleon
+    // based on the original build_llama() function, changes:
+    //   * qk-norm
+    //   * swin-norm
+    //   * removed bias
+    //   * removed MoE
+    struct ggml_cgraph * build_chameleon() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            // norm
+            if (hparams.swin_norm) {
+                cur = inpL;
+            } else {
+                cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                if (model.layers[il].attn_q_norm) {
+                    Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens,
+                                ggml_element_size(Qcur) * n_embd_head,
+                                ggml_element_size(Qcur) * n_embd_head * n_head,
+                                0);
+                    cb(Qcur, "Qcur", il);
+
+                    Qcur = llm_build_norm(ctx0, Qcur, hparams,
+                                model.layers[il].attn_q_norm,
+                                model.layers[il].attn_q_norm_b,
+                                LLM_NORM, cb, il);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                if (model.layers[il].attn_k_norm) {
+                    Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens,
+                                ggml_element_size(Kcur) * n_embd_head,
+                                ggml_element_size(Kcur) * n_embd_head * n_head_kv,
+                                0);
+                    cb(Kcur, "Kcur", il);
+
+                    Kcur = llm_build_norm(ctx0, Kcur, hparams,
+                               model.layers[il].attn_k_norm,
+                               model.layers[il].attn_k_norm_b,
+                               LLM_NORM, cb, il);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, nullptr,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
+                if (hparams.swin_norm) {
+                    cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                }
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                n_tokens = n_outputs;
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            if (!hparams.swin_norm) {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+            }
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            if (hparams.swin_norm) {
+                cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output_with_img_logits", -1);
+
+        // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs.
+        // Needs to be removed once image outputs are supported.
+        int img_token_end_idx = 8196;
+        int img_token_start_idx = 4;
+        int num_img_tokens = img_token_end_idx - img_token_start_idx;
+        // creates 1d tensor of size num_img_tokens and values -FLT_MAX,
+        // which ensures that text token values are always at least larger than image token values
+        struct ggml_tensor * img_logits = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, num_img_tokens);
+        img_logits = ggml_clamp(ctx0, img_logits, -FLT_MAX, -FLT_MAX);
+        cb(img_logits, "img_logits", -1);
+        cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct ggml_cgraph * build_wavtokenizer_dec() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));
+
+        cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1);
+        cur = ggml_add(ctx0, cur, model.conv1d_b);
+
+        // posnet
+        for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) {
+            const auto & layer = model.layers[il].posnet;
+
+            inpL = cur;
+
+            switch (il) {
+                case 0:
+                case 1:
+                case 3:
+                case 4:
+                    {
+                        cur = llm_build_norm(ctx0, cur, hparams,
+                                layer.norm1,
+                                layer.norm1_b,
+                                LLM_NORM_GROUP, cb, 0);
+
+                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.conv1_b);
+
+                        cur = llm_build_norm(ctx0, cur, hparams,
+                                layer.norm2,
+                                layer.norm2_b,
+                                LLM_NORM_GROUP, cb, 0);
+
+                        cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.conv2_b);
+
+                        cur = ggml_add(ctx0, cur, inpL);
+                    } break;
+                case 2:
+                    {
+                        cur = llm_build_norm(ctx0, cur, hparams,
+                                layer.attn_norm,
+                                layer.attn_norm_b,
+                                LLM_NORM_GROUP, cb, 0);
+
+                        struct ggml_tensor * q;
+                        struct ggml_tensor * k;
+                        struct ggml_tensor * v;
+
+                        q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1);
+                        k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1);
+                        v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1);
+
+                        q = ggml_add(ctx0, q, layer.attn_q_b);
+                        k = ggml_add(ctx0, k, layer.attn_k_b);
+                        v = ggml_add(ctx0, v, layer.attn_v_b);
+
+                        q = ggml_cont(ctx0, ggml_transpose(ctx0, q));
+                        k = ggml_cont(ctx0, ggml_transpose(ctx0, k));
+
+                        struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+
+                        kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f);
+
+                        cur = ggml_mul_mat(ctx0, kq, v);
+
+                        cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1);
+                        cur = ggml_add(ctx0, cur, layer.attn_o_b);
+
+                        cur = ggml_add(ctx0, cur, inpL);
+                    } break;
+                case 5:
+                    {
+                        cur = llm_build_norm(ctx0, cur, hparams,
+                                layer.norm,
+                                layer.norm_b,
+                                LLM_NORM_GROUP, cb, 0);
+                    } break;
+                default: GGML_ABORT("unknown posnet layer");
+            };
+        }
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.tok_norm,
+                model.tok_norm_b,
+                LLM_NORM, cb, -1);
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        inpL = cur;
+
+        // convnext
+        for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) {
+            const auto & layer = model.layers[il].convnext;
+
+            cur = inpL;
+
+            cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1);
+            cur = ggml_add(ctx0, cur, layer.dw_b);
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    layer.norm,
+                    layer.norm_b,
+                    LLM_NORM, cb, -1);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    layer.pw1, layer.pw1_b, NULL,
+                    NULL,      NULL,        NULL,
+                    layer.pw2, layer.pw2_b, NULL,
+                    NULL,
+                    LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+
+            cur = ggml_mul(ctx0, cur, layer.gamma);
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            inpL = ggml_add(ctx0, cur, inpL);
+        }
+
+        cur = inpL;
+
+        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+
+        // lm_head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
-        cb(cur, "result_output", -1);
+        cur = ggml_add(ctx0, cur, model.output_b);
+        cb(cur, "result_embd", -1);
+
         ggml_build_forward_expand(gf, cur);
 
         return gf;
@@ -15249,7 +8133,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
 
 static struct ggml_cgraph * llama_build_graph(
          llama_context & lctx,
-    const llama_ubatch & batch,
+    const llama_ubatch & ubatch,
                   bool   worst_case) {
     const auto & model = lctx.model;
 
@@ -15264,20 +8148,21 @@ static struct ggml_cgraph * llama_build_graph(
         if (!lctx.cparams.offload_kqv) {
             if (strcmp(name, "kqv_merged_cont") == 0) {
                 // all nodes between the KV store and the attention output are run on the CPU
-                ggml_backend_sched_set_tensor_backend(lctx.sched, cur, lctx.backend_cpu);
+                ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, lctx.backend_cpu);
             }
         }
 
         // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
         // FIXME: fix in ggml_backend_sched
-        const bool full_offload = lctx.model.n_gpu_layers > (int)lctx.model.hparams.n_layer;
-        if (batch.n_tokens < 32 || full_offload) {
+        const bool full_offload = lctx.model.params.n_gpu_layers > (int) lctx.model.hparams.n_layer;
+        if (ubatch.n_tokens < 32 || full_offload) {
             if (il != -1 && strcmp(name, "norm") == 0) {
-                for (auto * backend : lctx.backends) {
-                    if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft) &&
-                        (ggml_backend_supports_op(backend, cur) || ggml_backend_offload_op(backend, cur))) {
-                        ggml_backend_sched_set_tensor_backend(lctx.sched, cur, backend);
-                        break;
+                const auto & dev_layer = lctx.model.dev_layer(il);
+                for (auto & backend : lctx.backends) {
+                    if (ggml_backend_get_device(backend.get()) == dev_layer) {
+                        if (ggml_backend_supports_op(backend.get(), cur)) {
+                            ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, backend.get());
+                        }
                     }
                 }
             }
@@ -15286,15 +8171,22 @@ static struct ggml_cgraph * llama_build_graph(
 
     struct ggml_cgraph * result = NULL;
 
-    struct llm_build_context llm(lctx, batch, cb, worst_case);
+    struct llm_build_context llm(lctx, ubatch, cb, worst_case);
 
     llm.init();
 
     switch (model.arch) {
         case LLM_ARCH_LLAMA:
+        case LLM_ARCH_MINICPM:
+        case LLM_ARCH_GRANITE:
+        case LLM_ARCH_GRANITE_MOE:
             {
                 result = llm.build_llama();
             } break;
+        case LLM_ARCH_DECI:
+            {
+                result = llm.build_deci();
+            } break;
         case LLM_ARCH_BAICHUAN:
             {
                 result = llm.build_baichuan();
@@ -15341,6 +8233,11 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_qwen2();
             } break;
+        case LLM_ARCH_QWEN2VL:
+            {
+                lctx.n_pos_per_token = 4;
+                result = llm.build_qwen2vl();
+            } break;
         case LLM_ARCH_QWEN2MOE:
             {
                 result = llm.build_qwen2moe();
@@ -15350,6 +8247,7 @@ static struct ggml_cgraph * llama_build_graph(
                 result = llm.build_phi2();
             } break;
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
             {
                 result = llm.build_phi3();
             } break;
@@ -15373,9 +8271,9 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_internlm2();
             } break;
-        case LLM_ARCH_MINICPM:
+        case LLM_ARCH_MINICPM3:
             {
-                result = llm.build_minicpm();
+                result = llm.build_minicpm3();
             } break;
         case LLM_ARCH_GEMMA:
             {
@@ -15401,6 +8299,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_command_r();
             } break;
+        case LLM_ARCH_COHERE2:
+            {
+                result = llm.build_cohere2();
+            } break;
         case LLM_ARCH_DBRX:
             {
                 result = llm.build_dbrx();
@@ -15409,6 +8311,14 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_olmo();
             } break;
+        case LLM_ARCH_OLMO2:
+            {
+                result = llm.build_olmo2();
+            } break;
+        case LLM_ARCH_OLMOE:
+            {
+                result = llm.build_olmoe();
+            } break;
         case LLM_ARCH_OPENELM:
             {
                 result = llm.build_openelm();
@@ -15421,6 +8331,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_arctic();
             } break;
+        case LLM_ARCH_DEEPSEEK:
+            {
+                result = llm.build_deepseek();
+            } break;
         case LLM_ARCH_DEEPSEEK2:
             {
                 result = llm.build_deepseek2();
@@ -15436,14 +8350,14 @@ static struct ggml_cgraph * llama_build_graph(
         case LLM_ARCH_T5:
             {
                 if (lctx.is_encoding) {
-                    result = llm.build_t5_encoder();
+                    result = llm.build_t5_enc();
                 } else {
-                    result = llm.build_t5_decoder();
+                    result = llm.build_t5_dec();
                 }
             } break;
         case LLM_ARCH_T5ENCODER:
             {
-                result = llm.build_t5_encoder();
+                result = llm.build_t5_enc();
             } break;
         case LLM_ARCH_JAIS:
             {
@@ -15461,6 +8375,18 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_rwkv6();
             } break;
+        case LLM_ARCH_RWKV6QWEN2:
+            {
+                result = llm.build_rwkv6qwen2();
+            } break;
+        case LLM_ARCH_CHAMELEON:
+            {
+                result = llm.build_chameleon();
+            } break;
+        case LLM_ARCH_WAVTOKENIZER_DEC:
+            {
+                result = llm.build_wavtokenizer_dec();
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -15475,648 +8401,66 @@ static struct ggml_cgraph * llama_build_graph(
     return result;
 }
 
-static void llama_set_k_shift(llama_context & lctx) {
-    const int64_t kv_size = lctx.kv_self.size;
-
-    assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
-
-    int32_t * data = (int32_t *) lctx.inp_K_shift->data;
-
-    for (int i = 0; i < kv_size; ++i) {
-        data[i] = lctx.kv_self.cells[i].delta;
-    }
-}
-
-static void llama_set_s_copy(llama_context & lctx) {
-    const int64_t kv_size = lctx.kv_self.size;
-
-    assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
-
-    int32_t * data = (int32_t *) lctx.inp_s_copy->data;
-
-    for (int i = 0; i < kv_size; ++i) {
-        data[i] = lctx.kv_self.cells[i].src;
-    }
-}
-
-static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
-    // TODO move to hparams if a T5 variant appears that uses a different value
-    const int64_t max_distance = 128;
-
-    if (bidirectional) {
-        n_buckets >>= 1;
-    }
-
-    const int64_t max_exact = n_buckets >> 1;
-
-    int32_t relative_position = x - y;
-    int32_t relative_bucket = 0;
-    if (bidirectional) {
-        relative_bucket += (relative_position > 0) * n_buckets;
-        relative_position = abs(relative_position);
-    } else {
-        relative_position = -std::min(relative_position, 0);
-    }
-    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
-    relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1);
-    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
-    return relative_bucket;
-}
-
-static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
-    //
-    // set input data
-    //
-
-    const auto & hparams = lctx.model.hparams;
-    const auto & cparams = lctx.cparams;
-    const auto & kv_self = lctx.kv_self;
-
-    if (batch.token) {
-        const int64_t n_tokens = batch.n_tokens;
-
-        ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
-    }
-
-    if (batch.embd) {
-        const int64_t n_embd   = hparams.n_embd;
-        const int64_t n_tokens = batch.n_tokens;
-
-        ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
-    }
-
-    if (batch.pos && lctx.inp_pos) {
-        const int64_t n_tokens = batch.n_tokens;
-
-        ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
-    }
-
-    if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
-        GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
-        const int64_t n_tokens = batch.n_tokens;
-
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
-        int32_t * data = (int32_t *) lctx.inp_out_ids->data;
-
-        if (lctx.n_outputs == n_tokens) {
-            for (int i = 0; i < n_tokens; ++i) {
-                data[i] = i;
-            }
-        } else if (batch.output) {
-            int32_t n_outputs = 0;
-            for (int i = 0; i < n_tokens; ++i) {
-                if (batch.output[i]) {
-                    data[n_outputs++] = i;
-                }
-            }
-            // the graph needs to have been passed the correct number of outputs
-            GGML_ASSERT(lctx.n_outputs == n_outputs);
-        } else if (lctx.n_outputs == 1) {
-            // only keep last output
-            data[0] = n_tokens - 1;
-        } else {
-            GGML_ASSERT(lctx.n_outputs == 0);
-        }
-    }
-
-    GGML_ASSERT(
-        // (!a || b) is a logical implication (a -> b)
-        // !hparams.causal_attn -> !cparams.causal_attn
-        (hparams.causal_attn || !cparams.causal_attn) &&
-        "causal attention is not supported by this model"
-    );
-
-    if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
-        // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
-        if (cparams.causal_attn && !lctx.is_encoding) {
-            const int64_t n_kv         = kv_self.n;
-            const int64_t n_tokens     = batch.n_tokens;
-            const int64_t n_seq_tokens = batch.n_seq_tokens;
-            const int64_t n_seqs       = batch.n_seqs;
-
-
-            float * data     = nullptr;
-            float * data_swa = nullptr;
-
-            if (lctx.inp_KQ_mask) {
-                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
-                data = (float *) lctx.inp_KQ_mask->data;
-            }
-
-            if (lctx.inp_KQ_mask_swa) {
-                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
-                data_swa = (float *) lctx.inp_KQ_mask_swa->data;
-            }
-
-            // For causal attention, use only the previous KV cells
-            // of the correct sequence for each token of the batch.
-            // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-            for (int h = 0; h < 1; ++h) {
-                for (int s = 0; s < n_seqs; ++s) {
-                    const llama_seq_id seq_id = batch.seq_id[s][0];
-
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const llama_pos pos = batch.pos[s*n_seq_tokens + j];
-
-                        for (int i = 0; i < n_kv; ++i) {
-                            float f;
-                            if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
-                                f = -INFINITY;
-                            } else {
-                                if (hparams.use_alibi) {
-                                    f = -std::abs(kv_self.cells[i].pos - pos);
-                                } else {
-                                    f = 0.0f;
-                                }
-                            }
-
-                            if (data) {
-                                data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                            }
-
-                            // may need to cut off old tokens for sliding window
-                            if (data_swa) {
-                                if (pos - kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
-                                    f = -INFINITY;
-                                }
-                                data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                            }
-                        }
-                    }
-                }
-
-                if (data) {
-                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                        for (int j = 0; j < n_kv; ++j) {
-                            data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                        }
-                    }
-                }
-
-                if (data_swa) {
-                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                        for (int j = 0; j < n_kv; ++j) {
-                            data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                        }
-                    }
-                }
-            }
-        } else {
-            const int64_t n_tokens     = batch.n_tokens;
-            const int64_t n_seq_tokens = batch.n_seq_tokens;
-            const int64_t n_seqs       = batch.n_seqs;
-            // when using kv cache, the mask needs to match the kv cache size
-            const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
-
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
-
-            float * data = (float *) lctx.inp_KQ_mask->data;
-
-            for (int h = 0; h < 1; ++h) {
-                for (int s1 = 0; s1 < n_seqs; ++s1) {
-                    const llama_seq_id seq_id = batch.seq_id[s1][0];
-
-                    for (int j = 0; j < n_seq_tokens; ++j) {
-                        const int32_t tj = s1*n_seq_tokens + j;
-
-                        for (int s0 = 0; s0 < n_seqs; ++s0) {
-                            for (int i = 0; i < n_seq_tokens; ++i) {
-                                const int32_t ti = s0*n_seq_tokens + i;
-                                float f = -INFINITY;
-
-                                for (int s = 0; s < batch.n_seq_id[s0]; ++s) {
-                                    if (batch.seq_id[s0][s] == seq_id) {
-                                        if (hparams.use_alibi) {
-                                            f = -std::abs(batch.pos[ti] - batch.pos[tj]);
-                                        } else {
-                                            f = 0.0f;
-                                        }
-                                        break;
-                                    }
-                                }
-
-                                data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
-                            }
-                        }
-
-                        for (int i = n_tokens; i < n_stride; ++i) {
-                            data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
-        const int64_t n_tokens     = batch.n_tokens;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_seqs       = batch.n_seqs;
-
-        GGML_ASSERT(lctx.inp_mean);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
-
-        float * data = (float *) lctx.inp_mean->data;
-        memset(lctx.inp_mean->data, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_mean));
-
-        std::vector sum(n_tokens, 0);
-
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
-
-            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
-
-            sum[seq_id] += batch.n_seq_tokens;
-        }
-
-        std::vector div(n_tokens, 0.0f);
-        for (int i = 0; i < n_tokens; ++i) {
-            const uint64_t s = sum[i];
-            if (s > 0) {
-                div[i] = 1.0f/float(s);
-            }
-        }
-
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
-
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
-            }
-        }
-    }
-
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
-        const int64_t n_tokens     = batch.n_tokens;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_seqs       = batch.n_seqs;
-
-        GGML_ASSERT(lctx.inp_cls);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
-
-        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
-        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
-
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
-
-            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS");
-
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
-
-                if (pos == 0) {
-                    data[seq_id] = s*n_seq_tokens + i;
-                }
-            }
-        }
-    }
-
-    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
-        const int64_t n_tokens     = batch.n_tokens;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_seqs       = batch.n_seqs;
-
-        GGML_ASSERT(lctx.inp_cls);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
-
-        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
-        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
-
-        std::vector last_pos(n_tokens, -1);
-        std::vector last_row(n_tokens, -1);
-
-        for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
-
-            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
-            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
-
-            for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
-
-                if (pos >= last_pos[seq_id]) {
-                    last_pos[seq_id] = pos;
-                    last_row[seq_id] = s*n_seq_tokens + i;
-                }
-            }
-        }
-
-        for (int i = 0; i < n_tokens; ++i) {
-            if (last_row[i] >= 0) {
-                data[i] = last_row[i];
-            }
-        }
-    }
-
-    if (kv_self.recurrent) {
-        const int64_t n_kv = kv_self.n;
-
-        if (lctx.inp_s_mask) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
-            float * data = (float *) lctx.inp_s_mask->data;
-
-            // clear unused states
-            for (int i = 0; i < n_kv; ++i) {
-                uint32_t        cell_id = i + kv_self.head;
-                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
-
-                data[i] = (float) (kv_cell.src >= 0);
-
-                // only clear once
-                if (kv_cell.src < 0) {
-                    kv_cell.src = cell_id;
-                }
-            }
-        }
-
-        if (lctx.inp_s_copy) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
-            int32_t * data = (int32_t *) lctx.inp_s_copy->data;
-
-            // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
-            for (uint32_t i = 0; i < n_kv; ++i) {
-                const uint32_t  cell_id = i + kv_self.head;
-                llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
-
-                // prevent out-of-bound sources
-                if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
-                    kv_cell.src = cell_id;
-                }
-
-                data[i] = kv_cell.src;
-
-                // ensure copy only happens once
-                if (kv_cell.src != (int32_t) cell_id) {
-                    kv_cell.src = cell_id;
-                }
-            }
-        }
-    }
-
-    if (lctx.inp_pos_bucket) {
-        const int64_t n_tokens = batch.n_tokens;
-
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
-        GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
-
-        int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
-
-        if (!lctx.is_encoding) {
-            const int64_t n_kv = kv_self.n;
-            for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    for (int i = 0; i < n_kv; ++i) {
-                        data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
-                    }
-                }
-            }
-        } else {
-            for (int h = 0; h < 1; ++h) {
-                for (int j = 0; j < n_tokens; ++j) {
-                    for (int i = 0; i < n_tokens; ++i) {
-                        data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
-                    }
-                }
-            }
-        }
-    }
-
-    if (!lctx.is_encoding && lctx.inp_embd_enc) {
-        assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
-        assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
-
-        ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
-    }
-
-    if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
-        const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
-        const int64_t n_tokens = batch.n_tokens;
-
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
-        GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
-
-        float * data = (float *) lctx.inp_KQ_mask_cross->data;
-
-        for (int h = 0; h < 1; ++h) {
-            for (int j = 0; j < n_tokens; ++j) {
-                for (int i = 0; i < n_output_enc; ++i) {
-                    float f = -INFINITY;
-                    for (int s = 0; s < batch.n_seq_id[j]; ++s) {
-                        const llama_seq_id seq_id = batch.seq_id[j][s];
-                        if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
-                            f = 0.0f;
-                        }
-                    }
-                    data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
-                }
-            }
-
-            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                for (int j = 0; j < n_output_enc; ++j) {
-                    data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
-                }
-            }
-        }
-    }
-}
-
-// Make sure enough space is available for outputs.
-// Returns max number of outputs for which space was reserved.
-static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
-    const auto & cparams = lctx.cparams;
-    const auto & hparams = lctx.model.hparams;
-
-    const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
-
-    const auto n_batch = cparams.n_batch;
-    const auto n_vocab = hparams.n_vocab;
-    const auto n_embd  = hparams.n_embd;
-
-    // TODO: use a per-batch flag for logits presence instead
-    const bool has_logits = !cparams.embeddings;
-    const bool has_embd   =  cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
-
-    const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
-    const size_t embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
-
-    if (lctx.output_ids.empty()) {
-        // init, never resized afterwards
-        lctx.output_ids.resize(n_batch);
-    }
-
-    const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0;
-    const size_t new_size  = (logits_size + embd_size) * sizeof(float);
-
-    // alloc only when more than the current capacity is required
-    // TODO: also consider shrinking the buffer
-    if (!lctx.buf_output || prev_size < new_size) {
-        if (lctx.buf_output) {
-#ifndef NDEBUG
-            // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
-            LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
-#endif
-            ggml_backend_buffer_free(lctx.buf_output);
-            lctx.buf_output = nullptr;
-            lctx.logits = nullptr;
-            lctx.embd = nullptr;
-        }
-
-        lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
-        if (lctx.buf_output == nullptr) {
-            LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
-            return 0;
-        }
-    }
-
-    float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
-
-    lctx.logits = has_logits ? output_base               : nullptr;
-    lctx.embd   = has_embd   ? output_base + logits_size : nullptr;
-
-    lctx.output_size = n_outputs_max;
-    lctx.logits_size = logits_size;
-    lctx.embd_size   = embd_size;
-
-    // set all ids as invalid (negative)
-    std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
-
-    ggml_backend_buffer_clear(lctx.buf_output, 0);
-
-    lctx.n_outputs = 0;
-
-    return n_outputs_max;
-}
-
-// make the outputs have the same order they had in the user-provided batch
-static void llama_output_reorder(struct llama_context * ctx) {
-    std::vector & out_ids = ctx->sbatch.out_ids;
-    if (!out_ids.empty()) {
-        uint32_t n_vocab = ctx->model.hparams.n_vocab;
-        uint32_t n_embd  = ctx->model.hparams.n_embd;
-        int32_t n_outputs = ctx->n_outputs;
-        GGML_ASSERT((size_t) n_outputs == out_ids.size());
-        // TODO: is there something more efficient which also minimizes swaps?
-        // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
-        for (int32_t i = 0; i < n_outputs - 1; ++i) {
-            int32_t j_min = i;
-            for (int32_t j = i + 1; j < n_outputs; ++j) {
-                if (out_ids[j] < out_ids[j_min]) {
-                    j_min = j;
-                }
-            }
-            if (j_min == i) { continue; }
-            std::swap(out_ids[i], out_ids[j_min]);
-            if (ctx->logits_size > 0) {
-                for (uint32_t k = 0; k < n_vocab; k++) {
-                    std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
-                }
-            }
-            if (ctx->embd_size > 0) {
-                for (uint32_t k = 0; k < n_embd; k++) {
-                    std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
-                }
-            }
-        }
-        std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
-        for (int32_t i = 0; i < n_outputs; ++i) {
-            ctx->output_ids[out_ids[i]] = i;
-        }
-        out_ids.clear();
-    }
-}
-
-static void llama_graph_compute(
+// returns the result of ggml_backend_sched_graph_compute_async execution
+static enum ggml_status llama_graph_compute(
           llama_context & lctx,
             ggml_cgraph * gf,
                     int   n_threads,
         ggml_threadpool * threadpool) {
-#ifdef GGML_USE_METAL
-    if (ggml_backend_is_metal(lctx.backend_metal)) {
-        ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
-    }
-#endif
-
     if (lctx.backend_cpu != nullptr) {
-        ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
-        ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
-        ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
+        auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(lctx.backend_cpu));
+        auto * set_threadpool_fn = (decltype(ggml_backend_cpu_set_threadpool) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_set_threadpool");
+        set_threadpool_fn(lctx.backend_cpu, threadpool);
     }
-#ifdef GGML_USE_BLAS
-    if (lctx.backend_blas != nullptr) {
-        ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
-    }
-#endif
 
-    ggml_backend_sched_graph_compute_async(lctx.sched, gf);
+    // set the number of threads for all the backends
+    for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) {
+        set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
+    }
+
+    auto status = ggml_backend_sched_graph_compute_async(lctx.sched.get(), gf);
+    if (status != GGML_STATUS_SUCCESS) {
+        LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
+    }
 
     // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
+
+    return status;
 }
 
-// decode a batch of tokens by evaluating the transformer
-//
-//   - lctx:      llama context
-//   - batch:     batch to evaluate
-//
-// return 0 on success
-// return positive int on warning
-// return negative int on error
-//
-static int llama_decode_internal(
-         llama_context & lctx,
-           llama_batch   batch_all) { // TODO: rename back to batch
-
-    lctx.is_encoding = false;
-    const uint32_t n_tokens_all = batch_all.n_tokens;
-
-    if (n_tokens_all == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
-        return -1;
-    }
-
-    for (uint32_t i = 0; i < n_tokens_all; ++i) {
-        if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= lctx.model.vocab.n_vocab) {
-            LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch_all.token[i]);
-            return -1;
-        }
-    }
-
+static int llama_prepare_sbatch(
+        llama_context     & lctx,
+        const llama_batch & batch,
+        uint32_t          & n_outputs) {
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
-    GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
-
-    GGML_ASSERT(n_tokens_all <= cparams.n_batch);
-
-    GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
-
-    if (lctx.t_compute_start_us == 0) {
-        lctx.t_compute_start_us = ggml_time_us();
-    }
-    lctx.n_queued_tokens += n_tokens_all;
-
-    auto & kv_self = lctx.kv_self;
-
-    const int64_t n_embd  = hparams.n_embd;
-    const int64_t n_vocab = hparams.n_vocab;
-
-    uint32_t n_outputs = 0;
-    uint32_t n_outputs_prev = 0;
-
-    const auto n_ubatch = cparams.n_ubatch;
+    const uint32_t n_tokens_all = batch.n_tokens;
+    const  int64_t n_embd       = hparams.n_embd;
 
     // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
     const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
 
+    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+    if (batch.token) {
+        for (uint32_t i = 0; i < n_tokens_all; ++i) {
+            if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
+                return -1;
+            }
+        }
+    }
+    GGML_ASSERT(n_tokens_all <= cparams.n_batch);
+    GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
+
+    lctx.n_queued_tokens += n_tokens_all;
     lctx.embd_seq.clear();
 
     // count outputs
-    if (batch_all.logits && !embd_pooled) {
+    if (batch.logits && !embd_pooled) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            n_outputs += batch_all.logits[i] != 0;
+            n_outputs += batch.logits[i] != 0;
         }
     } else if (lctx.logits_all || embd_pooled) {
         n_outputs = n_tokens_all;
@@ -16125,8 +8469,8 @@ static int llama_decode_internal(
         n_outputs = 1;
     }
 
-    lctx.sbatch.from_batch(batch_all, n_embd,
-        /* simple_split */ !kv_self.recurrent,
+    lctx.sbatch.from_batch(batch, n_embd,
+        /* simple_split */ !lctx.kv_self.recurrent,
         /* logits_all   */ n_outputs == n_tokens_all);
 
     // reserve output buffer
@@ -16135,78 +8479,158 @@ static int llama_decode_internal(
         return -2;
     };
 
+    return 0;
+}
+
+static int llama_prepare_ubatch(
+        llama_context          & lctx,
+        llama_kv_slot_restorer & kv_slot_restorer,
+        llama_ubatch           & ubatch,
+        const uint32_t           n_outputs,
+        const uint32_t           n_tokens_all) {
+    GGML_ASSERT(lctx.sbatch.n_tokens > 0);
+
+    auto       & kv_self = lctx.kv_self;
+    const auto & cparams = lctx.cparams;
+    const auto & hparams = lctx.model.hparams;
+
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+    if (lctx.kv_self.recurrent) {
+        if (embd_pooled) {
+            // Pooled embeddings cannot be split across ubatches (yet)
+            ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
+        } else {
+            // recurrent model architectures are easier to implement
+            // with equal-length sequences
+            ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
+        }
+    } else {
+        ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
+    }
+
+    // count the outputs in this u_batch
+    {
+        int32_t n_outputs_new = 0;
+
+        if (n_outputs == n_tokens_all) {
+            n_outputs_new = ubatch.n_tokens;
+        } else {
+            GGML_ASSERT(ubatch.output);
+            for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
+                n_outputs_new += int32_t(ubatch.output[i] != 0);
+            }
+        }
+
+        // needs to happen before the graph is built
+        lctx.n_outputs = n_outputs_new;
+    }
+
+    // non-causal masks do not use the KV cache
+    if (hparams.causal_attn) {
+        llama_kv_cache_update(&lctx);
+
+        // if we have enough unused cells before the current head ->
+        //   better to start searching from the beginning of the cache, hoping to fill it
+        if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
+            kv_self.head = 0;
+        }
+
+        const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+        if (!slot) {
+            return 1;
+        }
+        kv_slot_restorer.save(slot);
+
+        if (!kv_self.recurrent) {
+            // a heuristic, to avoid attending the full cache if it is not yet utilized
+            // after enough generations, the benefit from this heuristic disappears
+            // if we start defragmenting the cache, the benefit from this will be more important
+            const uint32_t pad = llama_kv_cache_get_padding(cparams);
+            kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
+            //kv_self.n = llama_kv_cache_cell_max(kv_self);
+        }
+    }
+
+    return 0;
+}
+
+// decode a batch of tokens by evaluating the transformer
+// in case of unsuccessful decoding (error or warning),
+// the kv_cache state will be returned to its original state
+// (for non-recurrent models) or cleaned (for recurrent models)
+//
+//   - lctx:      llama context
+//   - inp_batch: batch to evaluate
+//
+// return 0 on success
+// return positive int on warning
+// return negative int on error
+//
+static int llama_decode_impl(
+         llama_context & lctx,
+           llama_batch   inp_batch) {
+
+    lctx.is_encoding = false;
+
+    if (inp_batch.n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
+        return -1;
+    }
+
+    // temporarily allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
+    const llama_batch & batch = batch_allocr.batch;
+
+    const auto & model   = lctx.model;
+    const auto & vocab   = model.vocab;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    if (lctx.t_compute_start_us == 0) {
+        lctx.t_compute_start_us = ggml_time_us();
+    }
+    auto & kv_self = lctx.kv_self;
+    llama_kv_slot_restorer kv_slot_restorer(kv_self);
+
+    const int64_t n_embd  = hparams.n_embd;
+    const int64_t n_vocab = vocab.n_tokens();
+
+    uint32_t n_outputs = 0;
+    uint32_t n_outputs_prev = 0;
+
+    {
+        const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
+        if (ret != 0) {
+            return ret;
+        }
+    }
+
     while (lctx.sbatch.n_tokens > 0) {
         llama_ubatch ubatch;
-        if (kv_self.recurrent) {
-            if (embd_pooled) {
-                // Pooled embeddings cannot be split across ubatches (yet)
-                ubatch = lctx.sbatch.split_seq(n_ubatch);
-            } else {
-                // recurrent model architectures are easier to implement
-                // with equal-length sequences
-                ubatch = lctx.sbatch.split_equal(n_ubatch);
-            }
-        } else {
-            ubatch = lctx.sbatch.split_simple(n_ubatch);
-        }
-        const uint32_t n_tokens = ubatch.n_tokens;
-
-        // count the outputs in this u_batch
         {
-            int32_t n_outputs_new = 0;
-
-            if (n_outputs == n_tokens_all) {
-                n_outputs_new = n_tokens;
-            } else {
-                GGML_ASSERT(ubatch.output);
-                for (uint32_t i = 0; i < n_tokens; i++) {
-                    n_outputs_new += (int32_t) (ubatch.output[i] != 0);
-                }
+            const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
+            if (ret != 0) {
+                return ret;
             }
-
-            // needs to happen before the graph is built
-            lctx.n_outputs = n_outputs_new;
         }
 
-        int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
-        ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+        const int         n_threads  = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+        ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool   : lctx.threadpool_batch;
 
         GGML_ASSERT(n_threads > 0);
 
-        // non-causal masks do not use the KV cache
-        if (hparams.causal_attn) {
-            llama_kv_cache_update(&lctx);
-
-            // if we have enough unused cells before the current head ->
-            //   better to start searching from the beginning of the cache, hoping to fill it
-            if (kv_self.head > kv_self.used + 2*n_tokens) {
-                kv_self.head = 0;
-            }
-
-            if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
-                return 1;
-            }
-
-            if (!kv_self.recurrent) {
-                // a heuristic, to avoid attending the full cache if it is not yet utilized
-                // after enough generations, the benefit from this heuristic disappears
-                // if we start defragmenting the cache, the benefit from this will be more important
-                const uint32_t pad = llama_kv_cache_get_padding(cparams);
-                kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
-                //kv_self.n = llama_kv_cache_cell_max(kv_self);
-            }
-        }
-
         //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
-        ggml_backend_sched_reset(lctx.sched);
-        ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
+        ggml_backend_sched_reset(lctx.sched.get());
+        ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
         ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
 
         // the output is always the last tensor in the graph
-        struct ggml_tensor * res  = gf->nodes[gf->n_nodes - 1];
-        struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
+        struct ggml_tensor * res  = ggml_graph_node(gf, -1);
+        struct ggml_tensor * embd = ggml_graph_node(gf, -2);
 
         if (lctx.n_outputs == 0) {
             // no output
@@ -16215,9 +8639,9 @@ static int llama_decode_internal(
         } else if (cparams.embeddings) {
             res  = nullptr; // do not extract logits for embedding case
             embd = nullptr;
-            for (int i = gf->n_nodes - 1; i >= 0; --i) {
-                if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
-                    embd = gf->nodes[i];
+            for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
+                if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
+                    embd = ggml_graph_node(gf, i);
                     break;
                 }
             }
@@ -16226,17 +8650,30 @@ static int llama_decode_internal(
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
         }
+
         // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
-        ggml_backend_sched_alloc_graph(lctx.sched, gf);
+        ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
 
         llama_set_inputs(lctx, ubatch);
 
-        llama_graph_compute(lctx, gf, n_threads, threadpool);
+        const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
+        if (compute_status != GGML_STATUS_SUCCESS) {
+            kv_slot_restorer.restore(kv_self);
+            switch (compute_status) {
+                case GGML_STATUS_ABORTED:
+                    return 2;
+                case GGML_STATUS_ALLOC_FAILED:
+                    return -2;
+                case GGML_STATUS_FAILED:
+                default:
+                    return -3;
+            }
+        }
 
         // update the kv ring buffer
         {
-            kv_self.head += n_tokens;
+            kv_self.head += ubatch.n_tokens;
 
             // Ensure kv cache head points to a valid index.
             if (kv_self.head >= kv_self.size) {
@@ -16251,7 +8688,7 @@ static int llama_decode_internal(
 
         // extract logits
         if (res) {
-            ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
+            ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), res);
             GGML_ASSERT(backend_res != nullptr);
             GGML_ASSERT(lctx.logits != nullptr);
 
@@ -16267,7 +8704,7 @@ static int llama_decode_internal(
 
         // extract embeddings
         if (embd) {
-            ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
+            ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
             GGML_ASSERT(backend_embd != nullptr);
 
             switch (cparams.pooling_type) {
@@ -16300,6 +8737,20 @@ static int llama_decode_internal(
                             ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
                         }
                     } break;
+                case LLAMA_POOLING_TYPE_RANK:
+                    {
+                        // extract the rerank score - a single float per sequence
+                        auto & embd_seq_out = lctx.embd_seq;
+
+                        for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
+                            const llama_seq_id seq_id = ubatch.seq_id[s][0];
+                            if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
+                                continue;
+                            }
+                            embd_seq_out[seq_id].resize(1);
+                            ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
+                        }
+                    } break;
                 case LLAMA_POOLING_TYPE_UNSPECIFIED:
                     {
                         GGML_ABORT("unknown pooling type");
@@ -16348,7 +8799,7 @@ static int llama_decode_internal(
 
     // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
     // overlap with device computation.
-    ggml_backend_sched_reset(lctx.sched);
+    ggml_backend_sched_reset(lctx.sched.get());
 
     return 0;
 }
@@ -16362,25 +8813,22 @@ static int llama_decode_internal(
 // return positive int on warning
 // return negative int on error
 //
-static int llama_encode_internal(
+static int llama_encode_impl(
          llama_context & lctx,
-           llama_batch   batch) {
+           llama_batch   inp_batch) {
 
     lctx.is_encoding = true;
 
-    const uint32_t n_tokens = batch.n_tokens;
-
-    if (n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
+    if (inp_batch.n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
-    for (uint32_t i = 0; i < n_tokens; ++i) {
-        if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= lctx.model.vocab.n_vocab) {
-            LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch.token[i]);
-            return -1;
-        }
-    }
+    // temporary allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
+
+    const llama_batch & batch = batch_allocr.batch;
+    const uint32_t n_tokens = batch.n_tokens;
 
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
@@ -16388,6 +8836,15 @@ static int llama_encode_internal(
 
     GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
+    if (batch.token) {
+        for (uint32_t i = 0; i < n_tokens; ++i) {
+            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
+                return -1;
+            }
+        }
+    }
+
     // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
     GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
 
@@ -16421,8 +8878,8 @@ static int llama_encode_internal(
 
     GGML_ASSERT(n_threads > 0);
 
-    ggml_backend_sched_reset(lctx.sched);
-    ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
+    ggml_backend_sched_reset(lctx.sched.get());
+    ggml_backend_sched_set_eval_callback(lctx.sched.get(), lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
 
     ggml_cgraph * gf = llama_build_graph(lctx, ubatch, false);
 
@@ -16432,29 +8889,40 @@ static int llama_encode_internal(
     // there are two cases here
     if (llama_model_has_decoder(&lctx.model)) {
         // first case is an encoder-decoder T5 model where embeddings are passed to decoder
-        embd = gf->nodes[gf->n_nodes - 1];
+        embd = ggml_graph_node(gf, -1);
         GGML_ASSERT(strcmp(embd->name, "result_norm") == 0 && "missing result_output tensor");
     } else {
         // second case is an encoder-only T5 model
         if (cparams.embeddings) {
             // only output embeddings if required
-            embd = gf->nodes[gf->n_nodes - 1];
+            embd = ggml_graph_node(gf, -1);
             if (strcmp(embd->name, "result_embd_pooled") != 0) {
-                embd = gf->nodes[gf->n_nodes - 2];
+                embd = ggml_graph_node(gf, -2);
             }
             GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
         }
     }
 
-    ggml_backend_sched_alloc_graph(lctx.sched, gf);
+    ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
 
     llama_set_inputs(lctx, ubatch);
 
-    llama_graph_compute(lctx, gf, n_threads, threadpool);
+    const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
+    switch (compute_status) {
+        case GGML_STATUS_SUCCESS:
+            break;
+        case GGML_STATUS_ABORTED:
+            return 2;
+        case GGML_STATUS_ALLOC_FAILED:
+            return -2;
+        case GGML_STATUS_FAILED:
+        default:
+            return -3;
+    }
 
     // extract embeddings
     if (embd) {
-        ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
+        ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched.get(), embd);
         GGML_ASSERT(backend_embd != nullptr);
 
         if (llama_model_has_decoder(&lctx.model)) {
@@ -16504,6 +8972,13 @@ static int llama_encode_internal(
                             ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
                         }
                     } break;
+                case LLAMA_POOLING_TYPE_RANK:
+                    {
+                        // TODO: this likely should be the same logic as in llama_decoder_internal, but better to
+                        //       wait for an encoder model that requires this pooling type in order to test it
+                        //       https://github.com/ggerganov/llama.cpp/pull/9510
+                        GGML_ABORT("RANK pooling not implemented yet");
+                    }
                 case LLAMA_POOLING_TYPE_UNSPECIFIED:
                     {
                         GGML_ABORT("unknown pooling type");
@@ -16514,13 +8989,13 @@ static int llama_encode_internal(
 
     // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
     // overlap with device computation.
-    ggml_backend_sched_reset(lctx.sched);
+    ggml_backend_sched_reset(lctx.sched.get());
 
     return 0;
 }
 
 // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
-static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
     auto & kv_self = lctx.kv_self;
 
     const auto & hparams = lctx.model.hparams;
@@ -16540,9 +9015,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     // each move requires 6*n_layer tensors (see build_defrag)
     //   - source view, destination view, copy operation
     //   - x2 for keys and values
-    //const uint32_t max_moves = llama_model_max_nodes(model)/(6*n_layer);
+    //const uint32_t max_moves = model.max_nodes()/(6*n_layer);
     // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer);
+    const uint32_t max_moves = (lctx.model.max_nodes() - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
     //
@@ -16728,7 +9203,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
 #else
     // ggml_graph defrag
 
-    ggml_backend_sched_reset(lctx.sched);
+    ggml_backend_sched_reset(lctx.sched.get());
 
     ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids);
 
@@ -16740,21 +9215,21 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
 }
 
-static void llama_kv_cache_update_internal(struct llama_context & lctx) {
+static void llama_kv_cache_update_impl(struct llama_context & lctx) {
     bool need_reserve = false;
 
-    // apply K-shift if needed
-    if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
-        if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
-            GGML_ABORT("Deepseek2 does not support K-shift");
+    if (lctx.kv_self.has_shift) {
+        if (!llama_kv_cache_can_shift(&lctx)) {
+            GGML_ABORT("The current context does not support K-shift");
         }
 
-        {
-            ggml_backend_sched_reset(lctx.sched);
+        // apply K-shift if needed
+        if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
+            ggml_backend_sched_reset(lctx.sched.get());
 
             ggml_cgraph * gf = llama_build_graph_k_shift(lctx);
 
-            ggml_backend_sched_alloc_graph(lctx.sched, gf);
+            ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
 
             llama_set_k_shift(lctx);
 
@@ -16776,7 +9251,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
     // defragment the KV cache if needed
     if (lctx.kv_self.do_defrag) {
-        llama_kv_cache_defrag_internal(lctx);
+        llama_kv_cache_defrag_impl(lctx);
 
         need_reserve = true;
 
@@ -16789,1128 +9264,55 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
         // build worst-case graph
         uint32_t n_seqs = 1; // TODO: worst-case number of sequences
         uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
-        llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+        llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
         llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
         ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
 
         // initialize scheduler with the worst-case graph
-        ggml_backend_sched_reset(lctx.sched);
-        if (!ggml_backend_sched_reserve(lctx.sched, gf)) {
+        ggml_backend_sched_reset(lctx.sched.get());
+        if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
             LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
         }
     }
 }
 
-//
-// quantization
-//
-
-struct quantize_state_internal {
-    const llama_model                 & model;
-    const llama_model_quantize_params * params;
-
-    int n_attention_wv    = 0;
-    int n_ffn_down        = 0;
-    int n_ffn_gate        = 0;
-    int n_ffn_up          = 0;
-    int i_attention_wv    = 0;
-    int i_ffn_down        = 0;
-    int i_ffn_gate        = 0;
-    int i_ffn_up          = 0;
-
-    int n_k_quantized     = 0;
-    int n_fallback        = 0;
-
-    bool has_imatrix      = false;
-
-    // used to figure out if a model shares tok_embd with the output weight
-    bool has_output       = false;
-
-    quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
-        : model(model)
-        , params(params)
-        {}
-};
-
-static void llama_tensor_dequantize_internal(
-    struct ggml_tensor * tensor, std::vector> & output, std::vector & workers,
-    const size_t nelements, const int nthread
-) {
-    if (output.size() < nelements) {
-        output.resize(nelements);
-    }
-    float * f32_output = (float *) output.data();
-
-    ggml_type_traits_t qtype;
-    if (ggml_is_quantized(tensor->type)) {
-        qtype = ggml_internal_get_type_traits(tensor->type);
-        if (qtype.to_float == NULL) {
-            throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
-        }
-    } else if (tensor->type != GGML_TYPE_F16 &&
-               tensor->type != GGML_TYPE_BF16) {
-        throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type)));
-    }
-
-    if (nthread < 2) {
-        if (tensor->type == GGML_TYPE_F16) {
-            ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements);
-        } else if (tensor->type == GGML_TYPE_BF16) {
-            ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
-        } else if (ggml_is_quantized(tensor->type)) {
-            qtype.to_float(tensor->data, f32_output, nelements);
-        } else {
-            GGML_ABORT("fatal error"); // unreachable
-        }
-        return;
-    }
-
-    size_t block_size;
-    if (tensor->type == GGML_TYPE_F16 ||
-        tensor->type == GGML_TYPE_BF16) {
-        block_size = 1;
-    } else {
-        block_size = (size_t)ggml_blck_size(tensor->type);
-    }
-
-    size_t block_size_bytes = ggml_type_size(tensor->type);
-
-    GGML_ASSERT(nelements % block_size == 0);
-    size_t nblocks = nelements / block_size;
-    size_t blocks_per_thread = nblocks / nthread;
-    size_t spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count
-
-    size_t in_buff_offs = 0;
-    size_t out_buff_offs = 0;
-
-    for (int tnum = 0; tnum < nthread; tnum++) {
-        size_t thr_blocks = blocks_per_thread + (tnum == nthread - 1 ? spare_blocks : 0); // num blocks for this thread
-        size_t thr_elems = thr_blocks * block_size; // number of elements for this thread
-        size_t thr_block_bytes = thr_blocks * block_size_bytes; // number of input bytes for this thread
-
-        auto compute = [qtype] (ggml_type typ, uint8_t * inbuf, float * outbuf, int nels) {
-            if (typ == GGML_TYPE_F16) {
-                ggml_fp16_to_fp32_row((ggml_fp16_t *)inbuf, outbuf, nels);
-            } else if (typ == GGML_TYPE_BF16) {
-                ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
-            } else {
-                qtype.to_float(inbuf, outbuf, nels);
-            }
-        };
-        workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
-        in_buff_offs += thr_block_bytes;
-        out_buff_offs += thr_elems;
-    }
-    for (auto & w : workers) { w.join(); }
-    workers.clear();
-}
-
-static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
-    const std::string name = ggml_get_name(tensor);
-
-    // TODO: avoid hardcoded tensor names - use the TN_* constants
-    const llm_arch arch = qs.model.arch;
-    const auto       tn = LLM_TN(arch);
-
-    auto use_more_bits = [](int i_layer, int n_layers) -> bool {
-        return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
-    };
-    const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
-    auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
-        if (n_expert > 1) {
-            // Believe it or not, "experts" in the FFN of Mixtral-8x7B are not consecutive, but occasionally randomly
-            // sprinkled in the model. Hence, simply dividing i_ffn_down by n_expert does not work
-            // for getting the current layer as I initially thought, and we need to resort to parsing the
-            // tensor name.
-            if (sscanf(name, "blk.%d.", &i_layer) != 1) {
-                throw std::runtime_error(format("Failed to determine layer for tensor %s", name));
-            }
-            if (i_layer < 0 || i_layer >= n_layer) {
-                throw std::runtime_error(format("Bad layer %d for tensor %s. Must be in [0, %d)", i_layer, name, n_layer));
-            }
-        }
-        return std::make_pair(i_layer, n_layer);
-    };
-
-    // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings
-    // with the quantization of the output tensor
-    if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) {
-        if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
-            new_type = qs.params->output_tensor_type;
-        } else {
-            int nx = tensor->ne[0];
-            if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
-                new_type = GGML_TYPE_Q8_0;
-            }
-            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
-                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S  || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M   ||
-                     ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
-                new_type = GGML_TYPE_Q5_K;
-            }
-            else if (new_type != GGML_TYPE_Q8_0) {
-                new_type = GGML_TYPE_Q6_K;
-            }
-        }
-    } else if (name == "token_embd.weight") {
-        if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
-            new_type = qs.params->token_embedding_type;
-        } else {
-            if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS ||
-                ftype == LLAMA_FTYPE_MOSTLY_IQ1_S   || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
-                new_type = GGML_TYPE_Q2_K;
-            }
-            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) {
-                new_type = GGML_TYPE_IQ3_S;
-            }
-            else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
-                new_type = GGML_TYPE_IQ3_S;
-            }
-            else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 ||
-                     new_type == GGML_TYPE_Q4_0_8_8) {
-                new_type = GGML_TYPE_Q4_0;
-            }
-            else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) {
-                new_type = GGML_TYPE_Q4_K;
-            }
-        }
-    } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
-               ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
-        if (name.find("attn_v.weight") != std::string::npos) {
-            if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
-            else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
-            ++qs.i_attention_wv;
-        }
-        else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (name.find("ffn_down") != std::string::npos) {
-            if (qs.i_ffn_down < qs.n_ffn_down/8) {
-                new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K;
-            }
-            ++qs.i_ffn_down;
-        }
-        else if (name.find("attn_output.weight") != std::string::npos) {
-            if (qs.model.hparams.n_expert == 8) {
-                new_type = GGML_TYPE_Q5_K;
-            } else {
-                if (ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) new_type = GGML_TYPE_IQ2_XXS;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S;
-            }
-        }
-    } else if (name.find("attn_v.weight") != std::string::npos) {
-        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
-            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && qs.model.hparams.n_gqa() >= 4) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
-            new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : !qs.has_imatrix ? GGML_TYPE_IQ3_S : GGML_TYPE_IQ3_XXS;
-        }
-        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S) && qs.model.hparams.n_gqa() >= 4) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
-            new_type = qs.i_attention_wv < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K;
-        else if ((ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && qs.model.hparams.n_gqa() >= 4) {
-            new_type = GGML_TYPE_Q5_K;
-        }
-        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
-                use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
-        if (qs.model.type == MODEL_70B) {
-            // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
-            // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
-            // nearly negligible increase in model size by quantizing this tensor with more bits:
-            if (new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K) new_type = GGML_TYPE_Q5_K;
-        }
-        if (qs.model.hparams.n_expert == 8) {
-            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
-            // TODO: explore better strategies
-            new_type = GGML_TYPE_Q8_0;
-        }
-        ++qs.i_attention_wv;
-    } else if (name.find("attn_k.weight") != std::string::npos) {
-        if (qs.model.hparams.n_expert == 8) {
-            // for the 8-expert model, bumping this to Q8_0 trades just ~128MB
-            // TODO: explore better strategies
-            new_type = GGML_TYPE_Q8_0;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
-            new_type = GGML_TYPE_IQ3_XXS;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
-            new_type = GGML_TYPE_IQ2_S;
-        }
-    } else if (name.find("attn_q.weight") != std::string::npos) {
-        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) {
-            new_type = GGML_TYPE_IQ3_XXS;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
-            new_type = GGML_TYPE_IQ2_S;
-        }
-    } else if (name.find("ffn_down") != std::string::npos) {
-        auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str());
-        int i_layer = info.first, n_layer = info.second;
-        if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) {
-            if (i_layer < n_layer/8) new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS && !qs.has_imatrix) {
-            new_type = i_layer < n_layer/8 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
-            new_type = i_layer < n_layer/16 ? GGML_TYPE_Q5_K
-                     : arch != LLM_ARCH_FALCON || use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q4_K
-                     : GGML_TYPE_Q3_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M && (i_layer < n_layer/8 ||
-                    (qs.model.hparams.n_expert == 8 && use_more_bits(i_layer, n_layer)))) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) {
-            new_type = arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) {
-            if (arch == LLM_ARCH_FALCON) {
-                new_type = i_layer < n_layer/16 ? GGML_TYPE_Q6_K :
-                           use_more_bits(i_layer, n_layer) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K;
-            } else {
-                if (use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
-            }
-        }
-        else if (i_layer < n_layer/8 && (ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) && !qs.has_imatrix) {
-            new_type = GGML_TYPE_Q5_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_layer, n_layer)) new_type = GGML_TYPE_Q6_K;
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) {
-            new_type = GGML_TYPE_Q5_K;
-        }
-        else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0)
-                && qs.has_imatrix && i_layer < n_layer/8) {
-            // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0.
-            // We only do it when an imatrix is provided because a) we want to make sure that one can always get the
-            // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix.
-            new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1;
-        }
-        ++qs.i_ffn_down;
-    } else if (name.find("attn_output.weight") != std::string::npos) {
-        if (arch != LLM_ARCH_FALCON) {
-            if (qs.model.hparams.n_expert == 8) {
-                if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
-                    ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_NL  ||
-                    ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S  ||
-                    ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  || ftype == LLAMA_FTYPE_MOSTLY_IQ4_XS) {
-                    new_type = GGML_TYPE_Q5_K;
-                }
-            } else {
-                if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K   ) new_type = GGML_TYPE_Q3_K;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) new_type = GGML_TYPE_IQ3_S;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M ) new_type = GGML_TYPE_Q4_K;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L ) new_type = GGML_TYPE_Q5_K;
-                else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M  ) new_type = GGML_TYPE_Q4_K;
-            }
-        } else {
-            if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K;
-        }
-    }
-    else if (name.find("attn_qkv.weight") != std::string::npos) {
-        if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) {
-            new_type = GGML_TYPE_Q4_K;
-        }
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K;
-        else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K;
-    }
-    else if (name.find("ffn_gate") != std::string::npos) {
-        auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str());
-        int i_layer = info.first, n_layer = info.second;
-        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
-            new_type = GGML_TYPE_IQ3_XXS;
-        }
-        ++qs.i_ffn_gate;
-    }
-    else if (name.find("ffn_up") != std::string::npos) {
-        auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str());
-        int i_layer = info.first, n_layer = info.second;
-        if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) {
-            new_type = GGML_TYPE_IQ3_XXS;
-        }
-        ++qs.i_ffn_up;
-    }
-
-    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
-    //}
-    // IK: let's remove this, else Q2_K is almost the same as Q3_K_S
-    //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) {
-    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K;
-    //}
-    // This can be used to reduce the size of the Q5_K_S model.
-    // The associated PPL increase is fully in line with the size reduction
-    //else {
-    //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
-    //}
-    bool convert_incompatible_tensor = false;
-    if (new_type == GGML_TYPE_Q2_K    || new_type == GGML_TYPE_Q3_K    || new_type == GGML_TYPE_Q4_K   ||
-        new_type == GGML_TYPE_Q5_K    || new_type == GGML_TYPE_Q6_K    || new_type == GGML_TYPE_IQ4_XS ||
-        new_type == GGML_TYPE_IQ2_XS  || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S  ||
-        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S   || new_type == GGML_TYPE_IQ3_S  ||
-        new_type == GGML_TYPE_IQ1_M) {
-        int nx = tensor->ne[0];
-        int ny = tensor->ne[1];
-        if (nx % QK_K != 0) {
-            LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
-            convert_incompatible_tensor = true;
-        } else {
-            ++qs.n_k_quantized;
-        }
-    }
-    if (convert_incompatible_tensor) {
-        switch (new_type) {
-            case GGML_TYPE_TQ1_0:
-            case GGML_TYPE_TQ2_0:  new_type = GGML_TYPE_Q4_0; break;  // TODO: use a symmetric type instead
-            case GGML_TYPE_IQ2_XXS:
-            case GGML_TYPE_IQ2_XS:
-            case GGML_TYPE_IQ2_S:
-            case GGML_TYPE_IQ3_XXS:
-            case GGML_TYPE_IQ3_S:
-            case GGML_TYPE_IQ1_S:
-            case GGML_TYPE_IQ1_M:
-            case GGML_TYPE_Q2_K:
-            case GGML_TYPE_Q3_K:
-            case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break;
-            case GGML_TYPE_Q4_K:   new_type = GGML_TYPE_Q5_0;   break;
-            case GGML_TYPE_Q5_K:   new_type = GGML_TYPE_Q5_1;   break;
-            case GGML_TYPE_Q6_K:   new_type = GGML_TYPE_Q8_0;   break;
-            default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
-        }
-        if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
-            new_type = GGML_TYPE_F16;
-        }
-        LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
-        ++qs.n_fallback;
-    }
-
-    return new_type;
-}
-
-static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) {
-    if (nthread < 2) {
-        // single-thread
-        size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
-        if (!ggml_validate_row_data(new_type, new_data, new_size)) {
-            throw std::runtime_error("quantized data validation failed");
-        }
-        return new_size;
-    }
-
-    std::mutex mutex;
-    int64_t counter = 0;
-    size_t new_size = 0;
-    bool valid = true;
-    auto compute = [&mutex, &counter, &new_size, &valid, new_type, f32_data, new_data, chunk_size,
-            nrows, n_per_row, imatrix]() {
-        const int64_t nrows_per_chunk = chunk_size / n_per_row;
-        size_t local_size = 0;
-        while (true) {
-            std::unique_lock lock(mutex);
-            int64_t first_row = counter; counter += nrows_per_chunk;
-            if (first_row >= nrows) {
-                if (local_size > 0) {
-                    new_size += local_size;
-                }
-                break;
-            }
-            lock.unlock();
-            const int64_t this_nrow = std::min(nrows - first_row, nrows_per_chunk);
-            size_t this_size = ggml_quantize_chunk(new_type, f32_data, new_data, first_row * n_per_row, this_nrow, n_per_row, imatrix);
-            local_size += this_size;
-
-            // validate the quantized data
-            const size_t row_size  = ggml_row_size(new_type, n_per_row);
-            void * this_data = (char *) new_data + first_row * row_size;
-            if (!ggml_validate_row_data(new_type, this_data, this_size)) {
-                std::unique_lock lock(mutex);
-                valid = false;
-                break;
-            }
-        }
-    };
-    for (int it = 0; it < nthread - 1; ++it) {
-        workers.emplace_back(compute);
-    }
-    compute();
-    for (auto & w : workers) { w.join(); }
-    workers.clear();
-    if (!valid) {
-        throw std::runtime_error("quantized data validation failed");
-    }
-    return new_size;
-}
-
-static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
-    ggml_type default_type;
-    llama_ftype ftype = params->ftype;
-
-    switch (params->ftype) {
-        case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
-        case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
-        case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
-        case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break;
-        case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break;
-        case LLAMA_FTYPE_MOSTLY_F16:  default_type = GGML_TYPE_F16;  break;
-        case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
-        case LLAMA_FTYPE_ALL_F32:     default_type = GGML_TYPE_F32;  break;
-
-        // K-quants
-        case LLAMA_FTYPE_MOSTLY_Q2_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q2_K:    default_type = GGML_TYPE_Q2_K;    break;
-        case LLAMA_FTYPE_MOSTLY_IQ3_XS:  default_type = GGML_TYPE_IQ3_S;   break;
-        case LLAMA_FTYPE_MOSTLY_Q3_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q3_K_M:
-        case LLAMA_FTYPE_MOSTLY_Q3_K_L:  default_type = GGML_TYPE_Q3_K;    break;
-        case LLAMA_FTYPE_MOSTLY_Q4_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q4_K_M:  default_type = GGML_TYPE_Q4_K;    break;
-        case LLAMA_FTYPE_MOSTLY_Q5_K_S:
-        case LLAMA_FTYPE_MOSTLY_Q5_K_M:  default_type = GGML_TYPE_Q5_K;    break;
-        case LLAMA_FTYPE_MOSTLY_Q6_K:    default_type = GGML_TYPE_Q6_K;    break;
-        case LLAMA_FTYPE_MOSTLY_TQ1_0:   default_type = GGML_TYPE_TQ1_0;   break;
-        case LLAMA_FTYPE_MOSTLY_TQ2_0:   default_type = GGML_TYPE_TQ2_0;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_XS:  default_type = GGML_TYPE_IQ2_XS;  break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_S:   default_type = GGML_TYPE_IQ2_XS;  break;
-        case LLAMA_FTYPE_MOSTLY_IQ2_M:   default_type = GGML_TYPE_IQ2_S;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break;
-        case LLAMA_FTYPE_MOSTLY_IQ1_S:   default_type = GGML_TYPE_IQ1_S;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ1_M:   default_type = GGML_TYPE_IQ1_M;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ4_NL:  default_type = GGML_TYPE_IQ4_NL;  break;
-        case LLAMA_FTYPE_MOSTLY_IQ4_XS:  default_type = GGML_TYPE_IQ4_XS;  break;
-        case LLAMA_FTYPE_MOSTLY_IQ3_S:   default_type = GGML_TYPE_IQ3_S;   break;
-        case LLAMA_FTYPE_MOSTLY_IQ3_M:   default_type = GGML_TYPE_IQ3_S;   break;
-        case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break;
-        case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: default_type = GGML_TYPE_Q4_0_4_8; break;
-        case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: default_type = GGML_TYPE_Q4_0_8_8; break;
-
-        default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
-    }
-
-    int nthread = params->nthread;
-
-    if (nthread <= 0) {
-        nthread = std::thread::hardware_concurrency();
-    }
-
-    // mmap consistently increases speed Linux, and also increases speed on Windows with
-    // hot cache. It may cause a slowdown on macOS, possibly related to free memory.
-#if defined(__linux__) || defined(_WIN32)
-    constexpr bool use_mmap = true;
-#else
-    constexpr bool use_mmap = false;
-#endif
-
-    llama_model_kv_override * kv_overrides = nullptr;
-    if (params->kv_overrides) {
-        auto v = (std::vector*)params->kv_overrides;
-        kv_overrides = v->data();
-    }
-    llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
-    ml.init_mappings(false); // no prefetching
-
-    llama_model model;
-    llm_load_arch(ml, model);
-    llm_load_hparams(ml, model);
-
-    struct quantize_state_internal qs(model, params);
-
-    if (params->only_copy) {
-        ftype = model.ftype;
-    }
-    const std::unordered_map> * imatrix_data = nullptr;
-    if (params->imatrix) {
-        imatrix_data = static_cast>*>(params->imatrix);
-        if (imatrix_data) {
-            LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
-            qs.has_imatrix = true;
-            // check imatrix for nans or infs
-            for (const auto & kv : *imatrix_data) {
-                for (float f : kv.second) {
-                    if (!std::isfinite(f)) {
-                        throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
-                    }
-                }
-            }
-        }
-    }
-
-    const size_t align = GGUF_DEFAULT_ALIGNMENT;
-    struct gguf_context * ctx_out = gguf_init_empty();
-
-    // copy the KV pairs from the input file
-    gguf_set_kv     (ctx_out, ml.meta);
-    gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
-    gguf_set_val_u32(ctx_out, "general.file_type", ftype); // TODO: use LLM_KV
-
-    // Remove split metadata
-    gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
-    gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
-    gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str());
-
-    if (params->kv_overrides) {
-        const std::vector & overrides = *(const std::vector *)params->kv_overrides;
-        for (auto & o : overrides) {
-            if (o.key[0] == 0) break;
-            if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) {
-                gguf_set_val_f32(ctx_out, o.key, o.val_f64);
-            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) {
-                gguf_set_val_i32(ctx_out, o.key, o.val_i64);
-            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) {
-                gguf_set_val_bool(ctx_out, o.key, o.val_bool);
-            } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) {
-                gguf_set_val_str(ctx_out, o.key, o.val_str);
-            } else {
-                LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key);
-            }
-        }
-    }
-
-    for (int i = 0; i < ml.n_tensors; ++i) {
-        const struct ggml_tensor * meta = ml.get_tensor_meta(i);
-
-        const std::string name = ggml_get_name(meta);
-
-        // TODO: avoid hardcoded tensor names - use the TN_* constants
-        if (name.find("attn_v.weight")   != std::string::npos ||
-            name.find("attn_qkv.weight") != std::string::npos ||
-            name.find("attn_kv_b.weight")!= std::string::npos) {
-            ++qs.n_attention_wv;
-        } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
-            qs.has_output = true;
-        }
-    }
-
-    qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
-
-    // sanity checks
-    {
-        const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
-        // attention layers have a non-zero number of kv heads
-        int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
-        if (llama_model_has_encoder(&model)) {
-            n_attn_layer *= 3;
-        }
-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
-    }
-
-    size_t total_size_org = 0;
-    size_t total_size_new = 0;
-
-    std::vector workers;
-    workers.reserve(nthread);
-
-    int idx = 0;
-
-    std::vector> read_data;
-    std::vector> work;
-    std::vector> f32_conv_buf;
-
-    uint16_t n_split = 1;
-    // Assume split index is continuous
-    if (params->keep_split) {
-        for (int i = 0; i < ml.n_tensors; ++i) {
-            n_split = std::max(uint16_t(ml.get_weight(i)->idx+1), n_split);
-        }
-    }
-    std::vector ctx_outs(n_split, NULL);
-    ctx_outs[0] = ctx_out;
-
-    // populate the original tensors so we get an initial meta data
-    for (int i = 0; i < ml.n_tensors; ++i) {
-        auto weight = ml.get_weight(i);
-        uint16_t i_split = params->keep_split ? weight->idx : 0;
-        struct ggml_tensor * tensor = weight->tensor;
-        if (ctx_outs[i_split] == NULL) {
-            ctx_outs[i_split] = gguf_init_empty();
-        }
-        gguf_add_tensor(ctx_outs[i_split], tensor);
-    }
-
-    // Set split info if needed
-    if (n_split > 1) {
-        for (size_t i = 0; i < ctx_outs.size(); ++i) {
-            gguf_set_val_u16(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
-            gguf_set_val_u16(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
-            gguf_set_val_i32(ctx_outs[i], ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), ml.n_tensors);
-        }
-    }
-
-    int cur_split = -1;
-    std::ofstream fout;
-    auto close_ofstream = [&]() {
-        // Write metadata and close file handler
-        if (fout.is_open()) {
-            fout.seekp(0);
-            std::vector data(gguf_get_meta_size(ctx_outs[cur_split]));
-            gguf_get_meta_data(ctx_outs[cur_split], data.data());
-            fout.write((const char *) data.data(), data.size());
-            fout.close();
-        }
-    };
-    auto new_ofstream = [&](int index) {
-        cur_split = index;
-        GGML_ASSERT(ctx_outs[cur_split] && "Find uninitialized gguf_context");
-        std::string fname = fname_out;
-        if (params->keep_split) {
-            char split_path[PATH_MAX] = {0};
-            llama_split_path(split_path, sizeof(split_path), fname_out.c_str(), cur_split, n_split);
-            fname = std::string(split_path);
-        }
-
-        fout = std::ofstream(fname, std::ios::binary);
-        fout.exceptions(std::ofstream::failbit); // fail fast on write errors
-        const size_t meta_size = gguf_get_meta_size(ctx_outs[cur_split]);
-        // placeholder for the meta data
-        ::zeros(fout, meta_size);
-    };
-
-    const auto tn = LLM_TN(model.arch);
-    new_ofstream(0);
-    for (int i = 0; i < ml.n_tensors; ++i) {
-        auto weight = ml.get_weight(i);
-        struct ggml_tensor * tensor = weight->tensor;
-        if (weight->idx != cur_split && params->keep_split) {
-            close_ofstream();
-            new_ofstream(weight->idx);
-        }
-
-        const std::string name = ggml_get_name(tensor);
-
-        if (!ml.use_mmap) {
-            if (read_data.size() < ggml_nbytes(tensor)) {
-                read_data.resize(ggml_nbytes(tensor));
-            }
-            tensor->data = read_data.data();
-        }
-        ml.load_data_for(tensor);
-
-        LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",
-               ++idx, ml.n_tensors,
-               ggml_get_name(tensor),
-               llama_format_tensor_shape(tensor).c_str(),
-               ggml_type_name(tensor->type));
-
-        // This used to be a regex, but  has an extreme cost to compile times.
-        bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'?
-
-        // quantize only 2D and 3D tensors (experts)
-        quantize &= (ggml_n_dims(tensor) >= 2);
-
-        // do not quantize norm tensors
-        quantize &= name.find("_norm.weight") == std::string::npos;
-
-        quantize &= params->quantize_output_tensor || name != "output.weight";
-        quantize &= !params->only_copy;
-
-        // do not quantize expert gating tensors
-        // NOTE: can't use LLM_TN here because the layer number is not known
-        quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
-
-        // do not quantize positional embeddings and token types (BERT)
-        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD,    "weight");
-        quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
-
-        // do not quantize Mamba's small yet 2D weights
-        // NOTE: can't use LLM_TN here because the layer number is not known
-        quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
-
-        // do not quantize RWKV's time_mix_first tensors
-        quantize &= name.find("time_mix_first.weight") == std::string::npos;
-        quantize &= name.find("time_mix_w1.weight") == std::string::npos;
-        quantize &= name.find("time_mix_w2.weight") == std::string::npos;
-
-        // do not quantize relative position bias (T5)
-        quantize &= name.find("attn_rel_b.weight") == std::string::npos;
-
-        enum ggml_type new_type;
-        void * new_data;
-        size_t new_size;
-
-        if (quantize) {
-            new_type = default_type;
-
-            // get more optimal quantization type based on the tensor shape, layer, etc.
-            if (!params->pure && ggml_is_quantized(default_type)) {
-                new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
-            }
-            if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) {
-                new_type = params->token_embedding_type;
-            }
-            if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
-                new_type = params->output_tensor_type;
-            }
-
-            // If we've decided to quantize to the same type the tensor is already
-            // in then there's nothing to do.
-            quantize = tensor->type != new_type;
-        }
-
-        if (!quantize) {
-            new_type = tensor->type;
-            new_data = tensor->data;
-            new_size = ggml_nbytes(tensor);
-            LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
-        } else {
-            const int64_t nelements = ggml_nelements(tensor);
-
-            const float * imatrix = nullptr;
-            if (imatrix_data) {
-                auto it = imatrix_data->find(tensor->name);
-                if (it == imatrix_data->end()) {
-                    LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
-                } else {
-                    if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) {
-                        imatrix = it->second.data();
-                    } else {
-                        LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__,
-                                int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name);
-
-                        // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix
-                        // this is a significant error and it may be good idea to abort the process if this happens,
-                        // since many people will miss the error and not realize that most of the model is being quantized without an imatrix
-                        // tok_embd should be ignored in this case, since it always causes this warning
-                        if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) {
-                            throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s",
-                                    int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name));
-                        }
-                    }
-                }
-            }
-            if ((new_type == GGML_TYPE_IQ2_XXS ||
-                 new_type == GGML_TYPE_IQ2_XS  ||
-                 new_type == GGML_TYPE_IQ2_S   ||
-                 new_type == GGML_TYPE_IQ1_S   ||
-                (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight"))  ||
-                (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
-                LLAMA_LOG_ERROR("\n\n============================================================\n");
-                LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);
-                LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n");
-                LLAMA_LOG_ERROR("============================================================\n\n");
-                throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name));
-            }
-
-            float * f32_data;
-
-            if (tensor->type == GGML_TYPE_F32) {
-                f32_data = (float *) tensor->data;
-            } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
-                throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
-            } else {
-                llama_tensor_dequantize_internal(tensor, f32_conv_buf, workers, nelements, nthread);
-                f32_data = (float *) f32_conv_buf.data();
-            }
-
-            int chunk_size_multiplier = 1;
-            if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) {
-                if ((new_type == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = GGML_TYPE_Q4_0;
-                else if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0;
-                if (new_type == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8;
-                else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4;
-            }
-
-            LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
-            fflush(stdout);
-
-            if (work.size() < (size_t)nelements * 4) {
-                work.resize(nelements * 4); // upper bound on size
-            }
-            new_data = work.data();
-
-            const int64_t n_per_row = tensor->ne[0];
-            const int64_t nrows = tensor->ne[1];
-
-            static const int64_t min_chunk_size = 32 * 512;
-            const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) *
-                                       chunk_size_multiplier;
-
-            const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
-            const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
-            const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1;
-
-            // quantize each expert separately since they have different importance matrices
-            new_size = 0;
-            for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) {
-                const float * f32_data_03 = f32_data + i03 * nelements_matrix;
-                void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
-                const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
-
-                new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
-            }
-            LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
-        }
-        total_size_org += ggml_nbytes(tensor);
-        total_size_new += new_size;
-
-        // update the gguf meta data as we go
-        gguf_set_tensor_type(ctx_outs[cur_split], name.c_str(), new_type);
-        gguf_set_tensor_data(ctx_outs[cur_split], name.c_str(), new_data, new_size);
-
-        // write tensor data + padding
-        fout.write((const char *) new_data, new_size);
-        zeros(fout, GGML_PAD(new_size, align) - new_size);
-    }
-    close_ofstream();
-    for (auto & c:ctx_outs) {
-        gguf_free(c);
-    }
-
-    LLAMA_LOG_INFO("%s: model size  = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
-    LLAMA_LOG_INFO("%s: quant size  = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
-
-    if (qs.n_fallback > 0) {
-        LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
-                __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback);
-    }
-}
-
-static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
-    LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
-
-    ggml_context * ctx = nullptr;
-    struct gguf_init_params meta_gguf_params = {
-        /* .no_alloc = */ true,
-        /* .ctx      = */ &ctx,
-    };
-    struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params);
-    if (!ctx_gguf) {
-        throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora));
-    }
-
-    // check metadata
-    {
-        auto get_kv_str = [&](const std::string & key) -> std::string {
-            int id = gguf_find_key(ctx_gguf, key.c_str());
-            return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id));
-        };
-        auto get_kv_f32 = [&](const std::string & key) -> float {
-            int id = gguf_find_key(ctx_gguf, key.c_str());
-            return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf, id);
-        };
-        LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
-
-        auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE));
-        if (general_type != "adapter") {
-            gguf_free(ctx_gguf);
-            throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
-        }
-
-        auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
-        auto general_arch = llm_arch_from_string(general_arch_str);
-        if (general_arch != model->arch) {
-            gguf_free(ctx_gguf);
-            throw std::runtime_error("model arch and LoRA arch mismatch");
-        }
-
-        auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE));
-        if (adapter_type != "lora") {
-            gguf_free(ctx_gguf);
-            throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
-        }
-
-        adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
-    }
-
-    int n_tensors = gguf_get_n_tensors(ctx_gguf);
-
-    // contexts for each buffer type
-    std::map ctx_map;
-    auto get_ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            // add a new context
-            struct ggml_init_params params = {
-                /*.mem_size   =*/ n_tensors*ggml_tensor_overhead(),
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-            ggml_context * buft_ctx = ggml_init(params);
-            ctx_map[buft] = buft_ctx;
-            return buft_ctx;
-        };
-        return it->second;
-    };
-
-    // bundle lora_a and lora_b into pairs
-    std::map ab_map;
-    auto str_endswith = [](const std::string & str, const std::string & suffix) {
-        return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
-    };
-    for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
-        std::string name(cur->name);
-        if (str_endswith(name, ".lora_a")) {
-            replace_all(name, ".lora_a", "");
-            if (ab_map.find(name) == ab_map.end()) {
-                ab_map[name] = llama_lora_weight(cur, nullptr);
-            } else {
-                ab_map[name].a = cur;
-            }
-        } else if (str_endswith(name, ".lora_b")) {
-            replace_all(name, ".lora_b", "");
-            if (ab_map.find(name) == ab_map.end()) {
-                ab_map[name] = llama_lora_weight(nullptr, cur);
-            } else {
-                ab_map[name].b = cur;
-            }
-        } else {
-            gguf_free(ctx_gguf);
-            ggml_free(ctx);
-            throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
-        }
-    }
-
-    // add tensors
-    for (auto & it : ab_map) {
-        const std::string & name = it.first;
-        llama_lora_weight & w = it.second;
-
-        if (!w.a || !w.b) {
-            gguf_free(ctx_gguf);
-            ggml_free(ctx);
-            throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
-        }
-
-        // device buft and device ctx
-        auto * model_tensor = llama_get_model_tensor(model, name.c_str());
-        if (!model_tensor) {
-            gguf_free(ctx_gguf);
-            ggml_free(ctx);
-            throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
-        }
-        struct ggml_context * dev_ctx = get_ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
-        // validate tensor shape
-        if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
-            gguf_free(ctx_gguf);
-            ggml_free(ctx);
-            throw std::runtime_error("tensor '" + name + "' has incorrect shape");
-        }
-        if (w.a->ne[1] != w.b->ne[0]) {
-            gguf_free(ctx_gguf);
-            ggml_free(ctx);
-            throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
-        }
-        // save tensor to adapter
-        struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
-        struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
-        ggml_set_name(tensor_a, w.a->name);
-        ggml_set_name(tensor_b, w.b->name);
-        adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
-    }
-
-    // allocate tensors / buffers and zero
-    {
-        adapter.ctxs.reserve(ctx_map.size());
-        adapter.bufs.reserve(ctx_map.size());
-        for (auto it : ctx_map) {
-            ggml_backend_buffer_type_t buft = it.first;
-            ggml_context * ctx_dev = it.second;
-            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft);
-            if (!buf) {
-                gguf_free(ctx_gguf);
-                ggml_free(ctx);
-                throw std::runtime_error("failed to allocate buffer for lora adapter\n");
-            }
-            LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
-            adapter.ctxs.push_back(ctx_dev);
-            adapter.bufs.push_back(buf);
-        }
-    }
-
-    // set tensor data
-    {
-        llama_file gguf_file(path_lora, "rb");
-        std::vector read_buf;
-        auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
-            size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, gguf_find_tensor(ctx_gguf, orig->name));
-            size_t size = ggml_nbytes(orig);
-            read_buf.resize(size);
-            gguf_file.seek(offs, SEEK_SET);
-            gguf_file.read_raw(read_buf.data(), size);
-            ggml_backend_tensor_set(dev, read_buf.data(), 0, size);
-        };
-        for (auto & it : adapter.ab_map) {
-            auto orig = ab_map[it.first];
-            auto dev  = it.second;
-            set_tensor(orig.a, dev.a);
-            set_tensor(orig.b, dev.b);
-        }
-    }
-
-    LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2);
-
-    // free ctx for reading gguf
-    gguf_free(ctx_gguf);
-    ggml_free(ctx);
-}
-
-int32_t llama_lora_adapter_set(
+int32_t llama_set_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter,
+            struct llama_adapter_lora * adapter,
             float scale) {
-    if (ctx->cparams.flash_attn) {
-        LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__);
-        return -1;
-    }
-    ctx->lora_adapters[adapter] = scale;
+    ctx->lora[adapter] = scale;
     return 0;
 }
 
-int32_t llama_lora_adapter_remove(
+int32_t llama_rm_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter) {
-    auto pos = ctx->lora_adapters.find(adapter);
-    if (pos != ctx->lora_adapters.end()) {
-        ctx->lora_adapters.erase(pos);
+            struct llama_adapter_lora * adapter) {
+    auto pos = ctx->lora.find(adapter);
+    if (pos != ctx->lora.end()) {
+        ctx->lora.erase(pos);
         return 0;
     }
+
     return -1;
 }
 
-void llama_lora_adapter_clear(struct llama_context * ctx) {
-    ctx->lora_adapters.clear();
+void llama_clear_adapter_lora(struct llama_context * ctx) {
+    ctx->lora.clear();
 }
 
-void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
-    delete adapter;
+int32_t llama_apply_adapter_cvec(
+        struct llama_context * ctx,
+                 const float * data,
+                      size_t   len,
+                     int32_t   n_embd,
+                     int32_t   il_start,
+                     int32_t   il_end) {
+    return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
 }
 
 //
 // interface implementation
 //
-struct llama_model_params llama_model_default_params() {
-    struct llama_model_params result = {
-        /*.n_gpu_layers                =*/ 0,
-        /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
-        /*.main_gpu                    =*/ 0,
-        /*.tensor_split                =*/ nullptr,
-        /*.rpc_servers                 =*/ nullptr,
-        /*.progress_callback           =*/ nullptr,
-        /*.progress_callback_user_data =*/ nullptr,
-        /*.kv_overrides                =*/ nullptr,
-        /*.vocab_only                  =*/ false,
-        /*.use_mmap                    =*/ true,
-        /*.use_mlock                   =*/ false,
-        /*.check_tensors               =*/ false,
-    };
-
-#ifdef GGML_USE_METAL
-    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
-    result.n_gpu_layers = 999;
-#endif
-
-    return result;
-}
 
 struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
@@ -17939,6 +9341,7 @@ struct llama_context_params llama_context_default_params() {
         /*.embeddings                  =*/ false,
         /*.offload_kqv                 =*/ true,
         /*.flash_attn                  =*/ false,
+        /*.no_perf                     =*/ true,
         /*.abort_callback              =*/ nullptr,
         /*.abort_callback_data         =*/ nullptr,
     };
@@ -17954,40 +9357,8 @@ struct llama_sampler_chain_params llama_sampler_chain_default_params() {
     return result;
 }
 
-struct llama_model_quantize_params llama_model_quantize_default_params() {
-    struct llama_model_quantize_params result = {
-        /*.nthread                     =*/ 0,
-        /*.ftype                       =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
-        /*.output_tensor_type          =*/ GGML_TYPE_COUNT,
-        /*.token_embedding_type        =*/ GGML_TYPE_COUNT,
-        /*.allow_requantize            =*/ false,
-        /*.quantize_output_tensor      =*/ true,
-        /*.only_copy                   =*/ false,
-        /*.pure                        =*/ false,
-        /*.keep_split                  =*/ false,
-        /*.imatrix                     =*/ nullptr,
-        /*.kv_overrides                =*/ nullptr,
-    };
-
-    return result;
-}
-
 size_t llama_max_devices(void) {
-#if defined(GGML_USE_RPC)
-    return GGML_RPC_MAX_SERVERS;
-#elif defined(GGML_USE_METAL)
-    return 1;
-#elif defined(GGML_USE_CUDA)
-    return GGML_CUDA_MAX_DEVICES;
-#elif defined(GGML_USE_SYCL)
-    return GGML_SYCL_MAX_DEVICES;
-#elif defined(GGML_USE_VULKAN)
-    return GGML_VK_MAX_DEVICES;
-#elif defined(GGML_USE_CANN)
-    return GGML_CANN_MAX_DEVICES;
-#else
-    return 1;
-#endif
+    return 16;
 }
 
 bool llama_supports_mmap(void) {
@@ -17999,13 +9370,12 @@ bool llama_supports_mlock(void) {
 }
 
 bool llama_supports_gpu_offload(void) {
-#if defined(GGML_USE_CUDA) || defined(GGML_USE_METAL)   || defined(GGML_USE_VULKAN) || \
-    defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC)
-    // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
-    return true;
-#else
-    return false;
-#endif
+    return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
+           llama_supports_rpc();
+}
+
+bool llama_supports_rpc(void) {
+    return ggml_backend_reg_by_name("RPC") != nullptr;
 }
 
 void llama_backend_init(void) {
@@ -18021,23 +9391,14 @@ void llama_backend_init(void) {
 
 void llama_numa_init(enum ggml_numa_strategy numa) {
     if (numa != GGML_NUMA_STRATEGY_DISABLED) {
-        ggml_numa_init(numa);
+        auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+        GGML_ASSERT(dev && "CPU backend is not loaded");
+        auto * reg = ggml_backend_dev_backend_reg(dev);
+        auto * numa_init_fn = (decltype(ggml_numa_init) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_numa_init");
+        numa_init_fn(numa);
     }
 }
 
-void llama_attach_threadpool(
-             struct llama_context * ctx,
-        ggml_threadpool_t   threadpool,
-        ggml_threadpool_t   threadpool_batch) {
-    ctx->threadpool       = threadpool;
-    ctx->threadpool_batch = threadpool_batch ? threadpool_batch : threadpool;
-}
-
-void llama_detach_threadpool(struct llama_context * ctx) {
-    ctx->threadpool       = nullptr;
-    ctx->threadpool_batch = nullptr;
-}
-
 void llama_backend_free(void) {
     ggml_quantize_free();
 }
@@ -18046,12 +9407,13 @@ int64_t llama_time_us(void) {
     return ggml_time_us();
 }
 
-struct llama_model * llama_load_model_from_file(
-        const char * path_model,
-        struct llama_model_params   params) {
+static struct llama_model * llama_model_load_from_file_impl(
+        const std::string & path_model,
+        std::vector & splits,
+        struct llama_model_params params) {
     ggml_time_init();
 
-    llama_model * model = new llama_model;
+    llama_model * model = new llama_model(params);
 
     unsigned cur_percentage = 0;
     if (params.progress_callback == NULL) {
@@ -18061,26 +9423,66 @@ struct llama_model * llama_load_model_from_file(
             unsigned percentage = (unsigned) (100 * progress);
             while (percentage > *cur_percentage_p) {
                 *cur_percentage_p = percentage;
-                LLAMA_LOG_INFO(".");
+                LLAMA_LOG_CONT(".");
                 if (percentage >= 100) {
-                    LLAMA_LOG_INFO("\n");
+                    LLAMA_LOG_CONT("\n");
                 }
             }
             return true;
         };
     }
-    if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
-        // split the servers set them into model->rpc_servers
-        std::string servers(params.rpc_servers);
-        size_t pos = 0;
-        while ((pos = servers.find(",")) != std::string::npos) {
-            std::string server = servers.substr(0, pos);
-            model->rpc_servers.push_back(server);
-            servers.erase(0, pos + 1);
+
+    // create list of devices to use with this model
+    if (params.devices) {
+        for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
+            model->devices.push_back(*dev);
+        }
+    } else {
+        std::vector rpc_servers;
+        // use all available devices
+        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+            switch (ggml_backend_dev_type(dev)) {
+                case GGML_BACKEND_DEVICE_TYPE_CPU:
+                case GGML_BACKEND_DEVICE_TYPE_ACCEL:
+                    // skip CPU backends since they are handled separately
+                    break;
+
+                case GGML_BACKEND_DEVICE_TYPE_GPU:
+                    ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
+                    if (ggml_backend_reg_name(reg) == std::string("RPC")) {
+                        rpc_servers.push_back(dev);
+                    } else {
+                        model->devices.push_back(dev);
+                    }
+                    break;
+            }
+        }
+        // add RPC servers at the front of the list
+        if (!rpc_servers.empty()) {
+            model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
         }
-        model->rpc_servers.push_back(servers);
     }
-    int status = llama_model_load(path_model, *model, params);
+
+    // if using single GPU mode, remove all except the main GPU
+    if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
+        if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) {
+            LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size());
+            llama_model_free(model);
+            return nullptr;
+        }
+        ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
+        model->devices.clear();
+        model->devices.push_back(main_gpu);
+    }
+
+    for (auto * dev : model->devices) {
+        size_t free, total; // NOLINT
+        ggml_backend_dev_memory(dev, &free, &total);
+        LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
+    }
+
+    const int status = llama_model_load(path_model, splits, *model, params);
     GGML_ASSERT(status <= 0);
     if (status < 0) {
         if (status == -1) {
@@ -18088,18 +9490,44 @@ struct llama_model * llama_load_model_from_file(
         } else if (status == -2) {
             LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
         }
-        delete model;
+
+        llama_model_free(model);
         return nullptr;
     }
 
     return model;
 }
 
-void llama_free_model(struct llama_model * model) {
-    delete model;
+// deprecated
+struct llama_model * llama_load_model_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    return llama_model_load_from_file(path_model, params);
 }
 
-struct llama_context * llama_new_context_with_model(
+struct llama_model * llama_model_load_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    std::vector splits = {};
+    return llama_model_load_from_file_impl(path_model, splits, params);
+}
+
+struct llama_model * llama_model_load_from_splits(
+        const char ** paths,
+        size_t n_paths,
+        struct llama_model_params params) {
+    std::vector splits;
+    if (n_paths == 0) {
+        LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
+        return nullptr;
+    }
+    for (size_t i = 0; i < n_paths; ++i) {
+        splits.push_back(paths[i]);
+    }
+    return llama_model_load_from_file_impl(splits.front(), splits, params);
+}
+
+struct llama_context * llama_init_from_model(
                  struct llama_model * model,
         struct llama_context_params   params) {
 
@@ -18128,7 +9556,7 @@ struct llama_context * llama_new_context_with_model(
         params.flash_attn = false;
     }
 
-    if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
+    if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;
     }
@@ -18149,6 +9577,7 @@ struct llama_context * llama_new_context_with_model(
     cparams.embeddings       = params.embeddings;
     cparams.offload_kqv      = params.offload_kqv;
     cparams.flash_attn       = params.flash_attn;
+    cparams.no_perf          = params.no_perf;
     cparams.pooling_type     = params.pooling_type;
 
     cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
@@ -18207,15 +9636,26 @@ struct llama_context * llama_new_context_with_model(
         cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
     }
 
-    LLAMA_LOG_INFO("%s: n_ctx      = %u\n",     __func__, cparams.n_ctx);
-    LLAMA_LOG_INFO("%s: n_batch    = %u\n",     __func__, cparams.n_batch);
-    LLAMA_LOG_INFO("%s: n_ubatch   = %u\n",     __func__, cparams.n_ubatch);
-    LLAMA_LOG_INFO("%s: flash_attn = %d\n",     __func__, cparams.flash_attn);
-    LLAMA_LOG_INFO("%s: freq_base  = %.1f\n",   __func__, cparams.rope_freq_base);
-    LLAMA_LOG_INFO("%s: freq_scale = %g\n",     __func__, cparams.rope_freq_scale);
+    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
 
-    ctx->abort_callback      = params.abort_callback;
-    ctx->abort_callback_data = params.abort_callback_data;
+    LLAMA_LOG_INFO("%s: n_seq_max     = %u\n",   __func__, cparams.n_seq_max);
+    LLAMA_LOG_INFO("%s: n_ctx         = %u\n",   __func__, cparams.n_ctx);
+    LLAMA_LOG_INFO("%s: n_ctx_per_seq = %u\n",   __func__, n_ctx_per_seq);
+    LLAMA_LOG_INFO("%s: n_batch       = %u\n",   __func__, cparams.n_batch);
+    LLAMA_LOG_INFO("%s: n_ubatch      = %u\n",   __func__, cparams.n_ubatch);
+    LLAMA_LOG_INFO("%s: flash_attn    = %d\n",   __func__, cparams.flash_attn);
+    LLAMA_LOG_INFO("%s: freq_base     = %.1f\n", __func__, cparams.rope_freq_base);
+    LLAMA_LOG_INFO("%s: freq_scale    = %g\n",   __func__, cparams.rope_freq_scale);
+
+    if (n_ctx_per_seq < hparams.n_ctx_train) {
+        LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
+                __func__, n_ctx_per_seq, hparams.n_ctx_train);
+    }
+
+    if (n_ctx_per_seq > hparams.n_ctx_train) {
+        LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
+                __func__, n_ctx_per_seq, hparams.n_ctx_train);
+    }
 
     ctx->logits_all = params.logits_all;
 
@@ -18239,154 +9679,55 @@ struct llama_context * llama_new_context_with_model(
     GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
 
     if (!hparams.vocab_only) {
-        // initialize backends
-#if defined(GGML_USE_RPC)
-        if (model->n_gpu_layers > 0) {
-            for (const auto & endpoint : model->rpc_servers) {
-                ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str());
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str());
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
+        // GPU backends
+        for (auto * dev : model->devices) {
+            ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
+                llama_free(ctx);
+                return nullptr;
             }
+            ctx->backends.emplace_back(backend);
         }
-#endif
 
-#if defined(GGML_USE_METAL)
-        if (model->n_gpu_layers > 0) {
-            ctx->backend_metal = ggml_backend_metal_init();
-            if (ctx->backend_metal == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Metal backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(ctx->backend_metal);
-        }
-#elif defined(GGML_USE_CUDA)
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-            ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
-            for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
-                ggml_backend_t backend = ggml_backend_cuda_init(device);
+        // add ACCEL backends (such as BLAS)
+        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+            if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
+                ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
                 if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device);
+                    LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
                     llama_free(ctx);
                     return nullptr;
                 }
-                ctx->backends.push_back(backend);
+                ctx->backends.emplace_back(backend);
             }
         }
-#elif defined(GGML_USE_VULKAN)
-        if (model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE) {
-            ggml_backend_t backend = ggml_backend_vk_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
-                ggml_backend_t backend = ggml_backend_vk_init(device);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize Vulkan%d backend\n", __func__, device);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(GGML_USE_SYCL)
-        // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            // LLAMA_SPLIT_LAYER requires a backend for each GPU
-            for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) {
-                ggml_backend_t backend = ggml_backend_sycl_init(i);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d for No.%d backend\n", __func__, i, i);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(GGML_USE_KOMPUTE)
-        if (model->n_gpu_layers > 0) {
-            auto * backend = ggml_backend_kompute_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Kompute backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        }
-#elif defined(GGML_USE_CANN)
-    // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-    // TODO: ggml_backend_cann is not support split tensor now, just leave code here.
-    if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-        ggml_backend_t backend = ggml_backend_cann_init(model->main_gpu);
-        if (backend == nullptr) {
-            LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, model->main_gpu);
-            llama_free(ctx);
-            return nullptr;
-        }
-        ctx->backends.push_back(backend);
-    } else {
-        // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
-        // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
-        for (int32_t device = 0; device < ggml_backend_cann_get_device_count(); ++device) {
-            ggml_backend_t backend = ggml_backend_cann_init(device);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        }
-    }
-#endif
 
-#ifdef GGML_USE_BLAS
-        ctx->backend_blas = ggml_backend_blas_init();
-        if (ctx->backend_blas == nullptr) {
-            LLAMA_LOG_WARN("%s: failed to initialize BLAS backend\n", __func__);
-        } else {
-            ctx->backends.push_back(ctx->backend_blas);
-        }
-#endif
-
-        ctx->backend_cpu = ggml_backend_cpu_init();
+        // add CPU backend
+        ctx->backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
         if (ctx->backend_cpu == nullptr) {
             LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
             llama_free(ctx);
             return nullptr;
         }
-        ctx->backends.push_back(ctx->backend_cpu);
+        ctx->backends.emplace_back(ctx->backend_cpu);
 
-        if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
+        // create a list of the set_n_threads functions in the backends
+        for (auto & backend : ctx->backends) {
+            ggml_backend_dev_t dev = ggml_backend_get_device(backend.get());
+            ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
+            if (reg) {
+                auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
+                if (ggml_backend_set_n_threads_fn) {
+                    ctx->set_n_threads_fns.emplace_back(backend.get(), ggml_backend_set_n_threads_fn);
+                }
+            }
+        }
+
+        llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data);
+
+        if (!llama_kv_cache_init(ctx->kv_self, ctx->model, ctx->cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
             LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
             llama_free(ctx);
             return nullptr;
@@ -18405,7 +9746,7 @@ struct llama_context * llama_new_context_with_model(
             }
 
             LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+                      (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
                 ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
         }
@@ -18420,63 +9761,100 @@ struct llama_context * llama_new_context_with_model(
             }
 
             LLAMA_LOG_INFO("%s: %10s  output buffer size = %8.2f MiB\n", __func__,
-                    ggml_backend_buffer_name(ctx->buf_output),
-                    ggml_backend_buffer_get_size(ctx->buf_output) / 1024.0 / 1024.0);
+                    ggml_backend_buffer_name(ctx->buf_output.get()),
+                    ggml_backend_buffer_get_size(ctx->buf_output.get()) / 1024.0 / 1024.0);
         }
 
         // scheduler and compute buffers
         {
             // buffer types used for the compute buffer of each backend
             std::vector backend_buft;
-            for (auto * backend : ctx->backends) {
-                if (ggml_backend_is_cpu(backend)) {
-                    // use host buffers for the CPU backend compute buffer
-                    backend_buft.push_back(llama_default_buffer_type_cpu(true));
-                } else {
-                    backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
+            std::vector backend_ptrs;
+            for (auto & backend : ctx->backends) {
+                auto * buft = ggml_backend_get_default_buffer_type(backend.get());
+                auto backend_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
+                if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) {
+                    // use the host buffer of the first device CPU for faster transfer of the intermediate state
+                    auto * dev = model->devices[0];
+                    auto * host_buft = ggml_backend_dev_host_buffer_type(dev);
+                    if (host_buft) {
+                        buft = host_buft;
+                    }
                 }
+                backend_buft.push_back(buft);
+                backend_ptrs.push_back(backend.get());
             }
 
-            const size_t max_nodes = llama_model_max_nodes(*model);
+            const size_t max_nodes = model->max_nodes();
 
             // buffer used to store the computation graph and the tensor meta data
             ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
 
+            // TODO: move these checks to ggml_backend_sched
             // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
             bool pipeline_parallel =
-                llama_get_device_count(*model) > 1 &&
-                model->n_gpu_layers > (int)model->hparams.n_layer &&
-                model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
+                model->n_devices() > 1 &&
+                model->params.n_gpu_layers > (int)model->hparams.n_layer &&
+                model->params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
                 params.offload_kqv;
-#ifndef GGML_USE_CUDA
-            // pipeline parallelism requires support for async compute and events
-            // currently this is only implemented in the CUDA backend
-            pipeline_parallel = false;
-#endif
-            ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), max_nodes, pipeline_parallel);
 
+            // pipeline parallelism requires support for async compute and events in all devices
             if (pipeline_parallel) {
-                LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));
+                for (auto & backend : ctx->backends) {
+                    auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get()));
+                    if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) {
+                        // ignore CPU backend
+                        continue;
+                    }
+                    auto * dev = ggml_backend_get_device(backend.get());
+                    ggml_backend_dev_props props;
+                    ggml_backend_dev_get_props(dev, &props);
+                    if (!props.caps.async || !props.caps.events) {
+                        // device does not support async compute or events
+                        pipeline_parallel = false;
+                        break;
+                    }
+                }
             }
 
-            // build worst-case graph
-            uint32_t n_seqs = 1; // TODO: worst-case number of sequences
-            uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-            llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-            llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
-            ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);
+            ctx->sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
+
+            if (pipeline_parallel) {
+                LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched.get()));
+            }
 
             // initialize scheduler with the worst-case graph
-            if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
+            uint32_t n_seqs = 1; // TODO: worst-case number of sequences
+            uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
+            llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+
+            llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+            ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
+
+            // reserve pp graph first so that buffers are only allocated once
+            ggml_backend_sched_reserve(ctx->sched.get(), gf_pp);
+            int n_splits_pp = ggml_backend_sched_get_n_splits(ctx->sched.get());
+            int n_nodes_pp = ggml_graph_n_nodes(gf_pp);
+
+            // reserve with tg graph to get the number of splits and nodes
+            llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
+            ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true);
+            ggml_backend_sched_reserve(ctx->sched.get(), gf_tg);
+            int n_splits_tg = ggml_backend_sched_get_n_splits(ctx->sched.get());
+            int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
+
+            // reserve again with pp graph to avoid ggml-alloc reallocations during inference
+            gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
+            if (!ggml_backend_sched_reserve(ctx->sched.get(), gf_pp)) {
                 LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
                 llama_free(ctx);
                 return nullptr;
             }
 
-            for (size_t i = 0; i < ctx->backends.size(); i++) {
-                ggml_backend_t backend = ctx->backends[i];
+            for (size_t i = 0; i < backend_ptrs.size(); ++i) {
+                ggml_backend_t backend = backend_ptrs[i];
                 ggml_backend_buffer_type_t buft = backend_buft[i];
-                size_t size = ggml_backend_sched_get_buffer_size(ctx->sched, backend);
+                size_t size = ggml_backend_sched_get_buffer_size(ctx->sched.get(), backend);
                 if (size > 1) {
                     LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
                             ggml_backend_buft_name(buft),
@@ -18484,449 +9862,48 @@ struct llama_context * llama_new_context_with_model(
                 }
             }
 
-            // note: the number of splits during measure is higher than during inference due to the kv shift
-            int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
-            LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, gf->n_nodes);
-            LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits);
+            if (n_nodes_pp == n_nodes_tg) {
+                LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, n_nodes_pp);
+            } else {
+                LLAMA_LOG_INFO("%s: graph nodes  = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg);
+            }
+            if (n_splits_pp == n_splits_tg) {
+                LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp);
+            } else {
+                LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
+            }
         }
     }
 
     return ctx;
 }
 
-void llama_free(struct llama_context * ctx) {
-    delete ctx;
+struct llama_context * llama_new_context_with_model(
+                 struct llama_model * model,
+        struct llama_context_params   params) {
+    return llama_init_from_model(model, params);
 }
 
-uint32_t llama_n_ctx(const struct llama_context * ctx) {
-    return ctx->cparams.n_ctx;
-}
+//
+// kv cache
+//
 
-uint32_t llama_n_batch(const struct llama_context * ctx) {
-    return ctx->cparams.n_batch;
-}
-
-uint32_t llama_n_ubatch(const struct llama_context * ctx) {
-    return ctx->cparams.n_ubatch;
-}
-
-uint32_t llama_n_seq_max(const struct llama_context * ctx) {
-    return ctx->kv_self.size;
-}
-
-enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
-    return model->vocab.type;
-}
-
-int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->hparams.n_vocab;
-}
-
-int32_t llama_n_ctx_train(const struct llama_model * model) {
-    return model->hparams.n_ctx_train;
-}
-
-int32_t llama_n_embd(const struct llama_model * model) {
-    return model->hparams.n_embd;
-}
-
-int32_t llama_n_layer(const struct llama_model * model) {
-    return model->hparams.n_layer;
-}
-
-const struct llama_model * llama_get_model(const struct llama_context * ctx) {
-    return &ctx->model;
-}
-
-enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
-    return ctx->cparams.pooling_type;
-}
-
-enum llama_rope_type llama_rope_type(const struct llama_model * model) {
-    switch (model->arch) {
-        // these models do not use RoPE
-        case LLM_ARCH_GPT2:
-        case LLM_ARCH_GPTJ:
-        case LLM_ARCH_MPT:
-        case LLM_ARCH_REFACT:
-        case LLM_ARCH_BLOOM:
-        case LLM_ARCH_MAMBA:
-        case LLM_ARCH_JINA_BERT_V2:
-        case LLM_ARCH_T5:
-        case LLM_ARCH_T5ENCODER:
-        case LLM_ARCH_JAIS:
-        case LLM_ARCH_RWKV6:
-            return LLAMA_ROPE_TYPE_NONE;
-
-        // use what we call a normal RoPE, operating on pairs of consecutive head values
-        case LLM_ARCH_LLAMA:
-        case LLM_ARCH_BAICHUAN:
-        case LLM_ARCH_STARCODER:
-        case LLM_ARCH_PLAMO:
-        case LLM_ARCH_ORION:
-        case LLM_ARCH_INTERNLM2:
-        case LLM_ARCH_MINICPM:
-        case LLM_ARCH_XVERSE:
-        case LLM_ARCH_COMMAND_R:
-        case LLM_ARCH_OLMO:
-        case LLM_ARCH_ARCTIC:
-        case LLM_ARCH_DEEPSEEK2:
-        case LLM_ARCH_CHATGLM:
-            return LLAMA_ROPE_TYPE_NORM;
-
-        // the pairs of head values are offset by n_rot/2
-        case LLM_ARCH_FALCON:
-        case LLM_ARCH_GROK:
-        case LLM_ARCH_DBRX:
-        case LLM_ARCH_BERT:
-        case LLM_ARCH_NOMIC_BERT:
-        case LLM_ARCH_STABLELM:
-        case LLM_ARCH_BITNET:
-        case LLM_ARCH_QWEN:
-        case LLM_ARCH_QWEN2:
-        case LLM_ARCH_QWEN2MOE:
-        case LLM_ARCH_PHI2:
-        case LLM_ARCH_PHI3:
-        case LLM_ARCH_GEMMA:
-        case LLM_ARCH_GEMMA2:
-        case LLM_ARCH_STARCODER2:
-        case LLM_ARCH_OPENELM:
-        case LLM_ARCH_GPTNEOX:
-        case LLM_ARCH_CODESHELL:
-        case LLM_ARCH_NEMOTRON:
-        case LLM_ARCH_EXAONE:
-            return LLAMA_ROPE_TYPE_NEOX;
-
-        // all model arches should be listed explicitly here
-        case LLM_ARCH_UNKNOWN:
-            GGML_ABORT("unknown architecture");
-    }
-
-    return LLAMA_ROPE_TYPE_NONE;
-}
-
-float llama_rope_freq_scale_train(const struct llama_model * model) {
-    return model->hparams.rope_freq_scale_train;
-}
-
-int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) {
-    const auto & it = model->gguf_kv.find(key);
-    if (it == model->gguf_kv.end()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
-        }
-        return -1;
-    }
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
-
-int32_t llama_model_meta_count(const struct llama_model * model) {
-    return (int)model->gguf_kv.size();
-}
-
-int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
-        }
-        return -1;
-    }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->first.c_str());
-}
-
-int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
-        }
-        return -1;
-    }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
-
-int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
-    return snprintf(buf, buf_size, "%s %s %s",
-            llama_model_arch_name(model->arch),
-            llama_model_type_name(model->type),
-            llama_model_ftype_name(model->ftype).c_str());
-}
-
-uint64_t llama_model_size(const struct llama_model * model) {
-    uint64_t size = 0;
-    for (const auto & it : model->tensors_by_name) {
-        size += ggml_nbytes(it.second);
-    }
-    return size;
-}
-
-uint64_t llama_model_n_params(const struct llama_model * model) {
-    uint64_t nparams = 0;
-    for (const auto & it : model->tensors_by_name) {
-        nparams += ggml_nelements(it.second);
-    }
-    return nparams;
-}
-
-struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
-    auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(),
-            [name](const std::pair & it) {
-                return it.first == name;
-            });
-    if (it == model->tensors_by_name.end()) {
-        return nullptr;
-    }
-    return it->second;
-}
-
-bool llama_model_has_encoder(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_T5:        return true;
-        case LLM_ARCH_T5ENCODER: return true;
-        default:                 return false;
-    }
-}
-
-bool llama_model_has_decoder(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_T5ENCODER: return false;
-        default:                 return true;
-    }
-}
-
-llama_token llama_model_decoder_start_token(const struct llama_model * model) {
-    return model->hparams.dec_start_token_id;
-}
-
-bool llama_model_is_recurrent(const struct llama_model * model) {
-    switch (model->arch) {
-        case LLM_ARCH_MAMBA:  return true;
-        case LLM_ARCH_RWKV6:  return true;
-        default:              return false;
-    }
-}
-
-uint32_t llama_model_quantize(
-        const char * fname_inp,
-        const char * fname_out,
-        const llama_model_quantize_params * params) {
-    try {
-        llama_model_quantize_internal(fname_inp, fname_out, params);
-        return 0;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
-        return 1;
-    }
-}
-
-struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
-    try {
-        struct llama_lora_adapter * adapter = new llama_lora_adapter(model);
-        llama_lora_adapter_init_internal(model, path_lora, *adapter);
-        return adapter;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
-        return nullptr;
-    }
-}
-
-static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
-    GGML_ASSERT(cvec.tensors.empty());
-    GGML_ASSERT(cvec.ctxs.empty());
-    GGML_ASSERT(cvec.bufs.empty());
-
-    // count layer buffer types
-    std::map buft_layer_count;
-    for (int64_t i = 0; i < model.hparams.n_layer; i++) {
-        buft_layer_count[model.buft_layer[i].buft]++;
-    }
-
-    // allocate contexts
-    std::map ctx_map;
-    for (auto & it : buft_layer_count) {
-        int n_layers = it.second;
-        struct ggml_init_params params = {
-            /*.mem_size   =*/ n_layers * ggml_tensor_overhead(),
-            /*.mem_buffer =*/ NULL,
-            /*.no_alloc   =*/ true,
-        };
-        ggml_context * ctx = ggml_init(params);
-        if (!ctx) {
-            LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
-            return 1;
-        }
-        ctx_map[it.first] = ctx;
-    }
-
-    // make tensors
-    cvec.tensors.reserve(model.hparams.n_layer);
-    cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
-    for (size_t il = 1; il < model.hparams.n_layer; il++) {
-        struct ggml_context * ctx = ctx_map.at(model.buft_layer[il].buft);
-        ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd);
-        cvec.tensors.push_back(tensor);
-    }
-
-    // allocate tensors / buffers and zero
-    cvec.ctxs.reserve(ctx_map.size());
-    cvec.bufs.reserve(ctx_map.size());
-    for (auto it : ctx_map) {
-        ggml_backend_buffer_type_t buft = it.first;
-        ggml_context * ctx = it.second;
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-        if (!buf) {
-            LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
-            return false;
-        }
-        ggml_backend_buffer_clear(buf, 0);
-        cvec.ctxs.push_back(ctx);
-        cvec.bufs.push_back(buf);
-    }
-
-    return true;
-}
-
-int32_t llama_control_vector_apply(struct llama_context * lctx, const float * data, size_t len, int32_t n_embd, int32_t il_start, int32_t il_end) {
-    const llama_model & model = lctx->model;
-    llama_control_vector & cvec = lctx->cvec;
-
-    if (data == nullptr) {
-        // disable the current control vector (but leave allocated for later)
-        cvec.layer_start = -1;
-        cvec.layer_end   = -1;
-        return 0;
-    }
-
-    if (n_embd != (int) model.hparams.n_embd) {
-        LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
-        return 1;
-    }
-
-    if (cvec.tensors.empty()) {
-        if (!llama_control_vector_init(cvec, model)) {
-            return 1;
-        }
-    }
-
-    cvec.layer_start = il_start;
-    cvec.layer_end   = il_end;
-
-    for (size_t il = 1; il < model.hparams.n_layer; il++) {
-        assert(cvec.tensors[il] != nullptr);
-
-        const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
-        if (off + n_embd <= len) {
-            ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
-        }
-    }
-
-    return 0;
-}
+// TODO: tmp bridges below until `struct llama_kv_cache` is exposed through the public API
 
 struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
-    struct llama_kv_cache_view result = {
-        /*.n_cells            = */ 0,
-        /*.n_seq_max          = */ n_seq_max,
-        /*.token_count        = */ 0,
-        /*.used_cells         = */ llama_get_kv_cache_used_cells(ctx),
-        /*.max_contiguous     = */ 0,
-        /*.max_contiguous_idx = */ -1,
-        /*.cells              = */ nullptr,
-        /*.cells_sequences    = */ nullptr,
-    };
-    return result;
-}
-
-void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
-    if (view->cells != nullptr) {
-        free(view->cells);
-        view->cells = nullptr;
-    }
-    if (view->cells_sequences != nullptr) {
-        free(view->cells_sequences);
-        view->cells_sequences = nullptr;
-    }
+    return llama_kv_cache_view_init(ctx->kv_self, n_seq_max);
 }
 
 void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
-    if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
-        view->n_cells = int32_t(ctx->kv_self.size);
-        void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
-        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
-        view->cells = (struct llama_kv_cache_view_cell *)p;
-        p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
-        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
-        view->cells_sequences = (llama_seq_id *)p;
-    }
-
-    const std::vector & kv_cells = ctx->kv_self.cells;
-    llama_kv_cache_view_cell * c_curr = view->cells;
-    llama_seq_id * cs_curr = view->cells_sequences;
-    int32_t used_cells = 0;
-    int32_t token_count = 0;
-    int32_t curr_contig_idx = -1;
-    uint32_t max_contig = 0;
-    int32_t max_contig_idx = -1;
-
-    for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) {
-        const size_t curr_size = kv_cells[i].seq_id.size();
-        token_count += curr_size;
-        c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
-
-        if (curr_size > 0) {
-            if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
-                max_contig = i - curr_contig_idx;
-                max_contig_idx = curr_contig_idx;
-            }
-            curr_contig_idx = -1;
-        } else if (curr_contig_idx < 0) {
-            curr_contig_idx = i;
-        }
-
-        int seq_idx = 0;
-        for (const llama_seq_id it : kv_cells[i].seq_id) {
-            if (seq_idx >= view->n_seq_max) {
-                break;
-            }
-            cs_curr[seq_idx] = it;
-            seq_idx++;
-        }
-        if (seq_idx != 0) {
-            used_cells++;
-        }
-        for (; seq_idx < view->n_seq_max; seq_idx++) {
-            cs_curr[seq_idx] = -1;
-        }
-    }
-    if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
-        max_contig_idx = curr_contig_idx;
-        max_contig = kv_cells.size() - curr_contig_idx;
-    }
-    view->max_contiguous = max_contig;
-    view->max_contiguous_idx = max_contig_idx;
-    view->token_count = token_count;
-    view->used_cells = used_cells;
-    if (uint32_t(used_cells) != ctx->kv_self.used) {
-        LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
-            __func__, ctx->kv_self.used, used_cells);
-    }
+    llama_kv_cache_view_update(view, ctx->kv_self);
 }
 
 int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) {
-    int result = 0;
-
-    for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
-        result += ctx->kv_self.cells[i].seq_id.size();
-    }
-
-    return result;
+    return llama_get_kv_cache_token_count(ctx->kv_self);
 }
 
 int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
-    return ctx->kv_self.used;
+    return llama_get_kv_cache_used_cells(ctx->kv_self);
 }
 
 void llama_kv_cache_clear(struct llama_context * ctx) {
@@ -18973,1074 +9950,20 @@ void llama_kv_cache_defrag(struct llama_context * ctx) {
 }
 
 void llama_kv_cache_update(struct llama_context * ctx) {
-    llama_kv_cache_update_internal(*ctx);
+    llama_kv_cache_update_impl(*ctx);
 }
 
-// deprecated
-size_t llama_get_state_size(struct llama_context * ctx) {
-    return llama_state_get_size(ctx);
+bool llama_kv_cache_can_shift(struct llama_context * ctx) {
+    return llama_kv_cache_can_shift(ctx->kv_self);
 }
 
-// deprecated
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
-    return llama_state_get_data(ctx, dst, -1);
-}
-
-// deprecated
-size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
-    return llama_state_set_data(ctx, src, -1);
-}
-
-// deprecated
-bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
-}
-
-// deprecated
-bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    return llama_state_save_file(ctx, path_session, tokens, n_token_count);
-}
-
-// TODO: replace all non-fatal assertions with returned errors or exceptions
-struct llama_data_write {
-    virtual void write(const void * src, size_t size) = 0;
-    virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
-    virtual size_t get_size_written() = 0;
-    virtual ~llama_data_write() = default;
-
-    void write_string(const std::string & str) {
-        uint32_t str_size = str.size();
-
-        write(&str_size,  sizeof(str_size));
-        write(str.data(), str_size);
-    }
-
-    void write_model_info(const struct llama_context * ctx) {
-        std::string arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
-        write_string(arch_str);
-        // TODO: add more model-specific info which should prevent loading the session file if not identical
-    }
-
-    //void write_rng(const std::mt19937 & rng) {
-    //    std::ostringstream rng_ss;
-    //    rng_ss << rng;
-
-    //    const std::string & rng_str = rng_ss.str();
-
-    //    write_string(rng_str);
-    //}
-
-    void write_output_ids(struct llama_context * ctx) {
-        llama_output_reorder(ctx);
-
-        const uint32_t n_outputs = ctx->n_outputs;
-
-        std::vector output_pos;
-
-        const size_t    n_batch = ctx->cparams.n_batch;
-        const auto & output_ids = ctx->output_ids;
-
-        GGML_ASSERT(n_outputs <= ctx->output_size);
-
-        output_pos.resize(n_outputs);
-
-        // build a more compact representation of the output ids
-        for (size_t i = 0; i < n_batch; ++i) {
-            // map an output id to a position in the batch
-            int32_t pos = output_ids[i];
-            if (pos >= 0) {
-                GGML_ASSERT((uint32_t) pos < n_outputs);
-                output_pos[pos] = i;
-            }
-        }
-
-        write(&n_outputs, sizeof(n_outputs));
-
-        if (n_outputs) {
-            write(output_pos.data(), n_outputs * sizeof(int32_t));
-        }
-    }
-
-    void write_logits(const struct llama_context * ctx) {
-        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
-
-        write(&logits_size, sizeof(logits_size));
-
-        if (logits_size) {
-            write(ctx->logits, logits_size * sizeof(float));
-        }
-    }
-
-    void write_embeddings(const struct llama_context * ctx) {
-        const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
-
-        write(&embeddings_size, sizeof(embeddings_size));
-
-        if (embeddings_size) {
-            write(ctx->embd, embeddings_size * sizeof(float));
-        }
-    }
-
-    void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) {
-
-        for (const auto & range : cell_ranges) {
-            for (uint32_t i = range.first; i < range.second; ++i) {
-                const auto & cell = kv_self.cells[i];
-                const llama_pos pos      = cell.pos;
-                const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
-
-                write(&pos,      sizeof(pos));
-                write(&n_seq_id, sizeof(n_seq_id));
-
-                if (n_seq_id) {
-                    for (auto seq_id : cell.seq_id) {
-                        write(&seq_id, sizeof(seq_id));
-                    }
-                }
-            }
-        }
-    }
-
-    void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) {
-        const struct llama_kv_cache & kv_self = ctx->kv_self;
-        const struct llama_hparams & hparams = ctx->model.hparams;
-
-        const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
-        const uint32_t n_layer = hparams.n_layer;
-
-        write(&v_trans, sizeof(v_trans));
-        write(&n_layer, sizeof(n_layer));
-
-        std::vector tmp_buf;
-
-        // Iterate and write all the keys first, each row is a cell
-        // Get whole range at a time
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-            // Write key type
-            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-            write(&k_type_i, sizeof(k_type_i));
-
-            // Write row size of key
-            const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-            write(&k_size_row, sizeof(k_size_row));
-
-            // Read each range of cells of k_size length each into tmp_buf and write out
-            for (const auto & range : cell_ranges) {
-                const size_t range_size = range.second - range.first;
-                const size_t buf_size = range_size * k_size_row;
-                write_tensor_data(kv_self.k_l[il], range.first * k_size_row, buf_size);
-            }
-        }
-
-        if (!kv_self.v_trans) {
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-                // Write value type
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                write(&v_type_i, sizeof(v_type_i));
-
-                // Write row size of value
-                const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-                write(&v_size_row, sizeof(v_size_row));
-
-                // Read each range of cells of v_size length each into tmp_buf and write out
-                for (const auto & range : cell_ranges) {
-                    const size_t range_size = range.second - range.first;
-                    const size_t buf_size = range_size * v_size_row;
-                    write_tensor_data(kv_self.v_l[il], range.first * v_size_row, buf_size);
-                }
-            }
-        } else {
-            // When v is transposed, we also need the element size and get the element ranges from each row
-            const uint32_t kv_size = kv_self.size;
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-                // Write value type
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                write(&v_type_i, sizeof(v_type_i));
-
-                // Write element size
-                const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-                write(&v_size_el, sizeof(v_size_el));
-
-                // Write GQA embedding size
-                write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
-
-                // For each row, we get the element values of each cell
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    // Read each range of cells of v_size_el length each into tmp_buf and write out
-                    for (const auto & range : cell_ranges) {
-                        const size_t range_size = range.second - range.first;
-                        const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                        const size_t buf_size = range_size * v_size_el;
-                        write_tensor_data(kv_self.v_l[il], src_offset, buf_size);
-                    }
-                }
-            }
-        }
-    }
-
-    void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
-        const struct llama_kv_cache & kv_self = ctx->kv_self;
-        std::vector> cell_ranges; // ranges, from inclusive, to exclusive
-        uint32_t cell_count = 0;
-
-        // Count the number of cells with the specified seq_id
-        // Find all the ranges of cells with this seq id (or all, when -1)
-        uint32_t cell_range_begin = kv_self.size;
-        for (uint32_t i = 0; i < kv_self.size; ++i) {
-            const auto & cell = kv_self.cells[i];
-            if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
-                ++cell_count;
-                if (cell_range_begin == kv_self.size) {
-                    cell_range_begin = i;
-                }
-            } else {
-                if (cell_range_begin != kv_self.size) {
-                    cell_ranges.emplace_back(cell_range_begin, i);
-                    cell_range_begin = kv_self.size;
-                }
-            }
-        }
-        if (cell_range_begin != kv_self.size) {
-            cell_ranges.emplace_back(cell_range_begin, kv_self.size);
-        }
-
-        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-        uint32_t cell_count_check = 0;
-        for (const auto & range : cell_ranges) {
-            cell_count_check += range.second - range.first;
-        }
-        GGML_ASSERT(cell_count == cell_count_check);
-
-        write(&cell_count, sizeof(cell_count));
-
-        write_kv_cache_meta(kv_self, cell_ranges, seq_id);
-        write_kv_cache_data(ctx, cell_ranges);
-    }
-};
-
-struct llama_data_read {
-    virtual const uint8_t * read(size_t size) = 0;
-    virtual void read_to(void * dst, size_t size) = 0;
-    virtual size_t get_size_read() = 0;
-    virtual ~llama_data_read() = default;
-
-    void read_string(std::string & str) {
-        uint32_t str_size;
-        read_to(&str_size, sizeof(str_size));
-
-        str.assign((const char *) read(str_size), str_size);
-    }
-
-    // validate model information
-    void read_model_info(const struct llama_context * ctx) {
-        std::string cur_arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
-        std::string arch_str;
-        read_string(arch_str);
-        if (cur_arch_str != arch_str) {
-            throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
-        }
-        // TODO: add more info which needs to be identical but which is not verified otherwise
-    }
-
-    //void read_rng(std::mt19937 & rng) {
-    //    std::string rng_str;
-    //    read_string(rng_str);
-
-    //    std::istringstream rng_ss(rng_str);
-    //    rng_ss >> rng;
-
-    //    if (rng_ss.fail()) {
-    //        throw std::runtime_error("failed to load RNG state");
-    //    }
-    //}
-
-    void read_output_ids(struct llama_context * ctx) {
-        std::vector output_pos;
-
-        uint32_t n_outputs;
-        read_to(&n_outputs, sizeof(n_outputs));
-
-        if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
-            throw std::runtime_error("could not reserve outputs");
-        }
-
-        if (n_outputs) {
-            output_pos.resize(n_outputs);
-            read_to(output_pos.data(), n_outputs * sizeof(int32_t));
-
-            for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
-                int32_t id = output_pos[i];
-                if ((uint32_t) id >= ctx->cparams.n_batch) {
-                    throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
-                }
-                ctx->output_ids[id] = i;
-            }
-
-            ctx->n_outputs = n_outputs;
-        }
-    }
-
-    void read_logits(struct llama_context * ctx) {
-        uint64_t logits_size;
-        read_to(&logits_size, sizeof(logits_size));
-
-        if (ctx->logits_size < logits_size) {
-            throw std::runtime_error("logits buffer too small");
-        }
-
-        if (logits_size) {
-            read_to(ctx->logits, logits_size * sizeof(float));
-        }
-    }
-
-    void read_embeddings(struct llama_context * ctx) {
-        uint64_t embeddings_size;
-        read_to(&embeddings_size, sizeof(embeddings_size));
-
-        if (ctx->embd_size < embeddings_size) {
-            throw std::runtime_error("embeddings buffer too small");
-        }
-
-        if (embeddings_size) {
-            read_to(ctx->embd, embeddings_size * sizeof(float));
-        }
-    }
-
-    bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
-        struct llama_kv_cache & kv_self = ctx->kv_self;
-
-        if (dest_seq_id != -1) {
-            // single sequence
-
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-
-            llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
-            batch.n_tokens = cell_count;
-            batch.n_seq_tokens = cell_count;
-            batch.n_seqs = 1;
-
-            for (uint32_t i = 0; i < cell_count; ++i) {
-                llama_pos pos;
-                uint32_t n_seq_id;
-
-                read_to(&pos, sizeof(pos));
-                read_to(&n_seq_id, sizeof(n_seq_id));
-
-                if (n_seq_id != 0) {
-                    LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
-                    return false;
-                }
-
-                batch.pos[i] = pos;
-            }
-            batch.n_seq_id[0] = 1;
-            batch.seq_id[0] = &dest_seq_id;
-            if (!llama_kv_cache_find_slot(kv_self, batch)) {
-                LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
-                return false;
-            }
-
-            // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
-            // Assume that this is one contiguous block of cells
-            GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
-            GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
-            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
-            GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
-            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
-        } else {
-            // whole KV cache restore
-
-            if (cell_count > kv_self.size) {
-                LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
-                return false;
-            }
-
-            llama_kv_cache_clear(kv_self);
-
-            for (uint32_t i = 0; i < cell_count; ++i) {
-                llama_kv_cell & cell = kv_self.cells[i];
-
-                llama_pos pos;
-                uint32_t  n_seq_id;
-
-                read_to(&pos,      sizeof(pos));
-                read_to(&n_seq_id, sizeof(n_seq_id));
-
-                cell.pos = pos;
-
-                for (uint32_t j = 0; j < n_seq_id; ++j) {
-                    llama_seq_id seq_id;
-                    read_to(&seq_id, sizeof(seq_id));
-
-                    if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
-                        LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
-                        return false;
-                    }
-
-                    cell.seq_id.insert(seq_id);
-
-                    if (kv_self.recurrent) {
-                        int32_t & tail = kv_self.cells[seq_id].tail;
-                        if (tail != -1) {
-                            LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
-                            return false;
-                        }
-                        tail = i;
-                    }
-                }
-            }
-
-            kv_self.head = 0;
-            kv_self.used = cell_count;
-        }
-
-        if (kv_self.recurrent) {
-            for (uint32_t i = 0; i < cell_count; ++i) {
-                uint32_t cell_id = kv_self.head + i;
-                // make sure the recurrent states will keep their restored state
-                kv_self.cells[cell_id].src = cell_id;
-            }
-        }
-
-        return true;
-    }
-
-    bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
-        const struct llama_hparams & hparams = ctx->model.hparams;
-        struct llama_kv_cache & kv_self = ctx->kv_self;
-        uint32_t v_trans;
-        uint32_t n_layer;
-        read_to(&v_trans, sizeof(v_trans));
-        read_to(&n_layer, sizeof(n_layer));
-
-        if (n_layer != hparams.n_layer) {
-            LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
-            return false;
-        }
-        if (cell_count > kv_self.size) {
-            LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
-            return false;
-        }
-        if (kv_self.v_trans != (bool) v_trans) {
-            LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
-            return false;
-        }
-
-        // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
-        for (uint32_t il = 0; il < n_layer; ++il) {
-            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
-
-            // Read type of key
-            int32_t k_type_i_ref;
-            read_to(&k_type_i_ref, sizeof(k_type_i_ref));
-            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-            if (k_type_i != k_type_i_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
-                return false;
-            }
-
-            // Read row size of key
-            uint64_t k_size_row_ref;
-            read_to(&k_size_row_ref, sizeof(k_size_row_ref));
-            const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-            if (k_size_row != k_size_row_ref) {
-                LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
-                return false;
-            }
-
-            if (cell_count) {
-                // Read and set the keys for the whole cell range
-                ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
-            }
-        }
-
-        if (!kv_self.v_trans) {
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-                // Read type of value
-                int32_t v_type_i_ref;
-                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                if (v_type_i != v_type_i_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                    return false;
-                }
-
-                // Read row size of value
-                uint64_t v_size_row_ref;
-                read_to(&v_size_row_ref, sizeof(v_size_row_ref));
-                const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-                if (v_size_row != v_size_row_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
-                    return false;
-                }
-
-                if (cell_count) {
-                    // Read and set the values for the whole cell range
-                    ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
-                }
-            }
-        } else {
-            // For each layer, read the values for each cell (transposed)
-            for (uint32_t il = 0; il < n_layer; ++il) {
-                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
-
-                // Read type of value
-                int32_t v_type_i_ref;
-                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
-                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-                if (v_type_i != v_type_i_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                    return false;
-                }
-
-                // Read element size of value
-                uint32_t v_size_el_ref;
-                read_to(&v_size_el_ref, sizeof(v_size_el_ref));
-                const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-                if (v_size_el != v_size_el_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
-                    return false;
-                }
-
-                // Read GQA embedding size
-                uint32_t n_embd_v_gqa_ref;
-                read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
-                if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-                    LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
-                    return false;
-                }
-
-                if (cell_count) {
-                    // For each row in the transposed matrix, read the values for the whole cell range
-                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                        const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
-                        ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
-                    }
-                }
-            }
-        }
-        return true;
-    }
-
-    void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
-        uint32_t cell_count;
-        read_to(&cell_count, sizeof(cell_count));
-
-        bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
-
-        if (!res) {
-            if (seq_id == -1) {
-                llama_kv_cache_clear(ctx);
-            } else {
-                llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
-            }
-            throw std::runtime_error("failed to restore kv cache");
-        }
-    }
-};
-
-struct llama_data_write_dummy : llama_data_write {
-    size_t size_written = 0;
-
-    llama_data_write_dummy() {}
-
-    void write(const void * /* src */, size_t size) override {
-        size_written += size;
-    }
-
-    void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
-        size_written += size;
-    }
-
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
-
-struct llama_data_write_buffer : llama_data_write {
-    uint8_t * ptr;
-    size_t buf_size = 0;
-    size_t size_written = 0;
-
-    llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
-
-    void write(const void * src, size_t size) override {
-        if (size > buf_size) {
-            throw std::runtime_error("unexpectedly reached end of buffer");
-        }
-        memcpy(ptr, src, size);
-        ptr += size;
-        size_written += size;
-        buf_size -= size;
-    }
-
-    void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
-        if (size > buf_size) {
-            throw std::runtime_error("unexpectedly reached end of buffer");
-        }
-        ggml_backend_tensor_get(tensor, ptr, offset, size);
-        ptr += size;
-        size_written += size;
-        buf_size -= size;
-    }
-
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
-
-struct llama_data_read_buffer : llama_data_read {
-    const uint8_t * ptr;
-    size_t buf_size = 0;
-    size_t size_read = 0;
-
-    llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
-
-    const uint8_t * read(size_t size) override {
-        const uint8_t * base_ptr = ptr;
-        if (size > buf_size) {
-            throw std::runtime_error("unexpectedly reached end of buffer");
-        }
-        ptr += size;
-        size_read += size;
-        buf_size -= size;
-        return base_ptr;
-    }
-
-    void read_to(void * dst, size_t size) override {
-        memcpy(dst, read(size), size);
-    }
-
-    size_t get_size_read() override {
-        return size_read;
-    }
-};
-
-struct llama_data_write_file : llama_data_write {
-    llama_file * file;
-    size_t size_written = 0;
-    std::vector temp_buffer;
-
-    llama_data_write_file(llama_file * f) : file(f) {}
-
-    void write(const void * src, size_t size) override {
-        file->write_raw(src, size);
-        size_written += size;
-    }
-
-    void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
-        temp_buffer.resize(size);
-        ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
-        write(temp_buffer.data(), temp_buffer.size());
-    }
-
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
-
-struct llama_data_read_file : llama_data_read {
-    llama_file * file;
-    size_t size_read = 0;
-    std::vector temp_buffer;
-
-    llama_data_read_file(llama_file * f) : file(f) {}
-
-    void read_to(void * dst, size_t size) override {
-        file->read_raw(dst, size);
-        size_read += size;
-    }
-
-    const uint8_t * read(size_t size) override {
-        temp_buffer.resize(size);
-        read_to(temp_buffer.data(), size);
-        return temp_buffer.data();
-    }
-
-    size_t get_size_read() override {
-        return size_read;
-    }
-};
-
-/** copy state data into either a buffer or file depending on the passed in context
- *
- * file context:
- * llama_file file("/path", "wb");
- * llama_data_write_file data_ctx(&file);
- * llama_state_get_data_internal(ctx, data_ctx);
- *
- * buffer context:
- * std::vector buf(max_size, 0);
- * llama_data_write_buffer data_ctx(buf.data(), max_size);
- * llama_state_get_data_internal(ctx, data_ctx);
- *
-*/
-static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
-    llama_synchronize(ctx);
-
-    data_ctx.write_model_info(ctx);
-
-    // copy outputs
-    data_ctx.write_output_ids(ctx);
-    data_ctx.write_logits(ctx);
-    data_ctx.write_embeddings(ctx);
-
-    data_ctx.write_kv_cache(ctx);
-
-    return data_ctx.get_size_written();
-}
-
-size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
-    llama_data_write_buffer data_ctx(dst, size);
-    try {
-        return llama_state_get_data_internal(ctx, data_ctx);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-// Returns the *actual* size of the state.
-// Intended to be used when saving to state to a buffer.
-size_t llama_state_get_size(struct llama_context * ctx) {
-    llama_data_write_dummy data_ctx;
-    try {
-        return llama_state_get_data_internal(ctx, data_ctx);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
-    llama_synchronize(ctx);
-
-    data_ctx.read_model_info(ctx);
-
-    // set outputs
-    data_ctx.read_output_ids(ctx);
-    data_ctx.read_logits(ctx);
-    data_ctx.read_embeddings(ctx);
-
-    data_ctx.read_kv_cache(ctx);
-
-    return data_ctx.get_size_read();
-}
-
-// Sets the state reading from the specified source address
-size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
-    llama_data_read_buffer data_ctx(src, size);
-    try {
-        return llama_state_set_data_internal(ctx, data_ctx);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    llama_file file(path_session, "rb");
-
-    // sanity checks
-    {
-        const uint32_t magic   = file.read_u32();
-        const uint32_t version = file.read_u32();
-
-        if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
-            LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
-            return false;
-        }
-    }
-
-    // load the prompt
-    {
-        const uint32_t n_token_count = file.read_u32();
-
-        if (n_token_count > n_token_capacity) {
-            LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
-            return false;
-        }
-
-        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
-        *n_token_count_out = n_token_count;
-    }
-
-    // restore the context state
-    {
-        const size_t n_state_size_cur = file.size - file.tell();
-
-        llama_data_read_file data_ctx(&file);
-        const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
-
-        if (n_read != n_state_size_cur) {
-            LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
-            return false;
-        }
-    }
-    return true;
-}
-
-bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    try {
-        return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
-        return false;
-    }
-}
-
-static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    llama_file file(path_session, "wb");
-
-    file.write_u32(LLAMA_SESSION_MAGIC);
-    file.write_u32(LLAMA_SESSION_VERSION);
-
-    // save the prompt
-    file.write_u32((uint32_t) n_token_count);
-    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
-
-    // save the context state using stream saving
-    llama_data_write_file data_ctx(&file);
-    llama_state_get_data_internal(ctx, data_ctx);
-
-    return true;
-}
-
-bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    try {
-        return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
-        return false;
-    }
-}
-
-static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
-    llama_synchronize(ctx);
-
-    data_ctx.write_kv_cache(ctx, seq_id);
-
-    return data_ctx.get_size_written();
-}
-
-size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_data_write_dummy data_ctx;
-    return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-}
-
-size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
-    llama_data_write_buffer data_ctx(dst, size);
-    try {
-        return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
-    llama_synchronize(ctx);
-
-    data_ctx.read_kv_cache(ctx, dest_seq_id);
-
-    return data_ctx.get_size_read();
-}
-
-size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
-    llama_data_read_buffer data_ctx(src, size);
-    try {
-        return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
-    llama_file file(filepath, "wb");
-
-    file.write_u32(LLAMA_STATE_SEQ_MAGIC);
-    file.write_u32(LLAMA_STATE_SEQ_VERSION);
-
-    // save the prompt
-    file.write_u32((uint32_t) n_token_count);
-    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
-
-    // save the context state using stream saving
-    llama_data_write_file data_ctx(&file);
-    llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-
-    const size_t res = file.tell();
-    GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
-    return res;
-}
-
-static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    llama_file file(filepath, "rb");
-
-    // version checks
-    {
-        const uint32_t magic   = file.read_u32();
-        const uint32_t version = file.read_u32();
-
-        if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
-            LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
-            return 0;
-        }
-    }
-
-    // load the prompt
-    {
-        const uint32_t n_token_count = file.read_u32();
-
-        if (n_token_count > n_token_capacity) {
-            LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
-            return 0;
-        }
-
-        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
-        *n_token_count_out = n_token_count;
-    }
-
-    // restore the context state
-    {
-        const size_t state_size = file.size - file.tell();
-        llama_data_read_file data_ctx(&file);
-        const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
-        if (!nread) {
-            LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
-            return 0;
-        }
-        GGML_ASSERT(nread <= state_size);
-        GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
-    }
-
-    return file.tell();
-}
-
-size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
-    try {
-        return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    try {
-        return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
-        return 0;
-    }
-}
-
-void llama_set_n_threads(struct llama_context * ctx, int32_t n_threads, int32_t n_threads_batch) {
-    ctx->cparams.n_threads       = n_threads;
-    ctx->cparams.n_threads_batch = n_threads_batch;
-}
-
-int32_t llama_n_threads(struct llama_context * ctx) {
-    return ctx->cparams.n_threads;
-}
-
-int32_t llama_n_threads_batch(struct llama_context * ctx) {
-    return ctx->cparams.n_threads_batch;
-}
-
-void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
-    ctx->abort_callback      = abort_callback;
-    ctx->abort_callback_data = abort_callback_data;
-}
-
-void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
-    ctx->cparams.embeddings = embeddings;
-}
-
-void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
-    ctx->cparams.causal_attn = causal_attn;
-}
-
-struct llama_batch llama_batch_get_one(
-             llama_token * tokens,
-                 int32_t   n_tokens,
-               llama_pos   pos_0,
-            llama_seq_id   seq_id) {
-    return {
-        /*n_tokens       =*/ n_tokens,
-        /*tokens         =*/ tokens,
-        /*embd           =*/ nullptr,
-        /*pos            =*/ nullptr,
-        /*n_seq_id       =*/ nullptr,
-        /*seq_id         =*/ nullptr,
-        /*logits         =*/ nullptr,
-        /*all_pos_0      =*/ pos_0,
-        /*all_pos_1      =*/ 1,
-        /*all_seq_id     =*/ seq_id,
-    };
-}
-
-struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
-    llama_batch batch = {
-        /*n_tokens       =*/ 0,
-        /*tokens         =*/ nullptr,
-        /*embd           =*/ nullptr,
-        /*pos            =*/ nullptr,
-        /*n_seq_id       =*/ nullptr,
-        /*seq_id         =*/ nullptr,
-        /*logits         =*/ nullptr,
-        /*all_pos_0      =*/ 0,
-        /*all_pos_1      =*/ 0,
-        /*all_seq_id     =*/ 0,
-    };
-
-    if (embd) {
-        batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
-    } else {
-        batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
-    }
-
-    batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc);
-    batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc);
-    batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
-    for (int i = 0; i < n_tokens_alloc; ++i) {
-        batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
-    }
-    batch.seq_id[n_tokens_alloc] = nullptr;
-
-    batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc);
-
-    return batch;
-}
-
-void llama_batch_free(struct llama_batch batch) {
-    if (batch.token)    free(batch.token);
-    if (batch.embd)     free(batch.embd);
-    if (batch.pos)      free(batch.pos);
-    if (batch.n_seq_id) free(batch.n_seq_id);
-    if (batch.seq_id) {
-        for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
-            free(batch.seq_id[i]);
-        }
-        free(batch.seq_id);
-    }
-    if (batch.logits)   free(batch.logits);
-}
+///
 
 int32_t llama_encode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    const int ret = llama_encode_internal(*ctx, batch);
-    if (ret < 0) {
+    const int ret = llama_encode_impl(*ctx, batch);
+    if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
     }
 
@@ -20050,549 +9973,26 @@ int32_t llama_encode(
 int32_t llama_decode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    const int ret = llama_decode_internal(*ctx, batch);
-    if (ret < 0) {
+    const int ret = llama_decode_impl(*ctx, batch);
+    if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
 
     return ret;
 }
 
-void llama_synchronize(struct llama_context * ctx) {
-    ggml_backend_sched_synchronize(ctx->sched);
-
-    // FIXME: if multiple single tokens are evaluated without a synchronization,
-    // the stats will be added to the prompt evaluation stats
-    // this should only happen when using batch size 1 to evaluate a batch
-
-    // add the evaluation to the stats
-    if (ctx->n_queued_tokens == 1) {
-        ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
-        ctx->n_eval++;
-    } else if (ctx->n_queued_tokens > 1) {
-        ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
-        ctx->n_p_eval += ctx->n_queued_tokens;
-    }
-
-    // get a more accurate load time, upon first eval
-    if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
-        ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
-        ctx->has_evaluated_once = true;
-    }
-
-    ctx->n_queued_tokens = 0;
-    ctx->t_compute_start_us = 0;
-}
-
-float * llama_get_logits(struct llama_context * ctx) {
-    llama_synchronize(ctx);
-
-    // reorder logits for backward compatibility
-    // TODO: maybe deprecate this
-    llama_output_reorder(ctx);
-
-    return ctx->logits;
-}
-
-float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
-    int32_t j = -1;
-    llama_synchronize(ctx);
-
-    try {
-        if (ctx->logits == nullptr) {
-            throw std::runtime_error("no logits");
-        }
-
-        if (i < 0) {
-            j = ctx->n_outputs + i;
-            if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
-            }
-        } else if ((size_t) i >= ctx->output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
-        } else {
-            j = ctx->output_ids[i];
-        }
-
-        if (j < 0) {
-            throw std::runtime_error(format("batch.logits[%d] != true", i));
-        }
-        if (j >= ctx->n_outputs) {
-            // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
-        }
-
-        return ctx->logits + j*ctx->model.hparams.n_vocab;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
-#ifndef NDEBUG
-        GGML_ABORT("fatal error");
-#else
-        return nullptr;
-#endif
-    }
-}
-
-float * llama_get_embeddings(struct llama_context * ctx) {
-    llama_synchronize(ctx);
-
-    // reorder embeddings for backward compatibility
-    // TODO: maybe deprecate this
-    llama_output_reorder(ctx);
-
-    return ctx->embd;
-}
-
-float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
-    int32_t j = -1;
-
-    llama_synchronize(ctx);
-
-    try {
-        if (ctx->embd == nullptr) {
-            throw std::runtime_error("no embeddings");
-        }
-
-        if (i < 0) {
-            j = ctx->n_outputs + i;
-            if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
-            }
-        } else if ((size_t) i >= ctx->output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
-        } else {
-            j = ctx->output_ids[i];
-        }
-
-        if (j < 0) {
-            throw std::runtime_error(format("batch.logits[%d] != true", i));
-        }
-        if (j >= ctx->n_outputs) {
-            // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
-        }
-
-        return ctx->embd + j*ctx->model.hparams.n_embd;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
-#ifndef NDEBUG
-        GGML_ABORT("fatal error");
-#else
-        return nullptr;
-#endif
-    }
-}
-
-float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_synchronize(ctx);
-
-    auto it = ctx->embd_seq.find(seq_id);
-    if (it == ctx->embd_seq.end()) {
-        return nullptr;
-    }
-
-    return it->second.data();
-}
-
-//
-// vocab
-//
-
-const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
-    return llama_token_get_text_impl(model->vocab, token);
-}
-
-float llama_token_get_score(const struct llama_model * model, llama_token token) {
-    return llama_token_get_score_impl(model->vocab, token);
-}
-
-enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
-    return llama_token_get_attr_impl(model->vocab, token);
-}
-
-bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
-    return llama_token_is_eog_impl(model->vocab, token);
-}
-
-bool llama_token_is_control(const struct llama_model * model, llama_token token) {
-    return llama_token_is_control_impl(model->vocab, token);
-}
-
-llama_token llama_token_bos(const struct llama_model * model) {
-    return llama_token_bos_impl(model->vocab);
-}
-
-llama_token llama_token_eos(const struct llama_model * model) {
-    return llama_token_eos_impl(model->vocab);
-}
-
-llama_token llama_token_cls(const struct llama_model * model) {
-    return llama_token_cls_impl(model->vocab);
-}
-
-llama_token llama_token_sep(const struct llama_model * model) {
-    return llama_token_sep_impl(model->vocab);
-}
-
-llama_token llama_token_nl (const struct llama_model * model) {
-    return llama_token_nl_impl(model->vocab);
-}
-
-llama_token llama_token_pad(const struct llama_model * model) {
-    return llama_token_pad_impl(model->vocab);
-}
-
-bool llama_add_bos_token(const struct llama_model * model) {
-    return llama_add_bos_token_impl(model->vocab);
-}
-
-bool llama_add_eos_token(const struct llama_model * model) {
-    return llama_add_eos_token_impl(model->vocab);
-}
-
-llama_token llama_token_prefix(const struct llama_model * model) {
-    return llama_token_prefix_impl(model->vocab);
-}
-
-llama_token llama_token_middle(const struct llama_model * model) {
-    return llama_token_middle_impl(model->vocab);
-}
-
-llama_token llama_token_suffix(const struct llama_model * model) {
-    return llama_token_suffix_impl(model->vocab);
-}
-
-llama_token llama_token_eot(const struct llama_model * model) {
-    return llama_token_eot_impl(model->vocab);
-}
-
-//
-// tokenization
-//
-
-int32_t llama_tokenize(
-    const struct llama_model * model,
-                  const char * text,
-                     int32_t   text_len,
-                 llama_token * tokens,
-                     int32_t   n_tokens_max,
-                        bool   add_special,
-                        bool   parse_special) {
-    return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
-}
-
-int32_t llama_token_to_piece(
-    const struct llama_model * model,
-                 llama_token   token,
-                        char * buf,
-                     int32_t   length,
-                     int32_t   lstrip,
-                        bool   special) {
-    return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
-}
-
-int32_t llama_detokenize(
-    const struct llama_model * model,
-           const llama_token * tokens,
-                     int32_t   n_tokens,
-                        char * text,
-                     int32_t   text_len_max,
-                        bool   remove_special,
-                        bool   unparse_special) {
-    return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
-}
-
 //
 // chat templates
 //
 
-// Simple version of "llama_apply_chat_template" that only works with strings
-// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
-static int32_t llama_chat_apply_template_internal(
-    const std::string & tmpl,
-    const std::vector & chat,
-    std::string & dest, bool add_ass) {
-    // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
-    std::stringstream ss;
-    auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
-        return tmpl.find(haystack) != std::string::npos;
-    };
-    if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
-        // chatml template
-        for (auto message : chat) {
-            ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
-        }
-        if (add_ass) {
-            ss << "<|im_start|>assistant\n";
-        }
-    } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
-        // llama2 template and its variants
-        // [variant] support system message
-        bool support_system_message = tmpl_contains("<>") || tmpl == "mistral";
-        // [variant] space before + after response
-        bool space_around_response = tmpl_contains("' ' + eos_token");
-        // [variant] add BOS inside history
-        bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
-        // [variant] trim spaces from the input message
-        bool strip_message = tmpl_contains("content.strip()");
-        // construct the prompt
-        bool is_inside_turn = true; // skip BOS at the beginning
-        ss << "[INST] ";
-        for (auto message : chat) {
-            std::string content = strip_message ? trim(message->content) : message->content;
-            std::string role(message->role);
-            if (!is_inside_turn) {
-                is_inside_turn = true;
-                ss << (add_bos_inside_history ? "[INST] " : "[INST] ");
-            }
-            if (role == "system") {
-                if (support_system_message) {
-                    ss << "<>\n" << content << "\n<>\n\n";
-                } else {
-                    // if the model does not support system message, we still include it in the first message, but without <>
-                    ss << content << "\n";
-                }
-            } else if (role == "user") {
-                ss << content << " [/INST]";
-            } else {
-                ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "";
-                is_inside_turn = false;
-            }
-        }
-        // llama2 templates seem to not care about "add_generation_prompt"
-    } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
-        // Phi 3
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
-        }
-        if (add_ass) {
-            ss << "<|assistant|>\n";
-        }
-    } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
-        // zephyr template
-        for (auto message : chat) {
-            ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
-        }
-        if (add_ass) {
-            ss << "<|assistant|>\n";
-        }
-    } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
-        // mlabonne/AlphaMonarch-7B template (the  is included inside history)
-        for (auto message : chat) {
-            std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message
-            ss << bos << message->role << "\n" << message->content << "\n";
-        }
-        if (add_ass) {
-            ss << "assistant\n";
-        }
-    } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("")) {
-        // google/gemma-7b-it
-        std::string system_prompt = "";
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
-                system_prompt = trim(message->content);
-                continue;
-            }
-            // in gemma, "assistant" is "model"
-            role = role == "assistant" ? "model" : message->role;
-            ss << "" << role << "\n";
-            if (!system_prompt.empty() && role != "model") {
-                ss << system_prompt << "\n\n";
-                system_prompt = "";
-            }
-            ss << trim(message->content) << "\n";
-        }
-        if (add_ass) {
-            ss << "model\n";
-        }
-    } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
-        // OrionStarAI/Orion-14B-Chat
-        std::string system_prompt = "";
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                // there is no system message support, we will merge it with user prompt
-                system_prompt = message->content;
-                continue;
-            } else if (role == "user") {
-                ss << "Human: ";
-                if (!system_prompt.empty()) {
-                    ss << system_prompt << "\n\n";
-                    system_prompt = "";
-                }
-                ss << message->content << "\n\nAssistant: ";
-            } else {
-                ss << message->content << "";
-            }
-        }
-    } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
-        // openchat/openchat-3.5-0106,
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << message->content << "<|end_of_turn|>";
-            } else {
-                role[0] = toupper(role[0]);
-                ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
-            }
-        }
-        if (add_ass) {
-            ss << "GPT4 Correct Assistant:";
-        }
-    } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
-        // eachadea/vicuna-13b-1.1 (and Orca variant)
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                // Orca-Vicuna variant uses a system prefix
-                if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
-                    ss << "SYSTEM: " << message->content << "\n";
-                } else {
-                    ss << message->content << "\n\n";
-                }
-            } else if (role == "user") {
-                ss << "USER: " << message->content << "\n";
-            } else if (role == "assistant") {
-                ss << "ASSISTANT: " << message->content << "\n";
-            }
-        }
-        if (add_ass) {
-            ss << "ASSISTANT:";
-        }
-    } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
-        // deepseek-ai/deepseek-coder-33b-instruct
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << message->content;
-            } else if (role == "user") {
-                ss << "### Instruction:\n" << message->content << "\n";
-            } else if (role == "assistant") {
-                ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
-            }
-        }
-        if (add_ass) {
-            ss << "### Response:\n";
-        }
-    } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
-        // CohereForAI/c4ai-command-r-plus
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
-            } else if (role == "user") {
-                ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
-            } else if (role == "assistant") {
-                ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
-            }
-        }
-        if (add_ass) {
-            ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
-        }
-    } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
-        // Llama 3
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
-        }
-        if (add_ass) {
-            ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
-        }
-    } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
-        // chatglm3-6b
-        ss << "[gMASK]" << "sop";
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|" << role << "|>" << "\n " << message->content;
-        }
-        if (add_ass) {
-            ss << "<|assistant|>";
-        }
-    } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]")) {
-        ss << "[gMASK]" << "";
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|" << role << "|>" << "\n" << message->content;
-        }
-        if (add_ass) {
-            ss << "<|assistant|>";
-        }
-    } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) {
-        // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "user") {
-                ss << LU8("<用户>");
-                ss << trim(message->content);
-                ss << "";
-            } else {
-                ss << trim(message->content);
-            }
-        }
-    } else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
-        // DeepSeek-V2
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << message->content << "\n\n";
-            } else if (role == "user") {
-                ss << "User: " << message->content << "\n\n";
-            } else if (role == "assistant") {
-                ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>");
-            }
-        }
-        if (add_ass) {
-            ss << "Assistant:";
-        }
-    } else if (tmpl == "exaone3" || (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]"))) {
-        // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
-        // EXAONE-3.0-7.8B-Instruct
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
-            } else if (role == "user") {
-                ss << "[|user|]" << trim(message->content) << "\n";
-            } else if (role == "assistant") {
-                ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
-            }
-        }
-        if (add_ass) {
-            ss << "[|assistant|]";
-        }
-    } else {
-        // template not supported
-        return -1;
-    }
-    dest = ss.str();
-    return dest.size();
-}
-
 int32_t llama_chat_apply_template(
-                const struct llama_model * model,
                               const char * tmpl,
          const struct llama_chat_message * chat,
                                   size_t   n_msg,
                                     bool   add_ass,
                                     char * buf,
                                  int32_t   length) {
-    std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
-    if (tmpl == nullptr) {
-        GGML_ASSERT(model != nullptr);
-        // load template from model
-        std::vector model_template(2048, 0); // longest known template is about 1200 bytes
-        std::string template_key = "tokenizer.chat_template";
-        int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
-        if (res < 0) {
-            // worst case: there is no information about template, we will use chatml by default
-            curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
-        } else {
-            curr_tmpl = std::string(model_template.data(), model_template.size());
-        }
-    }
+    const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl);
 
     // format the chat to string
     std::vector chat_vec;
@@ -20602,7 +10002,11 @@ int32_t llama_chat_apply_template(
     }
 
     std::string formatted_chat;
-    int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
+    llm_chat_template detected_tmpl = llm_chat_detect_template(curr_tmpl);
+    if (detected_tmpl == LLM_CHAT_TEMPLATE_UNKNOWN) {
+        return -1;
+    }
+    int32_t res = llm_chat_apply_template(detected_tmpl, chat_vec, formatted_chat, add_ass);
     if (res < 0) {
         return res;
     }
@@ -20612,15 +10016,6 @@ int32_t llama_chat_apply_template(
     return res;
 }
 
-//
-// sampling
-//
-
-// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
-struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
-    return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
-}
-
 //
 // model split
 //
@@ -20633,16 +10028,16 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix,
     return 0;
 }
 
-int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
+int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
     std::string str_split_path(split_path);
     char postfix[32];
     snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
     std::string str_postfix(postfix);
 
-    // check if dest ends with postfix
+    // check if split_prefix ends with postfix
     int size_prefix = str_split_path.size() - str_postfix.size();
     if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
+        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
         return size_prefix;
     }
 
@@ -20651,161 +10046,64 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
 
 const char * llama_print_system_info(void) {
     static std::string s;
+    s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls.
 
-    s  = "";
-    s += "AVX = "         + std::to_string(ggml_cpu_has_avx())         + " | ";
-    s += "AVX_VNNI = "    + std::to_string(ggml_cpu_has_avx_vnni())    + " | ";
-    s += "AVX2 = "        + std::to_string(ggml_cpu_has_avx2())        + " | ";
-    s += "AVX512 = "      + std::to_string(ggml_cpu_has_avx512())      + " | ";
-    s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
-    s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
-    s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
-    s += "FMA = "         + std::to_string(ggml_cpu_has_fma())         + " | ";
-    s += "NEON = "        + std::to_string(ggml_cpu_has_neon())        + " | ";
-    s += "SVE = "         + std::to_string(ggml_cpu_has_sve())         + " | ";
-    s += "ARM_FMA = "     + std::to_string(ggml_cpu_has_arm_fma())     + " | ";
-    s += "F16C = "        + std::to_string(ggml_cpu_has_f16c())        + " | ";
-    s += "FP16_VA = "     + std::to_string(ggml_cpu_has_fp16_va())     + " | ";
-    s += "WASM_SIMD = "   + std::to_string(ggml_cpu_has_wasm_simd())   + " | ";
-    s += "BLAS = "        + std::to_string(ggml_cpu_has_blas())        + " | ";
-    s += "SSE3 = "        + std::to_string(ggml_cpu_has_sse3())        + " | ";
-    s += "SSSE3 = "       + std::to_string(ggml_cpu_has_ssse3())       + " | ";
-    s += "VSX = "         + std::to_string(ggml_cpu_has_vsx())         + " | ";
-    s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | ";
-    s += "LLAMAFILE = "   + std::to_string(ggml_cpu_has_llamafile())   + " | ";
+
+    for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
+        auto * reg = ggml_backend_reg_get(i);
+        auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features");
+        if (get_features_fn) {
+            ggml_backend_feature * features = get_features_fn(reg);
+            s += ggml_backend_reg_name(reg);
+            s += " : ";
+            for (; features->name; features++) {
+                s += features->name;
+                s += " = ";
+                s += features->value;
+                s += " | ";
+            }
+        }
+    }
 
     return s.c_str();
 }
 
-void llama_perf_print(const void * ctx, enum llama_perf_type type) {
-    switch (type) {
-        case LLAMA_PERF_TYPE_CONTEXT:
-            {
-                const auto * p = (const struct llama_context *) ctx;
+//
+// perf
+//
 
-                const double t_start_ms   = 1e-3 * p->t_start_us;
-                const double t_end_ms     = 1.00 * ggml_time_ms();
-                const double t_load_ms    = 1e-3 * p->t_load_us;
-                const double t_p_eval_ms  = 1e-3 * p->t_p_eval_us;
-                const double t_eval_ms    = 1e-3 * p->t_eval_us;
+struct llama_perf_context_data llama_perf_context(const struct llama_context * ctx) {
+    struct llama_perf_context_data data = {};
 
-                const int32_t n_p_eval  = std::max(0, p->n_p_eval);
-                const int32_t n_eval    = std::max(1, p->n_eval);
-
-                LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, t_load_ms);
-                LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
-                        __func__, t_p_eval_ms, n_p_eval, t_p_eval_ms / n_p_eval, 1e3 / t_p_eval_ms * n_p_eval);
-                LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-                        __func__, t_eval_ms, n_eval, t_eval_ms / n_eval, 1e3 / t_eval_ms * n_eval);
-                LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - t_start_ms), (n_p_eval + n_eval));
-            } break;
-        case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
-            {
-                const auto * smpl = (const struct llama_sampler *) ctx;
-                const auto * p = (const struct llama_sampler_chain *) smpl->ctx;
-
-                const double t_sampler_ms = 1e-3 * p->t_sample_us;
-
-                const int32_t n_sampler = std::max(0, p->n_sample);
-
-                LLAMA_LOG_INFO("%s:    sampling time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-                        __func__, t_sampler_ms, n_sampler, t_sampler_ms / n_sampler, 1e3 / t_sampler_ms * n_sampler);
-            } break;
-        default:
-            GGML_ABORT("invalid perf type");
+    if (ctx == nullptr) {
+        return data;
     }
+
+    data.t_start_ms  = 1e-3 * ctx->t_start_us;
+    data.t_load_ms   = 1e-3 * ctx->t_load_us;
+    data.t_p_eval_ms = 1e-3 * ctx->t_p_eval_us;
+    data.t_eval_ms   = 1e-3 * ctx->t_eval_us;
+    data.n_p_eval    = std::max(1, ctx->n_p_eval);
+    data.n_eval      = std::max(1, ctx->n_eval);
+
+    return data;
 }
 
-void llama_perf_reset(void * ctx, enum llama_perf_type type) {
-    switch (type) {
-        case LLAMA_PERF_TYPE_CONTEXT:
-            {
-                auto * p = (struct llama_context *) ctx;
+void llama_perf_context_print(const struct llama_context * ctx) {
+    const auto data = llama_perf_context(ctx);
 
-                p->t_start_us  = ggml_time_us();
-                p->t_eval_us   = p->n_eval = 0;
-                p->t_p_eval_us = p->n_p_eval = 0;
-            } break;
-        case LLAMA_PERF_TYPE_SAMPLER_CHAIN:
-            {
-                auto * smpl = (struct llama_sampler *) ctx;
-                auto * p = (struct llama_sampler_chain *) smpl->ctx;
+    const double t_end_ms = 1e-3 * ggml_time_us();
 
-                p->t_sample_us = p->n_sample = 0;
-            } break;
-        default:
-            GGML_ABORT("invalid perf type");
-    }
+    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, data.t_load_ms);
+    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
+    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
+    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
 }
 
-void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) {
-    fprintf(stream, "\n");
-    fprintf(stream, "###########\n");
-    fprintf(stream, "# Timings #\n");
-    fprintf(stream, "###########\n");
-    fprintf(stream, "\n");
-
-    fprintf(stream, "mst_eval: %.2f  # ms / token during generation\n",
-            1.0e-3 * ctx->t_eval_us / ctx->n_eval);
-    fprintf(stream, "mst_p_eval: %.2f  # ms / token during prompt processing\n",
-            1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
-    fprintf(stream, "n_eval: %d  # number of tokens generated (excluding the first one)\n", ctx->n_eval);
-    fprintf(stream, "n_p_eval: %d  # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
-    fprintf(stream, "t_eval_us: %" PRId64 "  # total microseconds spent generating tokens\n", ctx->t_eval_us);
-    fprintf(stream, "t_load_us: %" PRId64 "  # total microseconds spent loading the model\n", ctx->t_load_us);
-    fprintf(stream, "t_p_eval_us: %" PRId64 "  # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
-    fprintf(stream, "ts_eval: %.2f  # tokens / second during generation\n",
-            1.0e6 * ctx->n_eval / ctx->t_eval_us);
-    fprintf(stream, "ts_p_eval: %.2f  # tokens / second during prompt processing\n",
-            1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
-}
-
-// For internal test use
-const std::vector> & llama_internal_get_tensor_map(
-    struct llama_context * ctx
-) {
-    return ctx->model.tensors_by_name;
-}
-
-void llama_log_set(ggml_log_callback log_callback, void * user_data) {
-    g_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
-    g_state.log_callback_user_data = user_data;
-#ifdef GGML_USE_METAL
-    ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#elif defined(GGML_USE_CUDA)
-    ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#elif defined(GGML_USE_CANN)
-    ggml_backend_cann_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#endif
-}
-
-static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {
-    va_list args_copy;
-    va_copy(args_copy, args);
-    char buffer[128];
-    int len = vsnprintf(buffer, 128, format, args);
-    if (len < 128) {
-        g_state.log_callback(level, buffer, g_state.log_callback_user_data);
-    } else {
-        char* buffer2 = new char[len+1];
-        vsnprintf(buffer2, len+1, format, args_copy);
-        buffer2[len] = 0;
-        g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
-        delete[] buffer2;
-    }
-    va_end(args_copy);
-}
-
-void llama_log_internal(ggml_log_level level, const char * format, ...) {
-    va_list args;
-    va_start(args, format);
-    llama_log_internal_v(level, format, args);
-    va_end(args);
-}
-
-void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
-    (void) level;
-    (void) user_data;
-    fputs(text, stderr);
-    fflush(stderr);
+void llama_perf_context_reset(struct llama_context * ctx) {
+    ctx->t_start_us  = ggml_time_us();
+    ctx->t_eval_us   = ctx->n_eval = 0;
+    ctx->t_p_eval_us = ctx->n_p_eval = 0;
 }
diff --git a/src/unicode-data.cpp b/src/unicode-data.cpp
index 02bdf7823..04dcd7fcf 100644
--- a/src/unicode-data.cpp
+++ b/src/unicode-data.cpp
@@ -7,7 +7,7 @@
 #include 
 #include 
 
-const std::vector> unicode_ranges_flags = {  // start, flags // last=next_start-1
+const std::initializer_list> unicode_ranges_flags = {  // start, flags // last=next_start-1
 {0x000000, 0x0080},
 {0x000020, 0x0008},
 {0x000021, 0x0020},
@@ -2311,7 +2311,8 @@ const std::unordered_set unicode_set_whitespace = {
 0x003000,
 };
 
-const std::unordered_map unicode_map_lowercase = {
+// list is always in ascending order, to enable binary search
+const std::initializer_list> unicode_map_lowercase = {
 {0x000041, 0x000061},
 {0x000042, 0x000062},
 {0x000043, 0x000063},
@@ -3747,7 +3748,8 @@ const std::unordered_map unicode_map_lowercase = {
 {0x01E921, 0x01E943},
 };
 
-const std::unordered_map unicode_map_uppercase = {
+// list is always in ascending order, to enable binary search
+const std::initializer_list> unicode_map_uppercase = {
 {0x000061, 0x000041},
 {0x000062, 0x000042},
 {0x000063, 0x000043},
@@ -5200,7 +5202,7 @@ const std::unordered_map unicode_map_uppercase = {
 {0x01E943, 0x01E921},
 };
 
-const std::vector unicode_ranges_nfd = {  // start, last, nfd
+const std::initializer_list unicode_ranges_nfd = {  // start, last, nfd
 {0x000000, 0x000000, 0x000000},
 {0x0000C0, 0x0000C5, 0x000041},
 {0x0000C7, 0x0000C7, 0x000043},
diff --git a/src/unicode-data.h b/src/unicode-data.h
index e27fe1770..f6973ebd2 100644
--- a/src/unicode-data.h
+++ b/src/unicode-data.h
@@ -13,8 +13,8 @@ struct range_nfd {
 
 static const uint32_t MAX_CODEPOINTS = 0x110000;
 
-extern const std::vector> unicode_ranges_flags;
+extern const std::initializer_list> unicode_ranges_flags;
 extern const std::unordered_set unicode_set_whitespace;
-extern const std::unordered_map unicode_map_lowercase;
-extern const std::unordered_map unicode_map_uppercase;
-extern const std::vector unicode_ranges_nfd;
+extern const std::initializer_list> unicode_map_lowercase;
+extern const std::initializer_list> unicode_map_uppercase;
+extern const std::initializer_list unicode_ranges_nfd;
diff --git a/src/unicode.cpp b/src/unicode.cpp
index 46650bff0..89180da41 100644
--- a/src/unicode.cpp
+++ b/src/unicode.cpp
@@ -5,19 +5,19 @@
 #include "unicode.h"
 #include "unicode-data.h"
 
+#include 
 #include 
+#include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
 #include 
 #include 
-#include 
-#include 
 
 size_t unicode_len_utf8(char src) {
     const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@@ -70,15 +70,15 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
     throw std::invalid_argument("failed to convert utf8 to codepoint");
 }
 
-//static std::vector unicode_cpt_to_utf16(uint32_t cp) {
+//static std::vector unicode_cpt_to_utf16(uint32_t cpt) {
 //    std::vector result;
-//    if (/* 0x0000 <= cp && */ cp <= 0xffff) {
-//        result.emplace_back(cp);
+//    if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
+//        result.emplace_back(cpt);
 //        return result;
 //    }
-//    if (0x10000 <= cp && cp <= 0x10ffff) {
-//        result.emplace_back(0xd800 | ((cp - 0x10000) >> 10));
-//        result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff));
+//    if (0x10000 <= cpt && cpt <= 0x10ffff) {
+//        result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
+//        result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
 //        return result;
 //    }
 //    throw std::invalid_argument("failed to convert codepoint to utf16");
@@ -119,14 +119,14 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
 //    return result;
 //}
 
-static std::vector unicode_cpt_flags_array() {
-    std::vector cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
+static std::vector unicode_cpt_flags_array() {
+    std::vector cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
 
-    assert (unicode_ranges_flags.front().first == 0);
-    assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
+    assert (unicode_ranges_flags.begin()[0].first == 0);
+    assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
     for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
-        const auto range_ini = unicode_ranges_flags[i-1];  // codepoint_ini, flags
-        const auto range_end = unicode_ranges_flags[i];    // codepoint_end, flags
+        const auto range_ini = unicode_ranges_flags.begin()[i-1];  // codepoint_ini, flags
+        const auto range_end = unicode_ranges_flags.begin()[i];    // codepoint_end, flags
         for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
             cpt_flags[cpt] = range_ini.second;
         }
@@ -200,7 +200,18 @@ static std::unordered_map unicode_utf8_to_byte_map() {
 }
 
 static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
+#if defined(__clang__)
+    // disable C++17 deprecation warning for std::codecvt_utf8
+#    pragma clang diagnostic push
+#    pragma clang diagnostic ignored "-Wdeprecated-declarations"
+#endif
+
     std::wstring_convert> conv;
+
+#if defined(__clang__)
+#    pragma clang diagnostic pop
+#endif
+
     return conv.from_bytes(s);
 }
 
@@ -241,8 +252,8 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t
             return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
-        auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -359,8 +370,8 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
             return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
-        auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
+        auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -560,29 +571,29 @@ static std::vector unicode_regex_split_custom(const std::string & text,
 // interface
 //
 
-std::string unicode_cpt_to_utf8(uint32_t cp) {
+std::string unicode_cpt_to_utf8(uint32_t cpt) {
     std::string result;
 
-    if (/* 0x00 <= cp && */ cp <= 0x7f) {
-        result.push_back(cp);
+    if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
+        result.push_back(cpt);
         return result;
     }
-    if (0x80 <= cp && cp <= 0x7ff) {
-        result.push_back(0xc0 | ((cp >> 6) & 0x1f));
-        result.push_back(0x80 | (cp & 0x3f));
+    if (0x80 <= cpt && cpt <= 0x7ff) {
+        result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
+        result.push_back(0x80 | (cpt & 0x3f));
         return result;
     }
-    if (0x800 <= cp && cp <= 0xffff) {
-        result.push_back(0xe0 | ((cp >> 12) & 0x0f));
-        result.push_back(0x80 | ((cp >> 6) & 0x3f));
-        result.push_back(0x80 | (cp & 0x3f));
+    if (0x800 <= cpt && cpt <= 0xffff) {
+        result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
+        result.push_back(0x80 | ((cpt >> 6) & 0x3f));
+        result.push_back(0x80 | (cpt & 0x3f));
         return result;
     }
-    if (0x10000 <= cp && cp <= 0x10ffff) {
-        result.push_back(0xf0 | ((cp >> 18) & 0x07));
-        result.push_back(0x80 | ((cp >> 12) & 0x3f));
-        result.push_back(0x80 | ((cp >> 6) & 0x3f));
-        result.push_back(0x80 | (cp & 0x3f));
+    if (0x10000 <= cpt && cpt <= 0x10ffff) {
+        result.push_back(0xf0 | ((cpt >> 18) & 0x07));
+        result.push_back(0x80 | ((cpt >> 12) & 0x3f));
+        result.push_back(0x80 | ((cpt >> 6) & 0x3f));
+        result.push_back(0x80 | (cpt & 0x3f));
         return result;
     }
 
@@ -596,7 +607,7 @@ std::vector unicode_cpts_normalize_nfd(const std::vector & c
     std::vector result(cpts.size());
     for (size_t i = 0; i < cpts.size(); ++i) {
         const uint32_t cpt = cpts[i];
-        auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
+        auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1;
         result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
     }
     return result;
@@ -612,19 +623,19 @@ std::vector unicode_cpts_from_utf8(const std::string & utf8) {
     return result;
 }
 
-codepoint_flags unicode_cpt_flags(const uint32_t cp) {
-    static const codepoint_flags undef(codepoint_flags::UNDEFINED);
+unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
+    static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
     static const auto cpt_flags = unicode_cpt_flags_array();
-    return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
+    return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
 }
 
-codepoint_flags unicode_cpt_flags(const std::string & utf8) {
-    static const codepoint_flags undef(codepoint_flags::UNDEFINED);
+unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
+    static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
     if (utf8.empty()) {
         return undef;  // undefined
     }
     size_t offset = 0;
-    return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
+    return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
 }
 
 std::string unicode_byte_to_utf8(uint8_t byte) {
@@ -637,34 +648,47 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) {
     return map.at(utf8);
 }
 
-uint32_t unicode_tolower(uint32_t cp) {
-    auto it = unicode_map_lowercase.find(cp);
-    return it == unicode_map_lowercase.end() ? cp : it->second;
+uint32_t unicode_tolower(uint32_t cpt) {
+    // binary search
+    auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
+        [](const std::pair & pair, uint32_t value) {
+            return pair.first < value;
+        });
+    if (it != unicode_map_lowercase.end() && it->first == cpt) {
+        return it->second;
+    }
+    return cpt;  // Return the original code point if no lowercase mapping is found
 }
 
 std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) {
     // unicode categories
     static const std::map k_ucat_enum = {
-        { "\\p{N}", codepoint_flags::NUMBER },
-        { "\\p{L}", codepoint_flags::LETTER },
-        { "\\p{P}", codepoint_flags::PUNCTUATION },
+        { "\\p{N}", unicode_cpt_flags::NUMBER },
+        { "\\p{L}", unicode_cpt_flags::LETTER },
+        { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
+        { "\\p{M}", unicode_cpt_flags::ACCENT_MARK },
+        { "\\p{S}", unicode_cpt_flags::SYMBOL },
     };
 
     static const std::map k_ucat_cpt = {
-        { codepoint_flags::NUMBER,        0xD1 },
-        { codepoint_flags::LETTER,        0xD2 },
-        { codepoint_flags::PUNCTUATION,   0xD3 },
+        { unicode_cpt_flags::NUMBER,      0xD1 },
+        { unicode_cpt_flags::LETTER,      0xD2 },
+        { unicode_cpt_flags::PUNCTUATION, 0xD3 },
+        { unicode_cpt_flags::ACCENT_MARK, 0xD4 },
+        { unicode_cpt_flags::SYMBOL,      0xD5 },
     };
 
     static const std::map k_ucat_map = {
-        { codepoint_flags::NUMBER,        "\x30-\x39" }, // 0-9
-        { codepoint_flags::LETTER,        "\x41-\x5A\x61-\x7A" }, // A-Za-z
-        { codepoint_flags::PUNCTUATION,   "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
+        { unicode_cpt_flags::NUMBER,      "\x30-\x39" }, // 0-9
+        { unicode_cpt_flags::LETTER,      "\x41-\x5A\x61-\x7A" }, // A-Za-z
+        { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
+        { unicode_cpt_flags::ACCENT_MARK, "" }, // no sub-128 codepoints
+        { unicode_cpt_flags::SYMBOL,      "\\\x24\\\x2B\x3C-\x3E\x5E\x60\\\x7C" }, // $+<=>^`|
     };
 
     // compute collapsed codepoints only if needed by at least one regex
     bool need_collapse = false;
-    for (auto & regex_expr : regex_exprs) {
+    for (const auto & regex_expr : regex_exprs) {
         // search for unicode categories
         for (const auto & ucat : k_ucat_enum) {
             if (std::string::npos != regex_expr.find(ucat.first)) {
@@ -690,7 +714,7 @@ std::vector unicode_regex_split(const std::string & text, const std
                 continue;
             }
 
-            const auto flags = unicode_cpt_flags(cpts[i]);
+            const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
 
             if (flags.is_whitespace) {
                 //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
@@ -706,7 +730,7 @@ std::vector unicode_regex_split(const std::string & text, const std
 
     std::vector bpe_offsets = { cpts.size() };
 
-    for (auto & regex_expr : regex_exprs) {
+    for (const auto & regex_expr : regex_exprs) {
         // first, see if we have an efficient custom regex implementation
         auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
 
@@ -720,7 +744,7 @@ std::vector unicode_regex_split(const std::string & text, const std
             // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
             // with the corresponding collapsed representation
             bool use_collapsed = false;
-            for (auto & ucat : k_ucat_enum) {
+            for (const auto & ucat : k_ucat_enum) {
                 if (std::string::npos != regex_expr.find(ucat.first)) {
                     use_collapsed = true;
                     break;
@@ -786,7 +810,7 @@ std::vector unicode_regex_split(const std::string & text, const std
                 // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
                 std::wstring wtext(cpts.begin(), cpts.end());
                 for (size_t i = 0; i < wtext.size(); ++i) {
-                    if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
+                    if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
                         wtext[i] = 0x0B;
                     }
                 }
diff --git a/src/unicode.h b/src/unicode.h
index 008532a24..c27098df7 100644
--- a/src/unicode.h
+++ b/src/unicode.h
@@ -4,9 +4,7 @@
 #include 
 #include 
 
-// TODO: prefix all symbols with "llama_"
-
-struct codepoint_flags {
+struct unicode_cpt_flags {
     enum {
         UNDEFINED       = 0x0001,
         NUMBER          = 0x0002,  // regex: \p{N}
@@ -35,7 +33,7 @@ struct codepoint_flags {
     uint16_t is_nfd         : 1;
 
     // decode from uint16
-    inline codepoint_flags(const uint16_t flags=0) {
+    inline unicode_cpt_flags(const uint16_t flags = 0) {
         *reinterpret_cast(this) = flags;
     }
 
@@ -50,18 +48,19 @@ struct codepoint_flags {
 
 size_t unicode_len_utf8(char src);
 
-std::string unicode_cpt_to_utf8(uint32_t cp);
-uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
+std::string unicode_cpt_to_utf8  (uint32_t cpt);
+uint32_t    unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
+
 std::vector unicode_cpts_from_utf8(const std::string & utf8);
 
 std::vector unicode_cpts_normalize_nfd(const std::vector & cpts);
 
-codepoint_flags unicode_cpt_flags(const uint32_t cp);
-codepoint_flags unicode_cpt_flags(const std::string & utf8);
+unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
+unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
 
 std::string unicode_byte_to_utf8(uint8_t byte);
-uint8_t unicode_utf8_to_byte(const std::string & utf8);
+uint8_t     unicode_utf8_to_byte(const std::string & utf8);
 
-uint32_t unicode_tolower(uint32_t cp);
+uint32_t unicode_tolower(uint32_t cpt);
 
 std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs);
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 30e71cfd4..40f83ff0d 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -1,3 +1,5 @@
+llama_add_compile_flags()
+
 function(llama_test target)
     include(CMakeParseArguments)
     set(options)
@@ -84,54 +86,67 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2             ARGS ${CMAKE
 llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact            ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
 llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder         ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
 
-# build test-tokenizer-1-bpe target once and add many tests
-add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)
-target_link_libraries(test-tokenizer-1-bpe PRIVATE common)
-install(TARGETS test-tokenizer-1-bpe RUNTIME)
 
-# TODO: disabled due to slowness
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-2     ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-neox  ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt       ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-refact    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
-#llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
+if (NOT WIN32)
+    # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API
+    llama_target_and_test(test-sampling.cpp)
+    llama_target_and_test(test-grammar-parser.cpp)
+    llama_target_and_test(test-grammar-integration.cpp)
+    llama_target_and_test(test-llama-grammar.cpp)
+    llama_target_and_test(test-chat.cpp)
+    # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
+    if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
+        llama_target_and_test(test-json-schema-to-grammar.cpp   WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
+        target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
+    endif()
 
-# build test-tokenizer-1-spm target once and add many tests
-add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp)
-target_link_libraries(test-tokenizer-1-spm PRIVATE common)
-install(TARGETS test-tokenizer-1-spm RUNTIME)
 
-llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
-#llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-baichuan  ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
+    # build test-tokenizer-1-bpe target once and add many tests
+    add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)
+    target_link_libraries(test-tokenizer-1-bpe PRIVATE common)
+    install(TARGETS test-tokenizer-1-bpe RUNTIME)
 
-# llama_target_and_test(test-double-float.cpp) # SLOW
+    # TODO: disabled due to slowness
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-2     ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-neox  ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt       ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-refact    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
+
+    # build test-tokenizer-1-spm target once and add many tests
+    add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp)
+    target_link_libraries(test-tokenizer-1-spm PRIVATE common)
+    install(TARGETS test-tokenizer-1-spm RUNTIME)
+
+    llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
+    #llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-baichuan  ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
+
+    # llama_target_and_test(test-double-float.cpp) # SLOW
+endif()
+
+llama_target_and_test(test-log.cpp)
 llama_target_and_test(test-arg-parser.cpp)
-llama_target_and_test(test-quantize-fns.cpp)
-llama_target_and_test(test-quantize-perf.cpp)
-llama_target_and_test(test-sampling.cpp)
 llama_target_and_test(test-chat-template.cpp)
 
-llama_target_and_test(test-grammar-parser.cpp)
-llama_target_and_test(test-llama-grammar.cpp)
-llama_target_and_test(test-grammar-integration.cpp)
-llama_target_and_test(test-grad0.cpp)
 # llama_target_and_test(test-opt.cpp) # SLOW
+llama_target_and_test(test-gguf.cpp)
 llama_target_and_test(test-backend-ops.cpp)
 
-llama_target_and_test(test-rope.cpp)
-
 llama_target_and_test(test-model-load-cancel.cpp  LABEL "model")
 llama_target_and_test(test-autorelease.cpp        LABEL "model")
 
-# TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
-if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
-    llama_target_and_test(test-json-schema-to-grammar.cpp   WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
-    target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
+if (NOT GGML_BACKEND_DL)
+    # these tests use the backends directly and cannot be built with dynamic loading
+    llama_target_and_test(test-barrier.cpp)
+    llama_target_and_test(test-quantize-fns.cpp)
+    llama_target_and_test(test-quantize-perf.cpp)
+    llama_target_and_test(test-rope.cpp)
 endif()
 
+
 # dummy executable - not installed
 get_filename_component(TEST_TARGET test-c.c NAME_WE)
 add_executable(${TEST_TARGET} test-c.c)
diff --git a/tests/run-json-schema-to-grammar.mjs b/tests/run-json-schema-to-grammar.mjs
index 71bf62ed3..b20ac1d6b 100644
--- a/tests/run-json-schema-to-grammar.mjs
+++ b/tests/run-json-schema-to-grammar.mjs
@@ -1,5 +1,5 @@
 import { readFileSync } from "fs"
-import { SchemaConverter } from "../examples/server/public/json-schema-to-grammar.mjs"
+import { SchemaConverter } from "../examples/server/public_legacy/json-schema-to-grammar.mjs"
 
 const [, , file] = process.argv
 const url = `file://${file}`
diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp
index 8852bfc7e..69604b87c 100644
--- a/tests/test-arg-parser.cpp
+++ b/tests/test-arg-parser.cpp
@@ -1,19 +1,43 @@
+#include "arg.h"
+#include "common.h"
+
 #include 
 #include 
 #include 
+#include 
 
 #undef NDEBUG
 #include 
 
-#include "common.h"
-
 int main(void) {
-    gpt_params params;
+    common_params params;
 
     printf("test-arg-parser: make sure there is no duplicated arguments in any examples\n\n");
     for (int ex = 0; ex < LLAMA_EXAMPLE_COUNT; ex++) {
         try {
-            gpt_params_parser_init(params, (enum llama_example)ex);
+            auto ctx_arg = common_params_parser_init(params, (enum llama_example)ex);
+            std::unordered_set seen_args;
+            std::unordered_set seen_env_vars;
+            for (const auto & opt : ctx_arg.options) {
+                // check for args duplications
+                for (const auto & arg : opt.args) {
+                    if (seen_args.find(arg) == seen_args.end()) {
+                        seen_args.insert(arg);
+                    } else {
+                        fprintf(stderr, "test-arg-parser: found different handlers for the same argument: %s", arg);
+                        exit(1);
+                    }
+                }
+                // check for env var duplications
+                if (opt.env) {
+                    if (seen_env_vars.find(opt.env) == seen_env_vars.end()) {
+                        seen_env_vars.insert(opt.env);
+                    } else {
+                        fprintf(stderr, "test-arg-parser: found different handlers for the same env var: %s", opt.env);
+                        exit(1);
+                    }
+                }
+            }
         } catch (std::exception & e) {
             printf("%s\n", e.what());
             assert(false);
@@ -29,40 +53,51 @@ int main(void) {
     };
 
     std::vector argv;
-    auto options = gpt_params_parser_init(params, LLAMA_EXAMPLE_COMMON);
 
     printf("test-arg-parser: test invalid usage\n\n");
 
+    // missing value
     argv = {"binary_name", "-m"};
-    assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
 
+    // wrong value (int)
     argv = {"binary_name", "-ngl", "hello"};
-    assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
 
+    // wrong value (enum)
     argv = {"binary_name", "-sm", "hello"};
-    assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+
+    // non-existence arg in specific example (--draft cannot be used outside llama-speculative)
+    argv = {"binary_name", "--draft", "123"};
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_EMBEDDING));
 
 
     printf("test-arg-parser: test valid usage\n\n");
 
     argv = {"binary_name", "-m", "model_file.gguf"};
-    assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
     assert(params.model == "model_file.gguf");
 
     argv = {"binary_name", "-t", "1234"};
-    assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
     assert(params.cpuparams.n_threads == 1234);
 
     argv = {"binary_name", "--verbose"};
-    assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
-    assert(params.verbosity == 1);
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
+    assert(params.verbosity > 1);
 
     argv = {"binary_name", "-m", "abc.gguf", "--predict", "6789", "--batch-size", "9090"};
-    assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
     assert(params.model == "abc.gguf");
     assert(params.n_predict == 6789);
     assert(params.n_batch == 9090);
 
+    // --draft cannot be used outside llama-speculative
+    argv = {"binary_name", "--draft", "123"};
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE));
+    assert(params.speculative.n_max == 123);
+
 // skip this part on windows, because setenv is not supported
 #ifdef _WIN32
     printf("test-arg-parser: skip on windows build\n");
@@ -71,12 +106,12 @@ int main(void) {
 
     setenv("LLAMA_ARG_THREADS", "blah", true);
     argv = {"binary_name"};
-    assert(false == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
+    assert(false == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
 
     setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
     setenv("LLAMA_ARG_THREADS", "1010", true);
     argv = {"binary_name"};
-    assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
     assert(params.model == "blah.gguf");
     assert(params.cpuparams.n_threads == 1010);
 
@@ -86,7 +121,7 @@ int main(void) {
     setenv("LLAMA_ARG_MODEL", "blah.gguf", true);
     setenv("LLAMA_ARG_THREADS", "1010", true);
     argv = {"binary_name", "-m", "overwritten.gguf"};
-    assert(true == gpt_params_parse(argv.size(), list_str_to_char(argv).data(), params, options));
+    assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON));
     assert(params.model == "overwritten.gguf");
     assert(params.cpuparams.n_threads == 1010);
 #endif // _WIN32
diff --git a/tests/test-autorelease.cpp b/tests/test-autorelease.cpp
index 57fa00011..35b09aaea 100644
--- a/tests/test-autorelease.cpp
+++ b/tests/test-autorelease.cpp
@@ -13,10 +13,10 @@ int main(int argc, char ** argv) {
 
     std::thread([&model_path]() {
         llama_backend_init();
-        auto * model = llama_load_model_from_file(model_path, llama_model_default_params());
-        auto * ctx = llama_new_context_with_model(model, llama_context_default_params());
+        auto * model = llama_model_load_from_file(model_path, llama_model_default_params());
+        auto * ctx = llama_init_from_model(model, llama_context_default_params());
         llama_free(ctx);
-        llama_free_model(model);
+        llama_model_free(model);
         llama_backend_free();
     }).join();
 
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 635de01d7..4c5c4dd9c 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -1,6 +1,6 @@
 // This file defines tests for various GGML ops and backends.
 // For the forward pass it asserts that the results of multiple backends computing the same GGML ops are consistent.
-// For the backwards pass it asserts that the gradients from backpropagation are consistent
+// For the backward pass it asserts that the gradients from backpropagation are consistent
 // with the gradients obtained via the method of finite differences ("grad" mode, this is optional).
 // It is also possible to check the performance ("perf" mode).
 //
@@ -25,70 +25,58 @@
 #include 
 #include 
 #include 
-#include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
+#include 
 #include 
 
 static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
-    // static RNG initialization (revisit if n_threads stops being constant)
-    static const size_t n_threads = std::thread::hardware_concurrency();
-    static std::vector generators = []() {
-        std::random_device rd;
-        std::vector vec;
-        vec.reserve(n_threads);
-        //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
-        for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
-        return vec;
-    }();
+    size_t nels = ggml_nelements(tensor);
+    std::vector data(nels);
+    {
+        // parallel initialization
+        static const size_t n_threads = std::thread::hardware_concurrency();
+        // static RNG initialization (revisit if n_threads stops being constant)
+        static std::vector generators = []() {
+            std::random_device rd;
+            std::vector vec;
+            vec.reserve(n_threads);
+            //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
+            for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
+            return vec;
+        }();
 
-    size_t size = ggml_nelements(tensor);
-    std::vector data(size);
+        auto init_thread = [&](size_t ith, size_t start, size_t end) {
+            std::uniform_real_distribution distribution(min, max);
+            auto & gen = generators[ith];
+            for (size_t i = start; i < end; i++) {
+                data[i] = distribution(gen);
+            }
+        };
 
-    auto init_thread = [&](size_t ith, size_t start, size_t end) {
-        std::uniform_real_distribution distribution(min, max);
-        for (size_t i = start; i < end; i++) {
-            data[i] = distribution(generators[ith]);
+        std::vector> tasks;
+        tasks.reserve(n_threads);
+        for (size_t i = 0; i < n_threads; i++) {
+            size_t start =     i*nels/n_threads;
+            size_t end   = (i+1)*nels/n_threads;
+            tasks.push_back(std::async(std::launch::async, init_thread, i, start, end));
         }
-    };
-
-    std::vector threads;
-    threads.reserve(n_threads);
-    for (size_t i = 0; i < n_threads; i++) {
-        size_t start =     i*size/n_threads;
-        size_t end   = (i+1)*size/n_threads;
-        threads.emplace_back(init_thread, i, start, end);
-    }
-    for (auto & t : threads) {
-        t.join();
-    }
-
-#if 0
-    const char * val_str = getenv("GGML_TEST_EPS");
-    float val = 1e-9f;
-    if (val_str != nullptr) {
-        val = std::stof(val_str);
-        printf("GGML_TEST_EPS=%e\n", val);
-    }
-
-    // test quantization with very small values that may result in nan scales due to division by zero
-    if (ggml_is_quantized(tensor->type)) {
-        for (int i = 0; i < 256; i++) {
-            data[i] = val;
+        for (auto & t : tasks) {
+            t.get();
         }
     }
-#endif
 
     if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
-        ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
+        ggml_backend_tensor_set(tensor, data.data(), 0, nels * sizeof(float));
     } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
-        GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
-        std::vector dataq(ggml_row_size(tensor->type, size));
-        std::vector imatrix(tensor->ne[0], 1.0f); // dummy importance matrix
+        GGML_ASSERT(nels % ggml_blck_size(tensor->type) == 0);
+
+         // dummy importance matrix
+        std::vector imatrix(tensor->ne[0], 1.0f);
         const float * im = imatrix.data();
         if (!ggml_quantize_requires_imatrix(tensor->type)) {
             // when the imatrix is optional, we want to test both quantization with and without imatrix
@@ -98,19 +86,40 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
             }
         }
 
-        ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
-        GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
-        // TODO: other cases
-        //#pragma omp parallel for
-        //for (int i = 0; i < tensor->ne[1]; i++) {
-        //    ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
-        //        i * tensor->ne[0], 1, tensor->ne[0], im);
-        //}
+        std::vector dataq(ggml_row_size(tensor->type, nels));
+        {
+            // parallel quantization by block
+            size_t blck_size = ggml_blck_size(tensor->type);
+            size_t n_blocks = nels / blck_size;
 
+            auto quantize_thread = [&](size_t start, size_t end) {
+                ggml_quantize_chunk(tensor->type, data.data(), dataq.data(),
+                    start * blck_size, end - start, blck_size, im);
+            };
+
+            const size_t min_blocks_per_thread = 1;
+            const size_t n_threads = std::min(std::thread::hardware_concurrency()/2,
+                                                      std::max(1, n_blocks / min_blocks_per_thread));
+            std::vector> tasks;
+            tasks.reserve(n_threads);
+            for (size_t i = 0; i < n_threads; i++) {
+                size_t start =     i*n_blocks/n_threads;
+                size_t end   = (i+1)*n_blocks/n_threads;
+                tasks.push_back(std::async(std::launch::async, quantize_thread, start, end));
+            }
+            for (auto & t : tasks) {
+                t.get();
+            }
+        }
         ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
     } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
         // This is going to create some weird integers though.
         ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor));
+    } else if (tensor->type == GGML_TYPE_I64) {
+        // Integers with a size of 8 bytes can be set by mirroring the float data, the specific values are again not really meaningful.
+        const size_t nbytes_half = ggml_nbytes(tensor)/2;
+        ggml_backend_tensor_set(tensor, data.data(), 0*nbytes_half, nbytes_half);
+        ggml_backend_tensor_set(tensor, data.data(), 1*nbytes_half, nbytes_half);
     } else {
         GGML_ABORT("fatal error");
     }
@@ -123,7 +132,7 @@ static std::vector tensor_to_float(const ggml_tensor * t) {
     std::vector buf(ggml_nbytes(t));
     ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
 
-    ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type);
+    const auto * tt = ggml_get_type_traits(t->type);
     size_t bs = ggml_blck_size(t->type);
     std::vector vq(ggml_blck_size(t->type));
     bool quantized = ggml_is_quantized(t->type);
@@ -140,6 +149,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) {
                         tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));
                     } else if (t->type == GGML_TYPE_F32) {
                         tv.push_back(*(float *) &buf[i]);
+                    } else if (t->type == GGML_TYPE_I64) {
+                        tv.push_back((float)*(int64_t *) &buf[i]);
                     } else if (t->type == GGML_TYPE_I32) {
                         tv.push_back((float)*(int32_t *) &buf[i]);
                     } else if (t->type == GGML_TYPE_I16) {
@@ -147,7 +158,7 @@ static std::vector tensor_to_float(const ggml_tensor * t) {
                     } else if (t->type == GGML_TYPE_I8) {
                         tv.push_back((float)*(int8_t *) &buf[i]);
                     } else if (quantized) {
-                        tt.to_float(&buf[i], vq.data(), bs);
+                        tt->to_float(&buf[i], vq.data(), bs);
                         tv.insert(tv.end(), vq.begin(), vq.end());
                     } else {
                         GGML_ABORT("fatal error");
@@ -160,60 +171,6 @@ static std::vector tensor_to_float(const ggml_tensor * t) {
     return tv;
 }
 
-/*
-static double cosine_similarity(const float * v1, const float * v2, size_t n) {
-    double dot = 0.0;
-    double mag1 = 0.0;
-    double mag2 = 0.0;
-
-    for (size_t i = 0; i < n; i++) {
-        if (std::isnan(v1[i]) || std::isnan(v2[i])) {
-            return -1.0f;
-        }
-        if (std::isinf(v1[i]) && std::isinf(v2[i])) {
-            continue;
-        }
-        dot  += v1[i]*v2[i];
-        mag1 += v1[i]*v1[i];
-        mag2 += v2[i]*v2[i];
-    }
-
-    return dot/sqrt(mag1*mag2);
-}
-
-static float distance(const float * v1, const float * v2, size_t n) {
-    double d = 0.0;
-
-    for (size_t i = 0; i < n; i++) {
-        if (std::isnan(v1[i]) || std::isnan(v2[i])) {
-            return INFINITY;
-        }
-        if (std::isinf(v1[i]) && std::isinf(v2[i])) {
-            continue;
-        }
-        d += (v1[i] - v2[i])*(v1[i] - v2[i]);
-    }
-
-    return sqrt(d);
-}
-
-static float vec_len(const float * v, size_t n) {
-    double d = 0.0;
-
-    for (size_t i = 0; i < n; i++) {
-        if (std::isnan(v[i])) {
-            return INFINITY;
-        }
-        if (std::isinf(v[i])) {
-            continue;
-        }
-        d += v[i]*v[i];
-    }
-
-    return sqrt(d);
-}
-*/
-
 // normalized mean squared error = mse(a, b) / mse(a, 0)
 static double nmse(const float * a, const float * b, size_t n) {
     double mse_a_b = 0.0;
@@ -264,7 +221,6 @@ static double mean_abs_asymm(const float * a, const float * b, const size_t n, c
 }
 
 // utils for printing the variables of the test cases
-#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
 
 template
 static std::string var_to_str(const T & x) {
@@ -297,10 +253,6 @@ static std::string var_to_str(const std::array & x) {
     return s;
 }
 
-//static std::string var_to_str(ggml_unary_op unary_op) {
-//    return ggml_unary_op_name(unary_op);
-//}
-
 static std::string var_to_str(ggml_type type) {
     return ggml_type_name(type);
 }
@@ -313,6 +265,8 @@ static std::string var_to_str(ggml_op_pool pool) {
     }
 }
 
+#define VAR_TO_STR(x) (#x "=" + var_to_str(x))
+
 #define VARS_TO_STR1(a) VAR_TO_STR(a)
 #define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
 #define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
@@ -370,13 +324,13 @@ struct test_case {
         return 1e-4;
     }
 
-    virtual float grad_eps(){
+    virtual float grad_eps() {
         return 1e-1f;
     }
 
     // If false, estimate gradient with 2 points, neglects 3rd order derivative and higher.
     // If true,  estimate gradient with 4 points, neglects 5th order derivative and higher.
-    virtual bool grad_precise(){
+    virtual bool grad_precise() {
         return false;
     }
 
@@ -409,6 +363,11 @@ struct test_case {
         return size;
     }
 
+    virtual uint64_t op_flops(ggml_tensor * t) {
+        GGML_UNUSED(t);
+        return 0;
+    }
+
     ggml_cgraph * gf = nullptr;
     ggml_cgraph * gb = nullptr;
 
@@ -519,7 +478,7 @@ struct test_case {
 
         // add sentinels as graph nodes so that they are checked in the callback
         for (ggml_tensor * sentinel : sentinels) {
-            gf->nodes[gf->n_nodes++] = sentinel;
+            ggml_graph_add_node(gf, sentinel);
         }
 
         // randomize tensors
@@ -651,12 +610,11 @@ struct test_case {
         }
 
         // align while also leaving some margin for variations in parameters
-        int align = 20;
+        int align = 8;
         int last = (len + align - 1) / align * align;
         if (last - len < 5) {
             last += align;
         }
-        last = std::max(last, 60);
         printf("%*s", last - len, "");
 
         // allocate
@@ -677,11 +635,28 @@ struct test_case {
         // warmup run
         ggml_backend_graph_compute(backend, gf);
 
+        // determine number of runs
+        int n_runs;
+        bool is_cpu = ggml_backend_dev_type(ggml_backend_get_device(backend)) == GGML_BACKEND_DEVICE_TYPE_CPU;
+        if (op_flops(out) > 0) {
+            // based on flops
+            const uint64_t GFLOP = 1000 * 1000 * 1000;
+            const uint64_t target_flops_cpu =   8ULL * GFLOP;
+            const uint64_t target_flops_gpu = 100ULL * GFLOP;
+            uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
+            n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
+        } else {
+            // based on memory size
+            const size_t GB = 1ULL << 30;
+            const size_t target_size_cpu =  8 * GB;
+            const size_t target_size_gpu = 32 * GB;
+            size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
+            n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
+        }
+
         // duplicate the op
-        size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
-        int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
         for (int i = 1; i < n_runs; i++) {
-            gf->nodes[gf->n_nodes++] = out;
+            ggml_graph_add_node(gf, out);
         }
 
         // calculate memory
@@ -696,27 +671,56 @@ struct test_case {
             }
             return size;
         };
-        for (int i = 0; i < gf->n_nodes; i++) {
-            if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) {
+        for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) {
+            if (ggml_is_view_op(ggml_graph_node(gf, i)->op) || ggml_graph_node(gf, i) == out) {
                 continue;
             }
-            mem += tensor_op_size(gf->nodes[i]);
+            mem += tensor_op_size(ggml_graph_node(gf, i));
         }
 
         // run
-        ggml_backend_synchronize(backend);
+        int64_t total_time_us = 0;
+        int64_t total_mem = 0;
+        int total_runs = 0;
+        do {
+            int64_t start_time = ggml_time_us();
+            ggml_backend_graph_compute(backend, gf);
+            int64_t end_time = ggml_time_us();
 
-        int64_t start_time = ggml_time_us();
-        ggml_backend_graph_compute(backend, gf);
-        ggml_backend_synchronize(backend);
-        int64_t end_time = ggml_time_us();
-        double time_us = end_time - start_time;
+            total_time_us += end_time - start_time;
+            total_mem += mem;
+            total_runs += n_runs;
+        } while (total_time_us < 1000*1000); // run for at least 1 second
 
-        printf("    %5d runs - %8.2f us/run - %8zu kB/run - \033[1;34m%7.2f GB/s\033[0m\n",
-            n_runs,
-            time_us / n_runs,
-            op_size(out) / 1024,
-            mem / (time_us/1e6) / 1024.0 / 1024.0 / 1024.0);
+        printf("    %8d runs - %8.2f us/run - ",
+            total_runs,
+            (double)total_time_us / total_runs);
+
+        if (op_flops(out) > 0) {
+            double flops_per_sec = (op_flops(out) * total_runs) / (total_time_us / 1e6);
+            auto format_flops = [](double flops) -> std::string {
+                char buf[256];
+                if (flops >= 1e12) {
+                    snprintf(buf, sizeof(buf), "%6.2f TFLOP", flops / 1e12);
+                } else if (flops >= 1e9) {
+                    snprintf(buf, sizeof(buf), "%6.2f GFLOP", flops / 1e9);
+                } else if (flops >= 1e6) {
+                    snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6);
+                } else {
+                    snprintf(buf, sizeof(buf), "%6.2f KFLOP", flops / 1e3);
+                }
+                return buf;
+            };
+            printf("%s/run - \033[1;34m%sS\033[0m",
+                format_flops(op_flops(out)).c_str(),
+                format_flops(flops_per_sec).c_str());
+
+        } else {
+            printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m",
+                op_size(out) / 1024,
+                total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0);
+        }
+        printf("\n");
 
         ggml_backend_buffer_free(buf);
 
@@ -742,7 +746,7 @@ struct test_case {
 
         ggml_tensor * out = build_graph(ctx);
 
-        if (op_name != nullptr && op_desc(out) != op_name) {
+        if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
             //printf("  %s: skipping\n", op_desc(out).c_str());
             ggml_free(ctx);
             return true;
@@ -751,11 +755,6 @@ struct test_case {
         printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
         fflush(stdout);
 
-        if (out->grad == nullptr) {
-            printf("backwards pass not supported \n");
-            ggml_free(ctx);
-            return true;
-        }
         if (out->type != GGML_TYPE_F32) {
             ggml_free(ctx);
             printf("not supported [%s->type != FP32]\n", out->name);
@@ -764,18 +763,26 @@ struct test_case {
 
         // check if the backend supports the ops
         bool supported = true;
+        bool any_params = false;
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (!ggml_backend_supports_op(backend, t)) {
                 printf("not supported [%s] ", ggml_backend_name(backend));
                 supported = false;
                 break;
             }
-            if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
-                printf("not supported [%s->type != FP32] ", t->name);
-                supported = false;
-                break;
+            if ((t->flags & GGML_TENSOR_FLAG_PARAM)) {
+                any_params = true;
+                if (t->type != GGML_TYPE_F32) {
+                    printf("not supported [%s->type != FP32] ", t->name);
+                    supported = false;
+                    break;
+                }
             }
         }
+        if (!any_params) {
+            printf("not supported [%s] \n", op_desc(out).c_str());
+            supported = false;
+        }
         if (!supported) {
             printf("\n");
             ggml_free(ctx);
@@ -799,18 +806,18 @@ struct test_case {
             out = ggml_sum(ctx, out);
             ggml_set_name(out, "sum_of_out");
         }
+        ggml_set_loss(out);
 
         ggml_build_forward_expand(gf, out);
         ggml_graph_cpy(gf, gb);
-        ggml_build_backward_expand(ctx, gf, gb, false);
+        ggml_build_backward_expand(ctx, ctx, gb, false);
         if (expect.size() != 1 || expect[0] != 0.0f) {
-            GGML_ASSERT(gb->n_nodes > gf->n_nodes);
+            GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
             for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-                GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || t->grad->op != GGML_OP_NONE);
+                GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE);
             }
         }
 
-        // TODO: refactor so that this check is only needed once
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (!ggml_backend_supports_op(backend, t)) {
                 printf("not supported [%s] ", ggml_backend_name(backend));
@@ -837,22 +844,11 @@ struct test_case {
             return false;
         }
 
-        // randomize tensors
-        initialize_tensors(ctx);
 
-        for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
-            if (!t->grad) {
-                continue;
-            }
+        initialize_tensors(ctx); // Randomizes all tensors (including gradients).
+        ggml_graph_reset(gb);    // Sets gradients to 1 if loss, 0 otherwise.
 
-            std::vector tmp(ggml_nelements(t->grad));
-            ggml_backend_tensor_set(t->grad, tmp.data(), 0, ggml_nbytes(t->grad));
-        }
-
-        // build graphs
-        const float onef = 1.0f;
         ggml_backend_graph_compute(backend, gf);
-        ggml_backend_tensor_set(out->grad, &onef, 0, ggml_nbytes(out->grad));
         ggml_backend_graph_compute(backend, gb);
 
         bool ok = true;
@@ -864,7 +860,13 @@ struct test_case {
             const char * bn = ggml_backend_name(backend);
             const int64_t ne = ggml_nelements(t);
 
-            std::vector ga = tensor_to_float(t->grad);
+            std::vector ga;
+            struct ggml_tensor * grad = ggml_graph_get_grad(gb, t);
+            if (grad) {
+                ga = tensor_to_float(grad);
+            } else {
+                ga.resize(ne); // default value is 0.0f
+            }
 
             for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
                 // check for nans
@@ -996,7 +998,7 @@ struct test_example : public test_case {
     }
     // In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a)
     // immediately after you create the tensors.
-    // This is optional and only makes sense if a backwards pass has actually been implemented for the new op.
+    // This is optional and only makes sense if a backward pass has actually been implemented for the new op.
 };
 
 
@@ -1128,6 +1130,144 @@ struct test_get_rows : public test_case {
     }
 };
 
+// GGML_OP_GET_ROWS_BACK
+struct test_get_rows_back : public test_case {
+    const ggml_type type;
+    const int n; // cols
+    const int m; // rows
+    const int r; // rows to get
+    const int b; // batch size
+    const bool v; // view (non-contiguous src1)
+
+    std::string vars() override {
+        return VARS_TO_STR6(type, n, m, r, b, v);
+    }
+
+    test_get_rows_back(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
+        : type(type), n(n), m(m), r(r), b(b), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * in_forward = ggml_new_tensor_3d(ctx, type, n, m, b);
+        ggml_set_name(in_forward, "in_forward");
+
+        ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
+        ggml_set_name(rows, "rows");
+        if (v) {
+            rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
+            ggml_set_name(rows, "view_of_rows");
+        }
+
+        ggml_tensor * grad = ggml_new_tensor_3d(ctx, type, n, r, b);
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * out = ggml_get_rows_back(ctx, grad, rows, in_forward);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
+                // rows
+                std::vector data(r*b);
+                for (int i = 0; i < r*b; i++) {
+                    data[i] = rand() % m;
+                }
+                ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
+// GGML_OP_ARGMAX
+struct test_argmax : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_argmax(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 100, 1, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_argmax(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_F32) {
+                // initialize with unique values to avoid ties
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
+                }
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+
+    double max_nmse_err() override {
+        return 0.0;
+    }
+};
+
+// GGML_OP_COUNT_EQUAL
+struct test_count_equal : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_count_equal(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {4, 500, 1, 1})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * a_argmax = ggml_argmax(ctx, a);
+        ggml_set_name(a_argmax, "a_argmax");
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(b, "b");
+
+        ggml_tensor * b_argmax = ggml_argmax(ctx, a);
+        ggml_set_name(b_argmax, "b_argmax");
+
+        ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    double max_nmse_err() override {
+        return 0.0;
+    }
+};
+
 // GGML_OP_REPEAT
 struct test_repeat : public test_case {
     const ggml_type type;
@@ -1162,6 +1302,59 @@ struct test_repeat : public test_case {
     }
 };
 
+// GGML_OP_REPEAT_BACK
+struct test_repeat_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const std::array nr;
+    const bool v; // whether src is a noncontiguous view
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, nr, v);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) * 2;
+    }
+
+    test_repeat_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {8, 6, 4, 2},
+            std::array nr = {2, 2, 2, 2},
+            bool v = false)
+        : type(type), ne(ne), nr(nr), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
+        ggml_set_name(src, "src");
+
+        if (v) {
+            GGML_ASSERT(ne[0] % 2 == 0);
+            GGML_ASSERT(ne[1] % 2 == 0);
+            GGML_ASSERT(ne[2] % 2 == 0);
+            GGML_ASSERT(ne[3] % 2 == 0);
+            GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
+            GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
+            GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
+            GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
+
+            const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
+            const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
+            const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
+            const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
+
+            src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
+        }
+
+        ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(target, "target");
+
+        ggml_tensor * out = ggml_repeat_back(ctx, src, target);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_DUP
 struct test_dup : public test_case {
     const ggml_type type;
@@ -1235,7 +1428,7 @@ struct test_set : public test_case {
             offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i];
         }
         ggml_tensor * out = ggml_set(ctx, dst, src,
-            // The backwards pass requires setting a contiguous region:
+            // The backward pass requires setting a contiguous region:
             src->nb[1], src->nb[2], src->nb[3], offset);
         ggml_set_name(out, "out");
 
@@ -1347,7 +1540,7 @@ struct test_bin_bcast : public test_case {
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_set_name(b, "b");
 
-        // The backwards pass supports broadcasting only for GGML_ADD:
+        // The backward pass supports broadcasting only for GGML_ADD:
         const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
         if (grad_supported) {
             ggml_set_param(ctx, a);
@@ -1444,6 +1637,39 @@ struct test_scale : public test_case {
     }
 };
 
+// GGML_OP_SILU_BACK
+struct test_silu_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, eps);
+    }
+
+    test_silu_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {64, 5, 4, 3},
+            float eps = 1e-6f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * grad = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * out = ggml_silu_back(ctx, a, grad);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
 // GGML_OP_NORM
 struct test_norm : public test_case {
     const ggml_type type;
@@ -1496,11 +1722,56 @@ struct test_rms_norm : public test_case {
         return out;
     }
 
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.f, 10.f);
+        }
+    }
+
+    float grad_eps() override {
+        return 1.0f;
+    }
+
     bool grad_precise() override {
         return true;
     }
 };
 
+// GGML_OP_RMS_NORM_BACK
+struct test_rms_norm_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    float eps;
+
+    std::string vars() override {
+        return VARS_TO_STR3(type, ne, eps);
+    }
+
+    test_rms_norm_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {64, 5, 4, 3},
+            float eps = 1e-6f)
+        : type(type), ne(ne), eps(eps) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_rms_norm_back(ctx, a, b, eps);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.f, 10.f);
+        }
+    }
+};
+
 // GGML_OP_SSM_CONV
 struct test_ssm_conv : public test_case {
     const ggml_type type;
@@ -1553,6 +1824,65 @@ struct test_ssm_scan : public test_case {
     }
 };
 
+// GGML_OP_RWKV_WKV6
+struct test_rwkv_wkv6 : public test_case {
+    const ggml_type type;
+
+    const int64_t head_count;
+    const int64_t head_size;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
+    }
+
+    test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32,
+            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int64_t n_tokens = n_seq_tokens * n_seqs;
+        ggml_tensor * r   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * tf  = ggml_new_tensor(ctx, type, 2, std::vector{ head_size, head_count }.data());
+        ggml_tensor * td  = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data());
+        ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
+        return out;
+    }
+};
+
+// GGML_OP_GATED_LINEAR_ATTN
+struct test_gla : public test_case {
+    const ggml_type type;
+
+    const int64_t head_count;
+    const int64_t head_size;
+    const int64_t n_seq_tokens;
+    const int64_t n_seqs;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
+    }
+
+    test_gla(ggml_type type = GGML_TYPE_F32,
+            int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
+        : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        const int64_t n_tokens = n_seq_tokens * n_seqs;
+        ggml_tensor * q   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * k   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * v   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * g   = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data());
+        ggml_tensor * s   = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data());
+        ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5));
+        return out;
+    }
+};
+
 // GGML_OP_MUL_MAT
 struct test_mul_mat : public test_case {
     const ggml_type type_a;
@@ -1560,40 +1890,76 @@ struct test_mul_mat : public test_case {
     const int64_t m;
     const int64_t n;
     const int64_t k;
-    const std::array bs; // dims 3 and 4
-    const std::array nr; // repeat in dims 3 and 4
+    const std::array bs;  // dims 3 and 4
+    const std::array nr;  // repeat in dims 3 and 4
+    const std::array per; // permutation of dimensions
 
     std::string vars() override {
-        return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
+        return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
     }
 
     double max_nmse_err() override {
         return 5e-4;
     }
 
-    size_t op_size(ggml_tensor * t) override {
-        size_t a = ggml_nbytes(t->src[0]) * n * nr[0] * nr[1];
-        size_t b = ggml_nbytes(t->src[1]) * m;
-        size_t c  = ggml_nbytes(t);
-        return a + b + c;
+    int64_t grad_nmax() override {
+        return 20000;
+    }
 
+    uint64_t op_flops(ggml_tensor * t) override {
         GGML_UNUSED(t);
+        return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
     }
 
     test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
             int64_t m = 32, int64_t n = 32, int64_t k = 32,
             std::array bs = {10, 10},
-            std::array nr = {2, 2})
-        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
+            std::array nr = {2, 2},
+            std::array per = {0, 1, 2, 3})
+        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         // C^T = A * B^T: (k, m) * (k, n) => (m, n)
-        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]      , bs[1]);
-        ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
-        ggml_set_param(ctx, a);
-        ggml_set_param(ctx, b);
-        ggml_set_name(a, "a");
-        ggml_set_name(b, "b");
+        ggml_tensor * a;
+        ggml_tensor * b;
+
+        const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
+        if (npermuted > 0) {
+            GGML_ASSERT(npermuted == 2);
+            GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
+            GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
+
+            // Create tensors with the permuted dimensions, then permute them back to the dimensions given by m,n,k.
+            const int64_t ne_a[4] = {k, m, bs[0],       bs[1]};
+            const int64_t ne_b[4] = {k, n, bs[0]*nr[0], bs[1]*nr[1]};
+
+            a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
+            b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
+            if (!ggml_is_quantized(type_a)) {
+                if (bs[1] == 1 && nr[1] == 1) {
+                    ggml_set_param(ctx, a);
+                }
+                ggml_set_param(ctx, b);
+            }
+            ggml_set_name(a, "a");
+            ggml_set_name(b, "b");
+
+            a = ggml_permute(ctx, a, per[0], per[1], per[2], per[3]);
+            b = ggml_permute(ctx, b, per[0], per[1], per[2], per[3]);
+            ggml_set_name(a, "a_permuted");
+            ggml_set_name(b, "b_permuted");
+        } else {
+            a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0],       bs[1]);
+            b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+            if (!ggml_is_quantized(type_a)) {
+                if (bs[1] == 1 && nr[1] == 1) {
+                    ggml_set_param(ctx, a);
+                }
+                ggml_set_param(ctx, b);
+            }
+            ggml_set_name(a, "a");
+            ggml_set_name(b, "b");
+        }
 
         ggml_tensor * out = ggml_mul_mat(ctx, a, b);
         ggml_set_name(out, "out");
@@ -1621,13 +1987,9 @@ struct test_mul_mat_id : public test_case {
         return 5e-4;
     }
 
-    size_t op_size(ggml_tensor * t) override {
-        size_t a = ggml_nbytes(t->src[2]) * n;
-        size_t b = ggml_nbytes(t->src[1]) * m;
-        size_t c  = ggml_nbytes(t);
-        return a + b + c;
-
+    uint64_t op_flops(ggml_tensor * t) override {
         GGML_UNUSED(t);
+        return 2 * m * k * n * n_used;
     }
 
     test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
@@ -1681,6 +2043,52 @@ struct test_mul_mat_id : public test_case {
     }
 };
 
+// GGML_OP_OUT_PROD
+struct test_out_prod : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const std::array bs; // dims 3 and 4
+    const std::array nr; // repeat in dims 3 and 4
+    const bool trans_b;
+
+    std::string vars() override {
+        return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, trans_b);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32,
+            std::array bs = {10, 10},
+            std::array nr = {2, 2},
+            bool trans_b = false)
+        : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), trans_b(trans_b) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b;
+        if (trans_b) {
+            b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+            b = ggml_transpose(ctx, b);
+        } else {
+            b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0]*nr[0], bs[1]*nr[1]);
+        }
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_out_prod(ctx, a, b);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_SQR
 struct test_sqr : public test_case {
     const ggml_type type;
@@ -1776,7 +2184,7 @@ struct test_log : public test_case {
 
     void initialize_tensors(ggml_context * ctx) override {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            // log(1) == 0, cluster values there to keep the sum low for better precision in the backwards pass:
+            // log(1) == 0, cluster values there to keep the sum low for better precision in the backward pass:
             init_tensor_uniform(t, 0.9f, 1.1f);
         }
     }
@@ -1939,11 +2347,12 @@ struct test_soft_max : public test_case {
     const ggml_type type;
     const std::array ne;
     const bool mask;
+    const ggml_type m_prec;
     const float scale;
     const float max_bias;
 
     std::string vars() override {
-        return VARS_TO_STR5(type, ne, mask, scale, max_bias);
+        return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias);
     }
 
     // the 1024 test with bias occasionally fails:
@@ -1955,9 +2364,10 @@ struct test_soft_max : public test_case {
     test_soft_max(ggml_type type = GGML_TYPE_F32,
             std::array ne = {10, 5, 4, 3},
             bool mask = false,
+            ggml_type m_prec = GGML_TYPE_F32,
             float scale = 1.0f,
             float max_bias = 0.0f)
-        : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
+        : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
@@ -1966,7 +2376,7 @@ struct test_soft_max : public test_case {
 
         ggml_tensor * mask = nullptr;
         if (this->mask) {
-            mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
+            mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]);
             ggml_set_name(mask, "mask");
         }
 
@@ -1981,8 +2391,38 @@ struct test_soft_max : public test_case {
     }
 };
 
+// GGML_OP_SOFT_MAX_BACK
+struct test_soft_max_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+    const float scale;
+    const float max_bias;
 
-// GGML_OP_ROPE
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, scale, max_bias);
+    }
+
+    test_soft_max_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3},
+            float scale = 1.0f,
+            float max_bias = 0.0f)
+        : type(type), ne(ne), scale(scale), max_bias(max_bias) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_soft_max_ext_back(ctx, a, b, scale, max_bias);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_ROPE + GGML_OP_ROPE_BACK
 struct test_rope : public test_case {
     const ggml_type type;
     const std::array ne_a;
@@ -1994,33 +2434,48 @@ struct test_rope : public test_case {
     float af; // attn_factor
     bool ff;
     int v; // view (1 : non-contiguous a)
+    bool forward;
 
     std::string vars() override {
+        // forward can be inferred from the op, does not need to be printed
         return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
     }
 
     test_rope(ggml_type type = GGML_TYPE_F32,
             std::array ne_a = {10, 5, 3, 1},
-            int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0)
-        : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {}
+            int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f,
+            float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true)
+        : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a;
         if (v & 1) {
             auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
-            ggml_set_param(ctx, a);
+            if (forward) {
+                ggml_set_param(ctx, a);
+            }
             ggml_set_name(a, "a");
 
             a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
             ggml_set_name(a, "view_of_a");
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
-            ggml_set_param(ctx, a);
+            if (forward) {
+                ggml_set_param(ctx, a);
+            }
             ggml_set_name(a, "a");
         }
 
-        ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
+        const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+        const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+        ggml_tensor * pos;
+        if (is_mrope || is_vision) {
+            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4);
+        } else {
+            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
+        }
         ggml_set_name(pos, "pos");
 
         ggml_tensor * freq = nullptr;
@@ -2029,7 +2484,32 @@ struct test_rope : public test_case {
             ggml_set_name(freq, "freq");
         }
 
-        ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+        ggml_tensor * out;
+        if (is_mrope) {
+            if (is_vision) {
+                GGML_ASSERT(n_dims/4 > 0);
+                int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate
+                if (forward) {
+                    out = ggml_rope_multi     (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+                } else {
+                    out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+                }
+            } else {
+                GGML_ASSERT(n_dims/3 > 0);
+                int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0};
+                if (forward) {
+                    out = ggml_rope_multi     (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+                } else {
+                    out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+                }
+            }
+        } else {
+            if (forward) {
+                out = ggml_rope_ext     (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+            } else {
+                out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
+            }
+        }
         ggml_set_name(out, "out");
 
         return out;
@@ -2039,11 +2519,12 @@ struct test_rope : public test_case {
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (t->type == GGML_TYPE_I32) {
                 // pos
-                std::vector data(ne_a[2]);
-                for (int i = 0; i < ne_a[2]; i++) {
+                const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2];
+                std::vector data(num_pos_ids);
+                for (int i = 0; i < num_pos_ids; i++) {
                     data[i] = rand() % n_ctx;
                 }
-                ggml_backend_tensor_set(t, data.data(), 0, ne_a[2] * sizeof(int));
+                ggml_backend_tensor_set(t, data.data(), 0, num_pos_ids * sizeof(int));
             } else {
                 if (t->ne[0] == n_dims/2) {
                     // frequency factors in the range [0.9f, 1.1f]
@@ -2343,6 +2824,35 @@ struct test_sum_rows : public test_case {
     }
 };
 
+// GGML_OP_MEAN
+struct test_mean : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_mean(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_param(ctx, a);
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_mean(ctx, a);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    float grad_eps() override {
+        return 0.1f * ne[0]*ne[1]*ne[2]*ne[3];
+    }
+};
+
 // GGML_OP_UPSCALE
 struct test_upscale : public test_case {
     const ggml_type type;
@@ -2487,6 +2997,33 @@ struct test_pad : public test_case {
     }
 };
 
+// GGML_OP_PAD_REFLECT_1D
+struct test_pad_reflect_1d : public test_case {
+    const ggml_type type;
+    const std::array ne_a;
+    const int pad_0;
+    const int pad_1;
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
+    }
+
+    test_pad_reflect_1d(ggml_type type = GGML_TYPE_F32,
+            std::array ne_a = {512, 34, 2, 1},
+            int pad_0 = 10, int pad_1 = 9)
+        : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1)  {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 2, ne_a.data());
+        ggml_set_name(a, "a");
+
+        ggml_tensor * out = ggml_pad_reflect_1d(ctx, a, pad_0, pad_1);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_ARANGE
 struct test_arange : public test_case {
     const ggml_type type;
@@ -2576,29 +3113,51 @@ struct test_flash_attn_ext : public test_case {
     const float logit_softcap; // Gemma 2
 
     const ggml_type type_KV;
+    std::array permute;
 
     std::string vars() override {
-        return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV);
+        return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
     }
 
     double max_nmse_err() override {
         return 5e-4;
     }
 
+    uint64_t op_flops(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        // Just counting matmul costs:
+        // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
+        return 2 * 2 * nh * nb * hs * kv;
+    }
+
     test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
-                        bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
-        : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
+                        bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
+                        std::array permute = {0, 1, 2, 3})
+        : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
 
-        ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1);
+        auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
+            int64_t ne[4] = {ne0, ne1, ne2, ne3};
+            int64_t ne_perm[4];
+            for (int i = 0; i < 4; ++i) {
+                ne_perm[permute[i]] = ne[i];
+            }
+            ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
+            if (permute != std::array{0, 1, 2, 3}) {
+                t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
+            }
+            return t;
+        };
+
+        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
         ggml_set_name(q, "q");
 
-        ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
+        ggml_tensor * k = create_permuted(type_KV,       hs_padded, kv, nh, 1);
         ggml_set_name(k, "k");
 
-        ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
+        ggml_tensor * v = create_permuted(type_KV,       hs_padded, kv, nh, 1);
         ggml_set_name(v, "v");
 
         ggml_tensor * m = nullptr;
@@ -2666,6 +3225,87 @@ struct test_cross_entropy_loss : public test_case {
     }
 };
 
+// GGML_OP_CROSS_ENTROPY_LOSS_BACK
+struct test_cross_entropy_loss_back : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cross_entropy_loss_back(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * grad = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(logits, "logits");
+
+        ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_set_name(labels, "labels");
+
+        // Ensure labels add up to 1:
+        labels = ggml_soft_max(ctx, labels);
+        ggml_set_name(labels, "labels_normalized");
+
+        ggml_tensor * out = ggml_cross_entropy_loss_back(ctx, grad, logits, labels);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
+// GGML_OP_OPT_STEP_ADAMW
+struct test_opt_step_adamw : public test_case {
+    const ggml_type type;
+    const std::array ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_opt_step_adamw(ggml_type type = GGML_TYPE_F32,
+            std::array ne = {10, 5, 4, 3})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
+        ggml_set_name(a, "a");
+
+        ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad, "grad");
+
+        ggml_tensor * grad_m = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad_m, "grad_m");
+
+        ggml_tensor * grad_v = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
+        ggml_set_name(grad_v, "grad_v");
+
+        ggml_tensor * adamw_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 7);
+        ggml_set_name(adamw_params, "adamw_params");
+
+        ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, grad_m, grad_v, adamw_params);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values.
+        }
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
 enum llm_norm_type {
     LLM_NORM,
     LLM_NORM_RMS,
@@ -3054,47 +3694,48 @@ struct test_falcon : public test_llm {
 // ###########################################
 // ## Section 3: GGML Op Test Instantiation ##
 // ###########################################
+static const ggml_type all_types[] = {
+    GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
+    GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+    GGML_TYPE_Q8_0,
+    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+    GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
+    GGML_TYPE_Q6_K,
+    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
+    GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
+    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
+    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+};
 
+static const ggml_type base_types[] = {
+    GGML_TYPE_F32, GGML_TYPE_F16,
+    GGML_TYPE_Q8_0, // for I8MM tests
+    GGML_TYPE_Q4_0,
+    GGML_TYPE_Q4_1, // for I8MM tests
+    GGML_TYPE_Q4_K,
+    GGML_TYPE_IQ2_XXS
+};
 
-static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
+static const ggml_type other_types[] = {
+    GGML_TYPE_Q4_1,
+    GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+    GGML_TYPE_Q8_0,
+    GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+    GGML_TYPE_Q5_K,
+    GGML_TYPE_Q6_K,
+    // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
+    GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
+    GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
+    GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+    GGML_TYPE_BF16,
+};
+
+// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
+static std::vector> make_test_cases_eval() {
     std::vector> test_cases;
     std::default_random_engine rng(0);
 
-    const ggml_type all_types[] = {
-        GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
-        GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
-        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
-        GGML_TYPE_Q8_0,
-        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
-        GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
-        GGML_TYPE_Q6_K,
-        // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
-        GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
-        GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
-        GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
-    };
-
-    const ggml_type base_types[] = {
-        GGML_TYPE_F32, GGML_TYPE_F16,
-        GGML_TYPE_Q4_0,
-        GGML_TYPE_Q4_K,
-        GGML_TYPE_IQ2_XXS
-    };
-
-    const ggml_type other_types[] = {
-        GGML_TYPE_Q4_1,
-        GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
-        GGML_TYPE_Q8_0,
-        GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
-        GGML_TYPE_Q5_K,
-        GGML_TYPE_Q6_K,
-        // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
-        GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
-        GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
-        GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
-        GGML_TYPE_BF16,
-    };
-
     // unary ops
     for (int v : {0, 1}) {
         for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
@@ -3117,6 +3758,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
+    test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_F32, 1, 8, 2, 1, false));
+    for (ggml_type type : all_types) {
+        for (bool v : {false, true}) {
+            test_cases.emplace_back(new test_get_rows_back(type, 256, 5, 4, 1, v));
+        }
+    }
+    for (bool v : {false, true}) {
+        test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
+    }
+
     for (ggml_type type_input : {GGML_TYPE_F32}) {
         for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
             for (int k0 : {1, 3}) {
@@ -3135,13 +3786,49 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
-    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
-    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
-    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
-    // test cases for 1D im2col
+    // im2col 1D
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
+    for (int s0 : {1, 3}) {
+        for (int p0 : {0, 3}) {
+            for (int d0 : {1, 3}) {
+                test_cases.emplace_back(new test_im2col(
+                    GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
+                    s0, 0, p0, 0, d0, 0, false));
+            }
+        }
+    }
+
+    // im2col 2D
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
+    for (int s0 : {1, 3}) {
+        for (int s1 : {1, 3}) {
+            for (int p0 : {0, 3}) {
+                for (int p1 : {0, 3}) {
+                    for (int d0 : {1, 3}) {
+                        for (int d1 : {1, 3}) {
+                            test_cases.emplace_back(new test_im2col(
+                                GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
+                                s0, s1, p0, p1, d0, d1, true));
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    // extra tests for im2col 2D
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 1024}, {3, 3, 2, 1024}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2048}, {3, 3, 1, 2048}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
+    test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
 
     // sycl backend will limit task global_range < MAX_INT
     // test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
@@ -3159,14 +3846,34 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
 
+    test_cases.emplace_back(new test_count_equal());
 
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, 3}, {2, 1, 1, 1}));
-    test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, 3}, {1, 1, 1, 2}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32,    1, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100,  10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438,  3, 1, 1}));
+
+    for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 2, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 2}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
+        test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
+    }
+
+    for (bool view : {false, true}) {
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
+        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
+    }
 
     test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
     test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
@@ -3176,17 +3883,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {0, 2, 1, 3})); // dup by rows
     test_cases.emplace_back(new test_dup(GGML_TYPE_F32, {10, 10, 5, 1}, {1, 0, 2, 3}));
     test_cases.emplace_back(new test_dup(GGML_TYPE_F16, {10, 10, 5, 1}, {1, 0, 2, 3})); // dup dst not-contiguous
-    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3}));
-    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10,  8, 3, 1}, {1, 2, 0, 3}));
 
     for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
         test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
     }
 
+    for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
+        test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
+    }
+
     for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
         for (ggml_type type_dst : all_types) {
-           test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
-           test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
+            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
+            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
+        }
+    }
+    for (ggml_type type_dst : {GGML_TYPE_F32}) {
+        for (ggml_type type_src : all_types) {
+            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
+            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
         }
     }
     for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
@@ -3245,10 +3962,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     test_cases.emplace_back(new test_add1());
     test_cases.emplace_back(new test_scale());
+    test_cases.emplace_back(new test_silu_back());
 
-    for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) {
-        test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
-        test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+    for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) {
+        test_cases.emplace_back(new test_norm         (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+        test_cases.emplace_back(new test_rms_norm     (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+        test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
     }
 
     test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
@@ -3257,24 +3976,66 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
 
+    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
+    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
+    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
+    test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
+
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
+    test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
+
+    for (ggml_type type_a : all_types) {
+        for (int i = 1; i < 10; ++i) {
+            test_cases.emplace_back(new test_mul_mat(type_a,    GGML_TYPE_F32, 16,  i, 256, { 1,  1}, {1, 1}));
+        }
+    }
+
 #if 1
     for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
+            // test cases without permutation
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 2}));
 
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
+
+            // test cases with permutation
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+        }
+    }
+    for (ggml_type type_a : other_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32}) {
+            if (ggml_blck_size(type_a) != 256) {
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1,  1}, {1, 1}));
+            }
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1,  1}, {1, 1}));
         }
     }
 #else
@@ -3296,15 +4057,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     }
 #endif
 
-    for (ggml_type type_a : other_types) {
-        for (ggml_type type_b : {GGML_TYPE_F32}) {
-            if (ggml_blck_size(type_a) != 256) {
-                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), {1,  1}, {1, 1}));
-            }
-            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1,  1}, {1, 1}));
-        }
-    }
-
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,  128, { 8,  1}, {1, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,  128, { 8,  1}, {4, 1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,   64, { 8,  1}, {4, 1}));
@@ -3339,7 +4091,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
             for (int n_mats : {4}) {
                 for (int n_used : {2}) {
                     for (bool b : {false}) {
-                        for (int n : {1}) {
+                        for (int n : {1, 32}) {
                             int m = 512;
                             int k = 256;
                             test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
@@ -3350,6 +4102,24 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
+    for (ggml_type type_a : base_types) {
+        for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+            for (int n : {1, 16}) {
+                for (int k : {1, 16}) {
+                    for (int bs2 : {1, 3}) {
+                        for (int bs3 : {1, 3}) {
+                            for (int nr2 : {1, 2}) {
+                                for (int nr3 : {1, 2}) {
+                                    test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, n, k, {bs2, bs3}, {nr2, nr3}));
+                                }
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
     test_cases.emplace_back(new test_sqr());
     test_cases.emplace_back(new test_sqrt());
     test_cases.emplace_back(new test_log());
@@ -3382,19 +4152,41 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
             for (float scale : {1.0f, 0.1f}) {
                 for (int64_t ne0 : {16, 1024}) {
                     for (int64_t ne1 : {16, 1024}) {
-                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, scale, max_bias));
-                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
+                        if (mask) {
+                            for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, m_prec, scale, max_bias));
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias));
+                            }
+                        } else {
+                            /* The precision of mask here doesn't matter as boolean mask is false */
+                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
+                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
+                        }
                     }
                 }
             }
         }
     }
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32,  0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16,  0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32,  0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16,  0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32,  0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16,  0.1f, 8.0f));
 
-    {
+    for (float max_bias : {0.0f, 8.0f}) {
+        for (float scale : {1.0f, 0.1f}) {
+            for (int64_t ne0 : {16, 1024}) {
+                for (int64_t ne1 : {16, 1024}) {
+                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, scale, max_bias));
+                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias));
+                }
+            }
+        }
+    }
+
+    for (bool fw : {true, false}) { // fw == forward
         bool all = true;
 
         for (float v : { 0, 1 }) {
@@ -3403,23 +4195,29 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
                     for (float af : { 1.0f, 1.4245f }) {
                         for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
                             for (bool ff : {false, true}) { // freq_factors
-                                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
+                                test_cases.emplace_back(new test_rope(type, {128,  32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B
 
                                 if (all) {
-                                    test_cases.emplace_back(new test_rope(type, {128,  40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
-                                    test_cases.emplace_back(new test_rope(type, {128,  52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
-                                    test_cases.emplace_back(new test_rope(type, {128,  64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
+                                    test_cases.emplace_back(new test_rope(type, {128,  40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B
+                                    test_cases.emplace_back(new test_rope(type, {128,  52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B
+                                    test_cases.emplace_back(new test_rope(type, {128,  64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B
                                 }
 
                                 if (all) {
-                                    test_cases.emplace_back(new test_rope(type, { 64,   1, 2, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
-                                    test_cases.emplace_back(new test_rope(type, { 64,  71, 2, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
-                                    test_cases.emplace_back(new test_rope(type, { 64,   8, 2, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
-                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm)
-                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
+                                    test_cases.emplace_back(new test_rope(type, { 64,   1, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 64,  71, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 64,   8, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  32, 2, 1},  32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2)
                                 }
 
-                                test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
+                                if (all) {
+                                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
+                                    test_cases.emplace_back(new test_rope(type, {128,  28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B)
+                                    test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
+                                }
+
+                                test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1},  64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)
                             }
                         }
 
@@ -3445,12 +4243,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     test_cases.emplace_back(new test_sum());
     test_cases.emplace_back(new test_sum_rows());
+    test_cases.emplace_back(new test_mean());
     test_cases.emplace_back(new test_upscale());
     test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
     test_cases.emplace_back(new test_upscale_ext());
-    test_cases.emplace_back(new test_group_norm());
+    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
+    test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
     test_cases.emplace_back(new test_acc());
     test_cases.emplace_back(new test_pad());
+    test_cases.emplace_back(new test_pad_reflect_1d());
     test_cases.emplace_back(new test_arange());
     test_cases.emplace_back(new test_timestep_embedding());
     test_cases.emplace_back(new test_leaky_relu());
@@ -3463,9 +4264,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
                     if (hs != 128 && logit_softcap != 0.0f) continue;
                     for (int nh : { 32, }) {
                         for (int kv : { 512, 1024, }) {
-                            for (int nb : { 1, 2, 4, 8, }) {
-                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+                            for (int nb : { 1, 3, 32, 35, }) {
+                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
                                     test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
+                                    // run fewer test cases permuted
+                                    if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                                        test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
+                                    }
                                 }
                             }
                         }
@@ -3475,7 +4280,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         }
     }
 
-    test_cases.emplace_back(new test_cross_entropy_loss());
+    test_cases.emplace_back(new test_cross_entropy_loss     (GGML_TYPE_F32, {   10, 5, 4, 3}));
+    test_cases.emplace_back(new test_cross_entropy_loss     (GGML_TYPE_F32, {30000, 1, 1, 1}));
+    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {   10, 5, 4, 3}));
+    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
+
+    test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
 
     // these tests are disabled to save execution time, but they can be handy for debugging
 #if 0
@@ -3485,21 +4295,63 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_falcon(2));
 #endif
 
-    // run tests
-    if (mode == MODE_GRAD) {
-        size_t n_ok = 0;
-        for (auto & test : test_cases) {
-            if (test->eval_grad(backend, op_name)) {
-                n_ok++;
+    return test_cases;
+}
+
+// Test cases for performance evaluation: should be representative of real-world use cases
+static std::vector> make_test_cases_perf() {
+    std::vector> test_cases;
+
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1,   1, 1, 1}));
+    test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
+
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
+    test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
+
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
+
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
+
+    for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
+        for (ggml_type type_a : all_types) {
+            for (ggml_type type_b : {GGML_TYPE_F32}) {
+                test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, bs, 14336, {1,  1}, {1, 1}));
             }
         }
-        printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
-
-        return n_ok == test_cases.size();
     }
 
+    for (int K : {3, 5}) {
+        for (int IC : {256, 2560}) {
+            for (int IW_IH : {32, 64, 256}) {
+                if (IC == 2560 && IW_IH == 256) {
+                    // too big
+                    continue;
+                }
+                test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {IW_IH, IW_IH, IC, 1}, {K, K, IC, 1}, 1, 1, 1, 1, 1, 1, true));
+            }
+        }
+    }
+
+    return test_cases;
+}
+
+static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
     if (mode == MODE_TEST) {
-        ggml_backend_t backend_cpu = ggml_backend_cpu_init();
+        auto test_cases = make_test_cases_eval();
+        ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
+        if (backend_cpu == NULL) {
+            printf("  Failed to initialize CPU backend\n");
+            return false;
+        }
 
         size_t n_ok = 0;
         for (auto & test : test_cases) {
@@ -3514,7 +4366,21 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         return n_ok == test_cases.size();
     }
 
+    if (mode == MODE_GRAD) {
+        auto test_cases = make_test_cases_eval();
+        size_t n_ok = 0;
+        for (auto & test : test_cases) {
+            if (test->eval_grad(backend, op_name)) {
+                n_ok++;
+            }
+        }
+        printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
+
+        return n_ok == test_cases.size();
+    }
+
     if (mode == MODE_PERF) {
+        auto test_cases = make_test_cases_perf();
         for (auto & test : test_cases) {
             test->eval_perf(backend, op_name);
         }
@@ -3528,9 +4394,9 @@ static void usage(char ** argv) {
     printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
     printf("    valid modes:\n");
     printf("      - test (default, compare with CPU backend for correctness)\n");
-    printf("      - perf (performance evaluation)\n");
     printf("      - grad (compare gradients from backpropagation with method of finite differences)\n");
-    printf("    op names are as given by ggml_op_desc() (e.g. GGML_ADD)\n");
+    printf("      - perf (performance evaluation)\n");
+    printf("    op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
 }
 
 int main(int argc, char ** argv) {
@@ -3565,31 +4431,45 @@ int main(int argc, char ** argv) {
         }
     }
 
-    // enumerate backends
-    printf("Testing %zu backends\n\n", ggml_backend_reg_get_count());
+    // load and enumerate backends
+    ggml_backend_load_all();
+
+    printf("Testing %zu devices\n\n", ggml_backend_dev_count());
 
     size_t n_ok = 0;
 
-    for (size_t i = 0; i < ggml_backend_reg_get_count(); i++) {
-        printf("Backend %zu/%zu (%s)\n", i + 1, ggml_backend_reg_get_count(), ggml_backend_reg_get_name(i));
+    for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
 
-        if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_reg_get_name(i)) != 0) {
+        printf("Backend %zu/%zu: %s\n", i + 1, ggml_backend_dev_count(), ggml_backend_dev_name(dev));
+
+        if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_dev_name(dev)) != 0) {
             printf("  Skipping\n");
             n_ok++;
             continue;
         }
 
-        ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL);
-        GGML_ASSERT(backend != NULL);
-
-        if (backend_filter == NULL && ggml_backend_is_cpu(backend) && mode != MODE_GRAD) {
+        if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {
             printf("  Skipping CPU backend\n");
-            ggml_backend_free(backend);
             n_ok++;
             continue;
         }
 
-        printf("  Backend name: %s\n", ggml_backend_name(backend));
+        ggml_backend_t backend = ggml_backend_dev_init(dev, NULL);
+        GGML_ASSERT(backend != NULL);
+
+        ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
+        auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
+        if (ggml_backend_set_n_threads_fn) {
+            // TODO: better value for n_threads
+            ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency());
+        }
+
+        printf("  Device description: %s\n", ggml_backend_dev_description(dev));
+        size_t free, total; // NOLINT
+        ggml_backend_dev_memory(dev, &free, &total);
+        printf("  Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
+        printf("\n");
 
         bool ok = test_backend(backend, mode, op_name_filter);
 
@@ -3606,15 +4486,15 @@ int main(int argc, char ** argv) {
         ggml_backend_free(backend);
     }
 
-    printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
+    ggml_quantize_free();
 
-    if (n_ok != ggml_backend_reg_get_count()) {
+    printf("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count());
+
+    if (n_ok != ggml_backend_dev_count()) {
         printf("\033[1;31mFAIL\033[0m\n");
         return 1;
     }
 
-    ggml_quantize_free();
-
     printf("\033[1;32mOK\033[0m\n");
     return 0;
 }
diff --git a/tests/test-barrier.cpp b/tests/test-barrier.cpp
new file mode 100644
index 000000000..d85bf912b
--- /dev/null
+++ b/tests/test-barrier.cpp
@@ -0,0 +1,94 @@
+#include "ggml.h"
+#include "ggml-cpu.h"
+#include "ggml-backend.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#define MAX_NARGS 2
+
+int main(int argc, char *argv[]) {
+
+    int n_threads = 4;
+    int n_rounds  = 100;
+
+    if (argc > 1) {
+        n_threads = std::atoi(argv[1]);
+    }
+
+    if (argc > 2) {
+        n_rounds  = std::atoi(argv[2]);
+    }
+
+    struct ggml_init_params params = {
+        /* .mem_size   = */ 1024*1024*1024,
+        /* .mem_buffer = */ NULL,
+        /* .no_alloc   = */ false,
+    };
+
+    struct ggml_context * ctx = ggml_init(params);
+
+    // Create graph
+    struct ggml_cgraph * gf = ggml_new_graph(ctx);
+
+    // Lots of small, parallel ops where barriers in between will dominate
+    struct ggml_tensor * out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32,  64);
+    for (int i = 0; i < 1000; i++) {
+        struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 64, 128);
+        out = ggml_mul_mat(ctx, a, out);
+
+        struct ggml_tensor * d = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 128, 64);
+        out = ggml_mul_mat(ctx, d, out);
+    }
+
+    ggml_build_forward_expand(gf, out);
+    int n_nodes = ggml_graph_n_nodes(gf);
+
+    // Create threadpool
+    struct ggml_threadpool_params tpp  = ggml_threadpool_params_default(n_threads);
+    struct ggml_threadpool* threadpool = ggml_threadpool_new(&tpp);
+    if (!threadpool) {
+        fprintf(stderr, "threadpool create failed : n_threads %d\n", n_threads);
+        exit(1);
+    }
+
+    // Create compute plan
+    struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads, threadpool);
+
+    std::vector work_data(cplan.work_size);
+    cplan.work_data = work_data.data();
+
+    std::cerr << "graph-compute with"
+              << "\n n_threads: " << n_threads
+              << "\n   n_nodes: " << n_nodes
+              << "\n  n_rounds: " << n_rounds
+              << "\n";
+    // ggml_graph_print(gf);
+
+    // Warmup
+    ggml_graph_compute(gf, &cplan);
+
+    auto t0 = std::chrono::high_resolution_clock::now();
+
+    for (int i=0; i < n_rounds; i++) {
+        ggml_graph_compute(gf, &cplan);
+    }
+
+    auto t1 = std::chrono::high_resolution_clock::now();
+
+    auto usec = std::chrono::duration_cast(t1-t0).count();
+    auto nsec = std::chrono::duration_cast(t1-t0).count();
+    std::cerr << "graph-compute took " << usec << " usec "
+              << "\n " << (float) usec / n_rounds << " usec per-iter"
+              << "\n " << (float) nsec / (n_rounds * n_nodes) << " nsec per-node"
+              << "\n";
+
+    ggml_threadpool_free(threadpool);
+    ggml_free(ctx);
+
+    return 0;
+}
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
index a8222caee..4563f9dcb 100644
--- a/tests/test-chat-template.cpp
+++ b/tests/test-chat-template.cpp
@@ -7,9 +7,19 @@
 
 #include "llama.h"
 #include "common.h"
+#include "chat-template.hpp"
+
+static std::string normalize_newlines(const std::string & s) {
+#ifdef _WIN32
+  static const std::regex nl_regex("\r\n");
+  return std::regex_replace(s, nl_regex, "\n");
+#else
+  return s;
+#endif
+}
 
 int main(void) {
-    llama_chat_message conversation[] = {
+    std::vector conversation {
         {"system", "You are a helpful assistant"},
         {"user", "Hello"},
         {"assistant", "Hi there"},
@@ -17,161 +27,353 @@ int main(void) {
         {"assistant", "   I am an assistant   "},
         {"user", "Another question"},
     };
-    size_t message_count = 6;
-    std::vector templates = {
-        // teknium/OpenHermes-2.5-Mistral-7B
-        "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
-        // mistralai/Mistral-7B-Instruct-v0.2
-        "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
-        // TheBloke/FusionNet_34Bx2_MoE-AWQ
-        "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}",
-        // bofenghuang/vigogne-2-70b-chat
-        "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
-        // mlabonne/AlphaMonarch-7B
-        "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
-        // google/gemma-7b-it
-        "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}",
-        // OrionStarAI/Orion-14B-Chat
-        "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
-        // openchat/openchat-3.5-0106
-        // The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d
-        // So we match against the included template but implement the suggested version.
-        "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
-        // deepseek-ai/deepseek-coder-33b-instruct
-        "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n    {%- if message['role'] == 'system' -%}\n        {%- set ns.found = true -%}\n    {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n    {%- else %}\n        {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n        {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
-        // eachadea/vicuna-13b-1.1
-        // No template included in tokenizer_config.json, so this template likely needs to be manually set.
-        "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
-        // Orca-Vicuna
-        // No template included in tokenizer_config.json, so this template likely needs to be manually set.
-        "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
-        // CohereForAI/c4ai-command-r-plus
-        "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'  + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
-        // Llama-3
-        "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
-        //Phi-3-mini
-        "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
-        //Phi-3-small
-        "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
-        //Phi-3-medium
-        "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
-        //Phi-3-vision
-        "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
-        // ChatGLM3
-        "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
-        // ChatGLM4
-        u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
-        // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
-        u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
-        // DeepSeek-V2
-        "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
+    struct TestCase {
+        std::string name;
+        std::string template_str;
+        std::string expected_output;
+        std::string expected_output_jinja;
+        std::string bos_token = "";
+        std::string eos_token = "";
+        bool supported_with_jinja = true;
     };
-    std::vector expected_output = {
-        // teknium/OpenHermes-2.5-Mistral-7B
-        "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
-        // mistralai/Mistral-7B-Instruct-v0.2
-        "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST]   I am an assistant   [INST] Another question [/INST]",
-        // TheBloke/FusionNet_34Bx2_MoE-AWQ
-        "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST]    I am an assistant    [INST] Another question [/INST]",
-        // bofenghuang/vigogne-2-70b-chat
-        "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]",
-        // mlabonne/AlphaMonarch-7B
-        "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n   I am an assistant   \nuser\nAnother question\nassistant\n",
-        // google/gemma-7b-it
-        "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n",
-        // OrionStarAI/Orion-14B-Chat
-        "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant:    I am an assistant   Human: Another question\n\nAssistant: ",
-        // openchat/openchat-3.5-0106
-        "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant:    I am an assistant   <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
-        // deepseek-ai/deepseek-coder-33b-instruct
-        "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n   I am an assistant   \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
-        // eachadea/vicuna-13b-1.1
-        "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT:    I am an assistant   \nUSER: Another question\nASSISTANT:",
-        // Orca-Vicuna
-        "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT:    I am an assistant   \nUSER: Another question\nASSISTANT:",
-        // CohereForAI/c4ai-command-r-plus
-        "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
-        // Llama 3
-        "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
-        //Phi-3-mini
-        "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
-        //Phi-3-small
-        "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
-        //Phi-3-medium
-        "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
-        //Phi-3-vision
-        "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
-        // ChatGLM3
-        "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n    I am an assistant   <|user|>\n Another question<|assistant|>",
-        // ChatGLM4
-        "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>",
-        // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
-        u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question",
-        // DeepSeek-V2
-        u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant:    I am an assistant   <|end▁of▁sentence|>User: Another question\n\nAssistant:",
+    std::vector test_cases {
+        {
+            /* .name= */ "teknium/OpenHermes-2.5-Mistral-7B",
+            /* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}",
+            /* .expected_output= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)",
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
+            /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST]   I am an assistant   [INST] Another question [/INST]",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "TheBloke/FusionNet_34Bx2_MoE-AWQ",
+            /* .template_str= */ "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' '  + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}",
+            /* .expected_output= */       "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]   I am an assistant   [INST] Another question [/INST]",
+            /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST]    I am an assistant    [INST] Another question [/INST]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "bofenghuang/vigogne-2-70b-chat",
+            /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
+            /* .expected_output= */       "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]",
+            /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "mlabonne/AlphaMonarch-7B",
+            /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
+            /* .expected_output= */          "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n   I am an assistant   \nuser\nAnother question\nassistant\n",
+            /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n   I am an assistant   \nuser\nAnother question\nassistant\n",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "google/gemma-7b-it",
+            /* .template_str= */ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}",
+            /* .expected_output= */       "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n",
+            /* .expected_output_jinja= */ "user\nYou are a helpful assistant\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n",
+        },
+        {
+            /* .name= */ "OrionStarAI/Orion-14B-Chat",
+            /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
+            /* .expected_output= */       "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant:    I am an assistant   Human: Another question\n\nAssistant: ",
+            /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant:    I am an assistant   Human: Another question\n\nAssistant: ",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "openchat/openchat-3.5-0106",
+            // The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d
+            // So we match against the included template but implement the suggested version.
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
+            /* .expected_output= */                            "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant:    I am an assistant   <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
+            /* .expected_output_jinja= */ "GPT4 Correct System: You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant:    I am an assistant   <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:",
+        },
+        {
+            /* .name= */ "deepseek-ai/deepseek-coder-33b-instruct",
+            /* .template_str= */ "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n    {%- if message['role'] == 'system' -%}\n        {%- set ns.found = true -%}\n    {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n    {%- else %}\n        {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n        {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
+            /* .expected_output= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n   I am an assistant   \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n",
+            /* .expected_output_jinja= */ "",
+        },
+        {
+            /* .name= */ "eachadea/vicuna-13b-1.1",
+            // No template included in tokenizer_config.json, so this template likely needs to be manually set.
+            /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
+            /* .expected_output= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT:    I am an assistant   \nUSER: Another question\nASSISTANT:",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "Orca-Vicuna",
+            // No template included in tokenizer_config.json, so this template likely needs to be manually set.
+            /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}",
+            /* .expected_output= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT:    I am an assistant   \nUSER: Another question\nASSISTANT:",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "CohereForAI/c4ai-command-r-plus",
+            /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'  + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
+            /* .expected_output= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
+            /* .expected_output_jinja= */ "",
+        },
+        {
+            /* .name= */ "Llama-3",
+            /* .template_str= */ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}",
+            /* .expected_output= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
+            /* .expected_output_jinja= */ "",
+        },
+        {
+            /* .name= */ "Phi-3-mini",
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
+            /* .expected_output= */     "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+            /* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+        },
+        {
+            /* .name= */ "Phi-3-small",
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
+            /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+            /* .expected_output_jinja= */ "",
+        },
+        {
+            /* .name= */ "Phi-3-medium",
+            /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
+            /* .expected_output= */     "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+            /* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+        },
+        {
+            /* .name= */ "Phi-3-vision",
+            /* .template_str= */ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
+            /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n   I am an assistant   <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "ChatGLM3",
+            /* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
+            /* .expected_output= */       "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n    I am an assistant   <|user|>\n Another question<|assistant|>",
+            /* .expected_output_jinja= */ "[gMASK]sop<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>",
+        },
+        {
+            /* .name= */ "ChatGLM4",
+            /* .template_str= */ u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
+            /* .expected_output= */ "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF",
+            /* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
+            /* .expected_output= */ u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "DeepSeek-V2",
+            /* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
+            /* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant:    I am an assistant   <|end▁of▁sentence|>User: Another question\n\nAssistant:",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "<|end▁of▁sentence|>",
+        },
+        {
+            /* .name= */ "ibm-granite/granite-3.0-8b-instruct",
+            /* .template_str= */ "{%- if tools %}\n    {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n    {%- for tool in tools %}\n    {{- tool | tojson(indent=4) }}\n    {%- if not loop.last %}\n        {{- '\n\n' }}\n    {%- endif %}\n    {%- endfor %}\n    {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n    {%- if message['role'] == 'system' %}\n    {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'user' %}\n    {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'assistant' %}\n    {{- '<|start_of_role|>assistant<|end_of_role|>'  + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'assistant_tool_call' %}\n    {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- elif message['role'] == 'tool_response' %}\n    {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n    {%- endif %}\n    {%- if loop.last and add_generation_prompt %}\n    {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n    {%- endif %}\n{%- endfor %}",
+            /* .expected_output= */       "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>   I am an assistant   <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n",
+            /* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>   I am an assistant   <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>",
+        },
+        {
+            /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)",
+            /* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n    {%- set system_message = messages[0]['content'] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n        {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n    {%- endif %}\n    {%- if message['role'] == 'user' %}\n        {%- if loop.first and system_message is defined %}\n            {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n        {%- else %}\n            {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n        {%- endif %}\n    {%- elif message['role'] == 'assistant' %}\n        {{- ' ' + message['content'] + eos_token}}\n    {%- else %}\n        {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n    {%- endif %}\n{%- endfor %}\n",
+            /* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there [INST] Who are you [/INST]    I am an assistant    [INST] Another question [/INST]",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)",
+            /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n    {%- set system_message = messages[0][\"content\"] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n    {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n        {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n            {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n        {%- endif %}\n        {%- set ns.index = ns.index + 1 %}\n    {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if message[\"role\"] == \"user\" %}\n        {%- if tools is not none and (message == user_messages[-1]) %}\n            {{- \"[AVAILABLE_TOOLS] [\" }}\n            {%- for tool in tools %}\n                {%- set tool = tool.function %}\n                {{- '{\"type\": \"function\", \"function\": {' }}\n                {%- for key, val in tool.items() if key != \"return\" %}\n                    {%- if val is string %}\n                        {{- '\"' + key + '\": \"' + val + '\"' }}\n                    {%- else %}\n                        {{- '\"' + key + '\": ' + val|tojson }}\n                    {%- endif %}\n                    {%- if not loop.last %}\n                        {{- \", \" }}\n                    {%- endif %}\n                {%- endfor %}\n                {{- \"}}\" }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- else %}\n                    {{- \"]\" }}\n                {%- endif %}\n            {%- endfor %}\n            {{- \"[/AVAILABLE_TOOLS]\" }}\n            {%- endif %}\n        {%- if loop.last and system_message is defined %}\n            {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n        {%- else %}\n            {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n        {%- endif %}\n    {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n        {{- \"[TOOL_CALLS] [\" }}\n        {%- for tool_call in message.tool_calls %}\n            {%- set out = tool_call.function|tojson %}\n            {{- out[:-1] }}\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n                {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n            {%- endif %}\n            {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n            {%- if not loop.last %}\n                {{- \", \" }}\n            {%- else %}\n                {{- \"]\" + eos_token }}\n            {%- endif %}\n        {%- endfor %}\n    {%- elif message[\"role\"] == \"assistant\" %}\n        {{- \" \" + message[\"content\"]|trim + eos_token}}\n    {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n        {%- if message.content is defined and message.content.content is defined %}\n            {%- set content = message.content.content %}\n        {%- else %}\n            {%- set content = message.content %}\n        {%- endif %}\n        {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n            {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n        {%- endif %}\n        {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n    {%- else %}\n        {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n    {%- endif %}\n{%- endfor %}\n",
+            /* .expected_output= */       "[INST] You are a helpful assistant\n\nHello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] Another question[/INST]",
+            /* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] You are a helpful assistant\n\nAnother question[/INST]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)",
+            /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n    {%- set system_message = messages[0][\"content\"] %}\n    {%- set loop_messages = messages[1:] %}\n{%- else %}\n    {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n    {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n    {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n        {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n            {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n        {%- endif %}\n        {%- set ns.index = ns.index + 1 %}\n    {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n    {%- if message[\"role\"] == \"user\" %}\n        {%- if tools is not none and (message == user_messages[-1]) %}\n            {{- \"[AVAILABLE_TOOLS][\" }}\n            {%- for tool in tools %}\n                {%- set tool = tool.function %}\n                {{- '{\"type\": \"function\", \"function\": {' }}\n                {%- for key, val in tool.items() if key != \"return\" %}\n                    {%- if val is string %}\n                        {{- '\"' + key + '\": \"' + val + '\"' }}\n                    {%- else %}\n                        {{- '\"' + key + '\": ' + val|tojson }}\n                    {%- endif %}\n                    {%- if not loop.last %}\n                        {{- \", \" }}\n                    {%- endif %}\n                {%- endfor %}\n                {{- \"}}\" }}\n                {%- if not loop.last %}\n                    {{- \", \" }}\n                {%- else %}\n                    {{- \"]\" }}\n                {%- endif %}\n            {%- endfor %}\n            {{- \"[/AVAILABLE_TOOLS]\" }}\n            {%- endif %}\n        {%- if loop.last and system_message is defined %}\n            {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n        {%- else %}\n            {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n        {%- endif %}\n    {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n        {{- \"[TOOL_CALLS][\" }}\n        {%- for tool_call in message.tool_calls %}\n            {%- set out = tool_call.function|tojson %}\n            {{- out[:-1] }}\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n                {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n            {%- endif %}\n            {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n            {%- if not loop.last %}\n                {{- \", \" }}\n            {%- else %}\n                {{- \"]\" + eos_token }}\n            {%- endif %}\n        {%- endfor %}\n    {%- elif message[\"role\"] == \"assistant\" %}\n        {{- message[\"content\"] + eos_token}}\n    {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n        {%- if message.content is defined and message.content.content is defined %}\n            {%- set content = message.content.content %}\n        {%- else %}\n            {%- set content = message.content %}\n        {%- endif %}\n        {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n            {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n        {%- endif %}\n        {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n    {%- else %}\n        {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n    {%- endif %}\n{%- endfor %}\n",
+            /* .expected_output= */       "[INST]You are a helpful assistant\n\nHello[/INST]Hi there[INST]Who are you[/INST]   I am an assistant   [INST]Another question[/INST]",
+            /* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there[INST]Who are you[/INST]   I am an assistant   [INST]You are a helpful assistant\n\nAnother question[/INST]",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)",
+            /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}",
+            /* .expected_output= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there[INST] Who are you[/INST]    I am an assistant   [INST] Another question[/INST]",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "ai-sage/GigaChat-20B-A3B-instruct",
+            /* .template_str= */ "{% if messages[0]['role'] == 'system' -%}\n    {%- set loop_messages = messages[1:] -%}\n    {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n    {%- set loop_messages = messages -%}\n    {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n    {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n    {% endif %}\n    \n    {%- if loop.index0 == 0 -%}\n        {{ system_message -}}\n    {%- endif -%}\n    {%- if message['role'] == 'user' -%}\n        {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n        {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3]  + additional_special_tokens[1] -}}\n    {%- endif -%}\n    {%- if message['role'] == 'assistant' -%}\n        {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n    {%- endif -%}\n    {%- if loop.last and add_generation_prompt -%}\n        {{ 'assistant' + additional_special_tokens[0] -}}\n    {%- endif -%}\n{%- endfor %}",
+            /* .expected_output= */ "You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>   I am an assistant   <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+            /* .supported_with_jinja= */ false, // Requires additional_special_tokens as extra context
+        },
+        {
+            /* .name= */ "Infinigence/Megrez-3B-Instruct",
+            /* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
+            /* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|>   I am an assistant   <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
+        {
+            /* .name= */ "phi-4",
+            /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}",
+            /* .expected_output= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|>   I am an assistant   <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>",
+            /* .expected_output_jinja= */ "",
+            /* .bos_token= */ "",
+            /* .eos_token= */ "",
+        },
     };
     std::vector formatted_chat(1024);
     int32_t res;
 
-    // test invalid chat template
-    res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
-    assert(res < 0);
+    // list all supported templates
+    std::vector supported_tmpl;
+    res = llama_chat_builtin_templates(nullptr, 0);
+    assert(res > 0);
+    supported_tmpl.resize(res);
+    res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size());
+    printf("Built-in chat templates:\n");
+    for (auto tmpl : supported_tmpl) {
+        printf("  %s\n", tmpl);
+    }
 
-    for (size_t i = 0; i < templates.size(); i++) {
-        std::string custom_template = templates[i];
-        std::string expected = expected_output[i];
+    // test invalid chat template
+    res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size());
+    assert(res < 0);
+    const auto add_generation_prompt = true;
+
+    for (const auto & test_case : test_cases) {
+        printf("\n\n=== %s ===\n\n", test_case.name.c_str());
         formatted_chat.resize(1024);
         res = llama_chat_apply_template(
-            nullptr,
-            custom_template.c_str(),
-            conversation,
-            message_count,
-            true,
+            test_case.template_str.c_str(),
+            conversation.data(),
+            conversation.size(),
+            add_generation_prompt,
             formatted_chat.data(),
             formatted_chat.size()
         );
         formatted_chat.resize(res);
         std::string output(formatted_chat.data(), formatted_chat.size());
-        printf("%s\n", output.c_str());
-        printf("-------------------------\n");
-        assert(output == expected);
+        if (output != test_case.expected_output) {
+            printf("Expected:\n%s\n", test_case.expected_output.c_str());
+            printf("-------------------------\n");
+            printf("Actual:\n%s\n", output.c_str());
+            fflush(stdout);
+            assert(output == test_case.expected_output);
+        }
     }
 
+    json messages = json::array();
+    for (const auto & msg : conversation) {
+        messages.push_back({
+            {"role", msg.role},
+            {"content", msg.content},
+        });
+    }
+    for (const auto & test_case : test_cases) {
+        if (!test_case.supported_with_jinja) {
+            continue;
+        }
+        printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
+        try {
+            minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token);
+            auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt));
+            auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
+            if (output != expected_output) {
+                printf("Expected:\n%s\n", expected_output.c_str());
+                printf("-------------------------\n");
+                printf("Actual:\n%s\n", output.c_str());
+                fflush(stdout);
+                assert(output == expected_output);
+            }
+        } catch (const std::exception & e) {
+            printf("ERROR: %s\n", e.what());
+            assert(false);
+        }
+    }
 
     // test llama_chat_format_single for system message
     printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
-    std::vector chat2;
-    llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
+    std::vector chat2;
+    common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
 
-    auto fmt_sys = [&](std::string tmpl) {
-        auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
-        printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
+    auto fmt_sys = [&](std::string tmpl_str) {
+        minja::chat_template tmpl(tmpl_str, "", "");
+        auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
+        printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
         printf("-------------------------\n");
         return output;
     };
     assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
+    assert(fmt_sys("mistral-v1") == " [INST] You are a helpful assistant\n\n");
+    assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n");
+    assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n");
+    assert(fmt_sys("mistral-v7") == "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT]");
     assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
+    assert(fmt_sys("llama2-sys") == "[INST] <>\nYou are a helpful assistant\n<>\n\n");
+    assert(fmt_sys("mistral") == "[INST] You are a helpful assistant\n"); // for old pre-v1 templates
     assert(fmt_sys("gemma")  == ""); // for gemma, system message is merged with user message
     assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
+    assert(fmt_sys("gigachat") == "You are a helpful assistant<|message_sep|>");
 
 
     // test llama_chat_format_single for user message
     printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
-    chat2.push_back({"system", "You are a helpful assistant"});
-    chat2.push_back({"user", "Hello"});
-    chat2.push_back({"assistant", "I am assistant"});
-    llama_chat_msg new_msg{"user", "How are you"};
+    chat2.push_back({"system", "You are a helpful assistant", {}});
+    chat2.push_back({"user", "Hello", {}});
+    chat2.push_back({"assistant", "I am assistant", {}});
+    common_chat_msg new_msg{"user", "How are you", {}};
 
-    auto fmt_single = [&](std::string tmpl) {
-        auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
-        printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
+    auto fmt_single = [&](std::string tmpl_str) {
+        minja::chat_template tmpl(tmpl_str, "", "");
+        auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
+        printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
         printf("-------------------------\n");
         return output;
     };
     assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
+    assert(fmt_single("mistral-v1") == " [INST] How are you [/INST]");
+    assert(fmt_single("mistral-v3") == "[INST] How are you[/INST]");
+    assert(fmt_single("mistral-v3-tekken") == "[INST]How are you[/INST]");
+    assert(fmt_single("mistral-v7") == "[INST] How are you[/INST]");
     assert(fmt_single("llama2") == "[INST] How are you [/INST]");
+    assert(fmt_single("mistral") == "[INST] How are you [/INST]"); // for old pre-v1 templates
     assert(fmt_single("gemma")  == "\nuser\nHow are you\nmodel\n");
     assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
+    assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>");
 
     return 0;
 }
diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp
new file mode 100644
index 000000000..ccc65d87a
--- /dev/null
+++ b/tests/test-chat.cpp
@@ -0,0 +1,521 @@
+//  Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
+//
+//  Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
+//  e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
+//
+//    cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
+//
+#include 
+#include 
+#include 
+#include 
+
+#include "chat-template.hpp"
+#include "chat.hpp"
+#include "llama-grammar.h"
+#include "unicode.h"
+
+using json = nlohmann::ordered_json;
+
+static common_chat_msg msg_from_json(const json & message) {
+    common_chat_msg ret{
+        "assistant",
+        "",
+        {},
+    };
+    if (message.contains("content") && !message.at("content").is_null()) {
+        ret.content = message.at("content").get();
+    }
+    auto has_tool_calls = message.contains("tool_calls");
+    if (has_tool_calls) {
+        for (const auto & tc : message.at("tool_calls")) {
+            const auto & arguments = tc.at("function").at("arguments");
+            ret.tool_calls.push_back({
+                tc.at("function").at("name").get(),
+                arguments.is_string() ? arguments.get() : arguments.dump(),
+                tc.contains("id") ? tc.at("id").get() : "",
+            });
+        }
+    }
+    return ret;
+}
+
+template  static void assert_equals(const T & expected, const T & actual) {
+    if (expected != actual) {
+        std::cerr << "Expected: " << expected << std::endl;
+        std::cerr << "Actual: " << actual << std::endl;
+        std::cerr << std::flush;
+        throw std::runtime_error("Test failed");
+    }
+}
+
+static std::string read_file(const std::string & path) {
+    std::cerr << "# Reading: " << path << std::endl << std::flush;
+    std::ifstream fs(path, std::ios_base::binary);
+    if (!fs.is_open()) {
+        fs = std::ifstream("../" + path, std::ios_base::binary);
+        if (!fs.is_open()) {
+            throw std::runtime_error("Failed to open file: " + path);
+        }
+    }
+    fs.seekg(0, std::ios_base::end);
+    auto size = fs.tellg();
+    fs.seekg(0);
+    std::string out;
+    out.resize(static_cast(size));
+    fs.read(&out[0], static_cast(size));
+    return out;
+}
+
+static std::unique_ptr build_grammar(const std::string & grammar_str) {
+    return std::unique_ptr(
+        llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0));
+}
+
+// TODO: extract to common helper (copied from test-grammar-integration.cpp)
+static bool match_string(const std::string & input, llama_grammar * grammar) {
+    const auto cpts = unicode_cpts_from_utf8(input);
+
+    auto & stacks_cur = llama_grammar_get_stacks(grammar);
+
+    for (const auto & cpt : cpts) {
+        llama_grammar_accept(grammar, cpt);
+
+        if (stacks_cur.empty()) {
+            // no stacks means that the grammar failed to match at this point
+            return false;
+        }
+    }
+
+    for (const auto & stack : stacks_cur) {
+        if (stack.empty()) {
+            // An empty stack means that the grammar has been completed
+            return true;
+        }
+    }
+
+    return false;
+}
+
+// Dumps `{"a": 1}` as `"{\"a\": 1}"`, unlike nlohmann::json::dump which would dump it as `"{\"a\":1}"`.
+static std::string dump(const json & j) {
+    return minja::Value(j).dump(-1, /* to_json= */ true);
+}
+
+static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
+    assert_equals(expected.role, actual.role);
+    assert_equals(expected.content, actual.content);
+    assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
+    for (size_t i = 0; i < expected.tool_calls.size(); i++) {
+        const auto & expected_tool_call = expected.tool_calls[i];
+        const auto & actual_tool_call   = actual.tool_calls[i];
+        assert_equals(expected_tool_call.name, actual_tool_call.name);
+        assert_equals(dump(json::parse(expected_tool_call.arguments)), dump(json::parse(actual_tool_call.arguments)));
+        assert_equals(expected_tool_call.id, actual_tool_call.id);
+    }
+}
+
+const auto special_function_tool = json::parse(R"({
+  "type": "function",
+  "function": {
+    "name": "special_function",
+    "description": "I'm special",
+    "parameters": {
+      "type": "object",
+      "properties": {
+        "arg1": {
+          "type": "integer",
+          "description": "The arg."
+        }
+      },
+      "required": ["arg1"]
+    }
+  }
+})");
+const auto python_tool           = json::parse(R"({
+  "type": "function",
+  "function": {
+    "name": "python",
+    "description": "an ipython interpreter",
+    "parameters": {
+      "type": "object",
+      "properties": {
+        "code": {
+          "type": "string",
+          "description": "Python code to execute."
+        }
+      },
+      "required": ["code"]
+    }
+  }
+})");
+const auto code_interpreter_tool = json::parse(R"({
+  "type": "function",
+  "function": {
+    "name": "code_interpreter",
+    "description": "an ipython interpreter",
+    "parameters": {
+      "type": "object",
+      "properties": {
+        "code": {
+          "type": "string",
+          "description": "Python code to execute."
+        }
+      },
+      "required": ["code"]
+    }
+  }
+})");
+const json tools                 = { special_function_tool, python_tool };
+const json llama_3_1_tools       = { special_function_tool, code_interpreter_tool };
+
+struct delta_data {
+    std::string        delta;
+    std::string        grammar;
+    common_chat_format format;
+};
+
+static delta_data init_delta(const common_chat_template & tmpl, const std::vector & end_tokens,
+                             const json & user_message, const json & delta_message, const json & tools,
+                             const json & tool_choice) {
+    common_chat_inputs inputs;
+    inputs.parallel_tool_calls = true;
+    inputs.messages            = json::array();
+    inputs.messages.push_back(user_message);
+    inputs.tools       = tools;
+    inputs.tool_choice = tool_choice;
+    auto params_prefix = common_chat_params_init(tmpl, inputs);
+
+    inputs.messages.push_back(delta_message);
+    inputs.add_generation_prompt = false;
+    auto params_full             = common_chat_params_init(tmpl, inputs);
+
+    std::string prefix = params_prefix.prompt;
+    std::string full   = params_full.prompt;
+
+    // Check full starts with prefix
+    if (full.find(prefix) != 0) {
+        fprintf(stderr, "Full:\n%s\n\nPrefix:\n%s\n\n", full.c_str(), prefix.c_str());
+        throw std::runtime_error("Full message does not start with prefix");
+    }
+
+    if (full == prefix) {
+        throw std::runtime_error("Full message is the same as the prefix");
+    }
+
+    auto delta = full.substr(prefix.size());
+
+    // Strip end tokens
+    for (const auto & end_token : end_tokens) {
+        // rfind to find the last occurrence
+        auto pos = delta.rfind(end_token);
+        if (pos != std::string::npos) {
+            delta = delta.substr(0, pos);
+            break;
+        }
+    }
+    return { delta, params_full.grammar, params_full.format };
+}
+
+/*
+  Applies the template to 1 user message w/ add_generation_prompt=true, then w/ the test message w/ add_generation_prompt=false,
+  gets the diff, removes any end tokens and parses the result w/ the grammar, checking that
+  the parsed message is the same as the test_message
+*/
+static void test_template(const common_chat_template & tmpl, const std::vector & end_tokens,
+                          const json & test_message, const json & tools = {}, const std::string & expected_delta = "",
+                          bool skip_grammar_test = false, bool skip_parser_test = false) {
+    common_chat_msg expected_msg = msg_from_json(test_message);
+
+    auto user_message = json{
+        { "role",    "user"          },
+        { "content", "Hello, world!" }
+    };
+
+    for (const auto & tool_choice : json({ "auto", "required" })) {
+        auto data = init_delta(tmpl, end_tokens, user_message, test_message, tools, tool_choice);
+        if (!expected_delta.empty()) {
+            assert_equals(expected_delta, data.delta);
+        }
+
+        if (!skip_parser_test) {
+            const auto msg = common_chat_parse(data.delta, data.format);
+            assert_msg_equals(expected_msg, msg);
+        }
+
+        if (!expected_msg.tool_calls.empty()) {
+            GGML_ASSERT(!data.grammar.empty());
+        }
+        if (!data.grammar.empty()) {
+            auto grammar = build_grammar(data.grammar);
+            if (!grammar) {
+                throw std::runtime_error("Failed to build grammar");
+            }
+            // TODO: exercice lazy grammars + triggers here, instead of skipping the test
+            if (!skip_grammar_test) {
+                if (!match_string(data.delta, grammar.get())) {
+                    throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta +
+                                             "\n\nGrammar: " + data.grammar);
+                }
+            }
+        }
+    }
+}
+
+static void test_template_output_parsers() {
+    auto text_message = json{
+        { "role",    "assistant"     },
+        { "content", "Hello, world!" },
+    };
+    auto tool_call_message = json{
+        { "role",       "assistant"                },
+        { "content",    {}                         },
+        { "tool_calls", json{ {
+                            { "type", "function" },
+                            { "function", { { "name", "special_function" }, { "arguments", "{\"arg1\": 1}" } } },
+                        } } }
+    };
+    auto tool_call_message_with_id                   = json::parse(tool_call_message.dump());
+    tool_call_message_with_id["tool_calls"][0]["id"] = "123456789";
+
+    auto python_tool_call_message = json{
+        { "role",       "assistant"                },
+        { "content",    {}                         },
+        { "tool_calls", json{ {
+                            { "type", "function" },
+                            { "function",
+                              {
+                                  { "name", "python" },
+                                  { "arguments",
+                                    {
+                                        { "code", "print('hey')" },
+                                    } },
+                              } },
+                        } } }
+    };
+    auto code_interpreter_tool_call_message = json{
+        { "role",       "assistant"                },
+        { "content",    {}                         },
+        { "tool_calls", json{ {
+                            { "type", "function" },
+                            { "function",
+                              {
+                                  { "name", "code_interpreter" },
+                                  { "arguments",
+                                    {
+                                        { "code", "print('hey')" },
+                                    } },
+                              } },
+                        } } }
+    };
+
+    common_chat_inputs inputs_no_tools;
+    inputs_no_tools.messages = {
+        { { "role", "user" }, { "content", "Hey" } }
+    };
+
+    common_chat_inputs inputs_tools = inputs_no_tools;
+    inputs_tools.tools              = json::array();
+    inputs_tools.tools.push_back(special_function_tool);
+
+    common_chat_inputs inputs_tools_builtin = inputs_no_tools;
+    inputs_tools_builtin.tools              = json::array();
+    inputs_tools_builtin.tools.push_back(python_tool);
+
+    {
+        const common_chat_template tmpl(read_file("models/templates/google-gemma-2-2b-it.jinja"), "", "");
+        std::vector   end_tokens{ "" };
+
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_params_init(tmpl, inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_GENERIC,
+                      common_chat_params_init(
+                          common_chat_template(read_file("models/templates/microsoft-Phi-3.5-mini-instruct.jinja"),
+                                               "", ""),
+                          inputs_tools)
+                          .format);
+
+        // Generic tool calls doesn't generate / parse content-only messages symmetrically.
+
+        assert_msg_equals(msg_from_json(text_message),
+                          common_chat_parse("{\n"
+                                            "  \"response\": \"Hello, world!\"\n"
+                                            "}",
+                                            common_chat_params_init(tmpl, inputs_tools).format));
+        test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
+                      "{\n"
+                      "  \"tool_calls\": [\n"
+                      "    {\n"
+                      "      \"name\": \"special_function\",\n"
+                      "      \"arguments\": {\n"
+                      "        \"arg1\": 1\n"
+                      "      },\n"
+                      "      \"id\": \"123456789\"\n"
+                      "    }\n"
+                      "  ]\n"
+                      "}");
+    }
+    {
+        const common_chat_template tmpl(read_file("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"), "",
+                                        "");
+        std::vector   end_tokens{ "" };
+
+        assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
+
+        test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
+        test_template(
+            tmpl, end_tokens, tool_call_message_with_id, tools,
+            "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]",
+            /* skip_grammar_test= */ true);
+    }
+    {
+        const common_chat_template tmpl(
+            read_file("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"), "", "");
+        std::vector end_tokens{ "<|im_end|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(
+            COMMON_CHAT_FORMAT_HERMES_2_PRO,
+            common_chat_params_init(
+                common_chat_template(read_file("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja"),
+                                     "", ""),
+                inputs_tools)
+                .format);
+        assert_equals(
+            COMMON_CHAT_FORMAT_HERMES_2_PRO,
+            common_chat_params_init(
+                common_chat_template(read_file("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja"), "", ""),
+                inputs_tools)
+                .format);
+
+        test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
+        test_template(tmpl, end_tokens, tool_call_message, tools,
+                      "\n"
+                      "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                      "");
+        test_template(tmpl, end_tokens, python_tool_call_message, tools,
+                      "\n"
+                      "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
+                      "");
+    }
+    {
+        const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"), "",
+                                        "");
+        std::vector   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
+                      common_chat_params_init(tmpl, inputs_tools_builtin).format);
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
+                      common_chat_params_init(
+                          common_chat_template(read_file("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja"),
+                                               "", ""),
+                          inputs_tools_builtin)
+                          .format);
+
+        // test_template(tmpl, end_tokens, text_message, tools, R"(?)", /* skip_grammar_test= */ true);
+        test_template(tmpl, end_tokens, code_interpreter_tool_call_message, llama_3_1_tools,
+                      "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
+        test_template(tmpl, end_tokens, python_tool_call_message, tools,
+                      "<|python_tag|>python.call(code=\"print('hey')\")");
+        test_template(tmpl, end_tokens, tool_call_message, tools,
+                      "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
+    }
+    {
+        const common_chat_template tmpl(read_file("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"), "",
+                                        "");
+        std::vector   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
+
+        test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
+        test_template(tmpl, end_tokens, tool_call_message, tools,
+                      "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
+    }
+    {
+        const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.1.jinja"), "",
+                                        "");
+        std::vector   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
+                      common_chat_params_init(tmpl, inputs_tools).format);
+
+        test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
+        test_template(tmpl, end_tokens, tool_call_message, tools,
+                      "{\"arg1\": 1}");
+    }
+    {
+        const common_chat_template tmpl(read_file("models/templates/meetkai-functionary-medium-v3.2.jinja"), "",
+                                        "");
+        std::vector   end_tokens{ "<|eom_id|>", "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_params_init(tmpl, inputs_tools).format);
+
+        test_template(tmpl, end_tokens, text_message, {},
+                      "all\n"
+                      "Hello, world!",
+                      /* skip_grammar_test= */ true);
+        test_template(tmpl, end_tokens, tool_call_message, tools,
+                      "special_function\n"
+                      "{\"arg1\": 1}");
+    }
+    {
+        const common_chat_template tmpl(read_file("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"), "",
+                                        "");
+        std::vector   end_tokens{ "<|eot_id|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
+
+        test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
+        test_template(tmpl, end_tokens, tool_call_message, tools,
+                      " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
+    }
+    {
+        const common_chat_template tmpl(read_file("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"),
+                                        "", "");
+        std::vector   end_tokens{ "<|end▁of▁sentence|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
+
+        test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* skip_grammar_test= */ true);
+        test_template(tmpl, end_tokens, tool_call_message, tools,
+                      "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
+                      "```json\n"
+                      "{\"arg1\": 1}\n"
+                      "```<|tool▁call▁end|>");
+    }
+}
+
+int main(int argc, char ** argv) {
+#ifndef _WIN32
+    if (argc > 1) {
+        common_chat_inputs inputs;
+        inputs.messages = {
+            { { "role", "user" }, { "content", "Hey" } }
+        };
+        inputs.tools = json::array({ special_function_tool });
+
+        std::cout << "| Template | Format |\n";
+        std::cout << "|----------|--------|\n";
+
+        for (int i = 1; i < argc; i++) {
+            std::string path = argv[i];
+            if (path.rfind(".jinja") != path.size() - 6) {
+                std::cerr << "Skipping non-jinja file: " << path << std::endl;
+                continue;
+            }
+            common_chat_template tmpl(read_file(path), "", "");
+            auto                 parts = string_split(path, "/");
+            auto                 name  = parts[parts.size() - 1];
+            std::cout << "| " << name << " | " << common_chat_format_name(common_chat_params_init(tmpl, inputs).format)
+                      << " |\n";
+        }
+    } else
+#endif
+    {
+        test_template_output_parsers();
+        std::cout << "\n[chat] All tests passed!" << std::endl;
+    }
+    return 0;
+}
diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp
new file mode 100644
index 000000000..6ed696328
--- /dev/null
+++ b/tests/test-gguf.cpp
@@ -0,0 +1,1338 @@
+#include "ggml.h"
+#include "ggml-backend.h"
+#include "../ggml/src/ggml-impl.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+constexpr int offset_has_kv      = 1000;
+constexpr int offset_has_tensors = 2000;
+constexpr int offset_has_data    = 3000;
+
+enum handcrafted_file_type {
+    HANDCRAFTED_HEADER_BAD_MAGIC           =  10,
+    HANDCRAFTED_HEADER_BAD_VERSION_1       =  20,
+    HANDCRAFTED_HEADER_BAD_VERSION_FUTURE  =  30,
+    HANDCRAFTED_HEADER_BAD_N_TENSORS       =  40,
+    HANDCRAFTED_HEADER_BAD_N_KV            =  50,
+    HANDCRAFTED_HEADER_EMPTY               = 800,
+
+    HANDCRAFTED_KV_BAD_KEY_SIZE            =  10 + offset_has_kv,
+    HANDCRAFTED_KV_BAD_TYPE                =  20 + offset_has_kv,
+    // HANDCRAFTED_KV_BAD_VALUE_SIZE          =  30 + offset_has_kv, // removed because it can result in allocations > 1 TB (default sanitizer limit)
+    HANDCRAFTED_KV_DUPLICATE_KEY           =  40 + offset_has_kv,
+    HANDCRAFTED_KV_BAD_ALIGN               =  50 + offset_has_kv,
+    HANDCRAFTED_KV_SUCCESS                 = 800 + offset_has_kv,
+
+    HANDCRAFTED_TENSORS_BAD_NAME_SIZE      =  10 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_N_DIMS         =  20 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_SHAPE          =  30 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_NE_TOO_BIG         =  40 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_TYPE           =  50 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_OFFSET         =  60 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_DUPLICATE_NAME     =  70 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_BAD_ALIGN          =  75 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN =  80 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_SUCCESS            = 800 + offset_has_tensors,
+    HANDCRAFTED_TENSORS_CUSTOM_ALIGN       = 810 + offset_has_tensors,
+
+    HANDCRAFTED_DATA_NOT_ENOUGH_DATA       =  10 + offset_has_data,
+    HANDCRAFTED_DATA_BAD_ALIGN             =  15 + offset_has_data,
+    HANDCRAFTED_DATA_INCONSISTENT_ALIGN    =  20 + offset_has_data,
+    HANDCRAFTED_DATA_SUCCESS               = 800 + offset_has_data,
+    HANDCRAFTED_DATA_CUSTOM_ALIGN          = 810 + offset_has_data,
+};
+
+static std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) {
+    switch (hft) {
+        case HANDCRAFTED_HEADER_BAD_MAGIC:           return "HEADER_BAD_MAGIC";
+        case HANDCRAFTED_HEADER_BAD_VERSION_1:       return "HEADER_BAD_VERSION_1";
+        case HANDCRAFTED_HEADER_BAD_VERSION_FUTURE:  return "HEADER_BAD_VERSION_FUTURE";
+        case HANDCRAFTED_HEADER_BAD_N_KV:            return "HEADER_BAD_N_KV";
+        case HANDCRAFTED_HEADER_BAD_N_TENSORS:       return "HEADER_BAD_N_TENSORS";
+        case HANDCRAFTED_HEADER_EMPTY:               return "HEADER_EMPTY";
+
+        case HANDCRAFTED_KV_BAD_KEY_SIZE:            return "KV_BAD_KEY_SIZE";
+        case HANDCRAFTED_KV_BAD_TYPE:                return "KV_BAD_TYPE";
+        case HANDCRAFTED_KV_DUPLICATE_KEY:           return "KV_DUPLICATE_KEY";
+        case HANDCRAFTED_KV_BAD_ALIGN:               return "KV_BAD_ALIGN";
+        case HANDCRAFTED_KV_SUCCESS:                 return "KV_RANDOM_KV";
+
+        case HANDCRAFTED_TENSORS_BAD_NAME_SIZE:      return "TENSORS_BAD_NAME_SIZE";
+        case HANDCRAFTED_TENSORS_BAD_N_DIMS:         return "TENSORS_BAD_N_DIMS";
+        case HANDCRAFTED_TENSORS_BAD_SHAPE:          return "TENSORS_BAD_SHAPE";
+        case HANDCRAFTED_TENSORS_NE_TOO_BIG:         return "TENSORS_NE_TOO_BIG";
+        case HANDCRAFTED_TENSORS_BAD_TYPE:           return "TENSORS_BAD_TYPE";
+        case HANDCRAFTED_TENSORS_BAD_OFFSET:         return "TENSORS_BAD_OFFSET";
+        case HANDCRAFTED_TENSORS_DUPLICATE_NAME:     return "TENSORS_DUPLICATE_NAME";
+        case HANDCRAFTED_TENSORS_BAD_ALIGN:          return "TENSORS_BAD_ALIGN";
+        case HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN: return "TENSORS_INCONSISTENT_ALIGN";
+        case HANDCRAFTED_TENSORS_SUCCESS:            return "TENSORS_SUCCESS";
+        case HANDCRAFTED_TENSORS_CUSTOM_ALIGN:       return "TENSORS_CUSTOM_ALIGN";
+
+        case HANDCRAFTED_DATA_NOT_ENOUGH_DATA:       return "DATA_NOT_ENOUGH_DATA";
+        case HANDCRAFTED_DATA_BAD_ALIGN:             return "DATA_BAD_ALIGN";
+        case HANDCRAFTED_DATA_INCONSISTENT_ALIGN:    return "DATA_INCONSISTENT_ALIGN";
+        case HANDCRAFTED_DATA_SUCCESS:               return "DATA_SUCCESS";
+        case HANDCRAFTED_DATA_CUSTOM_ALIGN:          return "DATA_CUSTOM_ALIGN";
+    }
+    GGML_ABORT("fatal error");
+}
+
+static bool expect_context_not_null(const enum handcrafted_file_type hft) {
+    if (hft < offset_has_kv) {
+        return hft >= HANDCRAFTED_HEADER_EMPTY;
+    }
+    if (hft < offset_has_tensors) {
+        return hft >= HANDCRAFTED_KV_SUCCESS;
+    }
+    if (hft < offset_has_data) {
+        return hft >= HANDCRAFTED_TENSORS_SUCCESS;
+    }
+    return hft >= HANDCRAFTED_DATA_SUCCESS;
+}
+
+typedef std::pair> tensor_config_t;
+
+static std::vector get_tensor_configs(std::mt19937 & rng) {
+    std::vector tensor_configs;
+    tensor_configs.reserve(100);
+
+    for (int i = 0; i < 100; ++i) {
+        const enum ggml_type type = ggml_type(rng() % GGML_TYPE_COUNT);
+        if (ggml_type_size(type) == 0) {
+            continue;
+        }
+
+        std::array shape = {1, 1, 1, 1};
+        shape[0] = (1 + rng() % 10) * ggml_blck_size(type);
+        const int n_dims = 1 + rng() % GGML_MAX_DIMS;
+        for (int i = 1; i < n_dims; ++i) {
+            shape[i] = 1 + rng() % 10;
+        }
+
+        tensor_configs.push_back(std::make_pair(type, shape));
+    }
+
+    return tensor_configs;
+}
+
+static std::vector> get_kv_types(std::mt19937 rng) {
+    std::vector> kv_types;
+    kv_types.reserve(100);
+
+    for (int i = 0; i < 100; ++i) {
+        const gguf_type type = gguf_type(rng() % GGUF_TYPE_COUNT);
+
+        if (type == GGUF_TYPE_ARRAY) {
+            const gguf_type type_arr = gguf_type(rng() % GGUF_TYPE_COUNT);
+            if (type_arr == GGUF_TYPE_ARRAY) {
+                continue;
+            }
+            kv_types.push_back(std::make_pair(type, type_arr));
+            continue;
+        }
+
+        kv_types.push_back(std::make_pair(type, gguf_type(-1)));
+    }
+    std::shuffle(kv_types.begin(), kv_types.end(), rng);
+
+    return kv_types;
+}
+
+template 
+static void helper_write(FILE * file, const T & val) {
+    GGML_ASSERT(fwrite(&val, 1, sizeof(val), file) == sizeof(val));
+}
+
+static void helper_write(FILE * file, const void * data, const size_t nbytes) {
+    GGML_ASSERT(fwrite(data, 1, nbytes, file) == nbytes);
+}
+
+static FILE * get_handcrafted_file(const unsigned int seed, const enum handcrafted_file_type hft, const int extra_bytes = 0) {
+    FILE * file = tmpfile();
+
+    if (!file) {
+        return file;
+    }
+
+    std::mt19937 rng(seed);
+    uint32_t alignment = GGUF_DEFAULT_ALIGNMENT;
+
+    if (hft == HANDCRAFTED_HEADER_BAD_MAGIC) {
+        const char bad_magic[4] = {'F', 'U', 'G', 'G'};
+        helper_write(file, bad_magic, sizeof(bad_magic));
+    } else {
+        helper_write(file, GGUF_MAGIC, 4);
+    }
+
+    if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) {
+        const uint32_t version = 1;
+        helper_write(file, version);
+    } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_FUTURE) {
+        const uint32_t version = GGUF_VERSION + 1;
+        helper_write(file, version);
+    } else {
+        const uint32_t version = GGUF_VERSION;
+        helper_write(file, version);
+    }
+
+    std::vector tensor_configs;
+    if (hft >= offset_has_tensors) {
+        tensor_configs = get_tensor_configs(rng);
+    }
+
+    if (hft == HANDCRAFTED_HEADER_BAD_N_TENSORS) {
+        const uint64_t n_tensors = -1;
+        helper_write(file, n_tensors);
+    } else {
+        const uint64_t n_tensors = tensor_configs.size();
+        helper_write(file, n_tensors);
+    }
+
+    std::vector> kv_types;
+    if (hft >= offset_has_kv) {
+        kv_types = get_kv_types(rng);
+    }
+    {
+        uint64_t n_kv = kv_types.size();
+        if (hft == HANDCRAFTED_KV_BAD_ALIGN      ||
+            hft == HANDCRAFTED_TENSORS_BAD_ALIGN || hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN ||
+            hft == HANDCRAFTED_DATA_BAD_ALIGN    || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+
+            n_kv += 1;
+        } else if (hft == HANDCRAFTED_HEADER_BAD_N_KV) {
+            n_kv = -1;
+        }
+        helper_write(file, n_kv);
+    }
+
+    if (hft < offset_has_kv) {
+        while (ftell(file) % alignment != 0) {
+            const char pad = 0;
+            helper_write(file, pad);
+        }
+
+        for (int i = 0; i < extra_bytes; ++i) {
+            const char tmp = 0;
+            helper_write(file, tmp);
+        }
+        rewind(file);
+        return file;
+    }
+
+    for (int i = 0; i < int(kv_types.size()); ++i) {
+        const enum gguf_type type     = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? GGUF_TYPE_COUNT : kv_types[i].first);
+        const enum gguf_type type_arr = gguf_type(hft == HANDCRAFTED_KV_BAD_TYPE ? GGUF_TYPE_COUNT : kv_types[i].second);
+
+        const std::string key = "my_key_" + std::to_string((hft == HANDCRAFTED_KV_DUPLICATE_KEY ? i/2 : i));
+
+        if (hft == HANDCRAFTED_KV_BAD_KEY_SIZE) {
+            const uint64_t n = -1;
+            helper_write(file, n);
+        } else {
+            const uint64_t n = key.length();
+            helper_write(file, n);
+        }
+        helper_write(file, key.data(), key.length());
+
+        {
+            const int32_t type32 = int32_t(type);
+            helper_write(file, type32);
+        }
+
+        uint32_t data[16];
+        for (int j = 0; j < 16; ++j) {
+            data[j] = rng();
+            if (type == GGUF_TYPE_STRING || type_arr == GGUF_TYPE_STRING) {
+                data[j] |= 0x01010101; // avoid random null-termination of string
+            }
+        }
+
+        if (type == GGUF_TYPE_STRING) {
+            const uint64_t n = rng() % sizeof(data);
+            helper_write(file, n);
+            helper_write(file, data, n);
+            continue;
+        }
+
+        if (type == GGUF_TYPE_ARRAY) {
+            {
+                const int32_t type32 = int32_t(type_arr);
+                helper_write(file, type32);
+            }
+            if (type_arr == GGUF_TYPE_STRING) {
+                const uint64_t nstr = rng() % (16 + 1);
+                helper_write(file, nstr);
+                for (uint64_t istr = 0; istr < nstr; ++istr) {
+                    const uint64_t n = rng() % (sizeof(uint32_t) + 1);
+                    helper_write(file, n);
+                    helper_write(file, &data[istr], n);
+                }
+                continue;
+            }
+            const size_t type_size = gguf_type_size(type_arr);
+            const uint64_t n = (rng() % sizeof(data)) / type_size;
+            helper_write(file, n);
+            helper_write(file, &data, n*type_size);
+            continue;
+        }
+
+        helper_write(file, data, hft == HANDCRAFTED_KV_BAD_TYPE ? 1 : gguf_type_size(type));
+    }
+
+    if (hft == HANDCRAFTED_KV_BAD_ALIGN      ||
+        hft == HANDCRAFTED_TENSORS_BAD_ALIGN || hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN ||
+        hft == HANDCRAFTED_DATA_BAD_ALIGN    || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN) {
+
+        const uint64_t n = strlen(GGUF_KEY_GENERAL_ALIGNMENT);
+        helper_write(file, n);
+        helper_write(file, GGUF_KEY_GENERAL_ALIGNMENT, n);
+
+        const int32_t type = gguf_type(GGUF_TYPE_UINT32);
+        helper_write(file, type);
+
+        alignment = expect_context_not_null(hft) ? 1 : 13;
+        helper_write(file, alignment);
+    }
+
+    if (hft < offset_has_tensors) {
+        while (ftell(file) % alignment != 0) {
+            const char pad = 0;
+            helper_write(file, pad);
+        }
+
+        for (int i = 0; i < extra_bytes; ++i) {
+            const char tmp = 0;
+            helper_write(file, tmp);
+        }
+        rewind(file);
+        return file;
+    }
+
+    if (hft == HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN || hft == HANDCRAFTED_DATA_INCONSISTENT_ALIGN) {
+        alignment = 1;
+    }
+
+    uint64_t offset = 0;
+    for (int i = 0; i < int(tensor_configs.size()); ++i) {
+        const ggml_type                          type  = tensor_configs[i].first;
+        const std::array shape = tensor_configs[i].second;
+
+        std::string name = "my_tensor";
+        if (hft != HANDCRAFTED_TENSORS_DUPLICATE_NAME) {
+            name += "_" + std::to_string(i);
+        }
+        if (hft == HANDCRAFTED_TENSORS_BAD_NAME_SIZE) {
+            name += "_with_a_very_long_name_which_is_longer_than_what_is_allowed_for_ggml_tensors";
+            GGML_ASSERT(name.length() >= GGML_MAX_NAME);
+        }
+        {
+            const uint64_t n = name.length();
+            helper_write(file, n);
+        }
+        helper_write(file, name.data(), name.length());
+
+        uint32_t n_dims = hft == HANDCRAFTED_TENSORS_NE_TOO_BIG ? 2 : 1;
+        for (int i = GGML_MAX_DIMS-1; i >= 1; --i) {
+            if (shape[i] != 1) {
+                n_dims = i + 1;
+                break;
+            }
+        }
+        if (hft == HANDCRAFTED_TENSORS_BAD_N_DIMS) {
+            const uint32_t n_dims_bad = GGML_MAX_DIMS + 1;
+            helper_write(file, n_dims_bad);
+        } else {
+            helper_write(file, n_dims);
+        }
+
+        if (hft == HANDCRAFTED_TENSORS_BAD_SHAPE) {
+            for (uint32_t j = 0; j < n_dims; ++j) {
+                const int64_t bad_dim = -1;
+                helper_write(file, bad_dim);
+            }
+        } else if (hft == HANDCRAFTED_TENSORS_NE_TOO_BIG){
+            for (uint32_t j = 0; j < n_dims; ++j) {
+                const int64_t big_dim = 4*int64_t(INT32_MAX);
+                helper_write(file, big_dim);
+            }
+        } else {
+            helper_write(file, shape.data(), n_dims*sizeof(int64_t));
+        }
+
+        {
+            const int32_t type32 = hft == HANDCRAFTED_TENSORS_BAD_TYPE ? GGML_TYPE_COUNT : int32_t(type);
+            helper_write(file, type32);
+        }
+
+        if (hft == HANDCRAFTED_TENSORS_BAD_OFFSET) {
+            const uint64_t bad_offset = -1;
+            helper_write(file, bad_offset);
+        } else {
+            helper_write(file, offset);
+        }
+
+        int64_t ne = shape[0];
+        for (uint32_t i = 1; i < n_dims; ++i) {
+            ne *= shape[i];
+        }
+        offset += GGML_PAD(ggml_row_size(type, ne), alignment);
+    }
+
+    while (ftell(file) % alignment != 0) {
+        const char pad = 0;
+        helper_write(file, pad);
+    }
+
+    if (hft >= offset_has_data) {
+        rng.seed(seed + 1);
+        uint64_t nbytes = offset;
+        if (hft == HANDCRAFTED_DATA_NOT_ENOUGH_DATA) {
+            nbytes -= 1;
+        }
+        for (uint64_t i = 0; i < nbytes; ++i) {
+            const uint8_t random_byte = i % 256;
+            helper_write(file, random_byte);
+        }
+    }
+
+    for (int i = 0; i < extra_bytes; ++i) {
+        const char tmp = 0;
+        helper_write(file, tmp);
+    }
+    rewind(file);
+    return file;
+}
+
+static bool handcrafted_check_header(const gguf_context * gguf_ctx, const unsigned int seed, const bool has_kv, const bool has_tensors, const bool alignment_defined) {
+    if (!gguf_ctx) {
+        return false;
+    }
+
+    std::mt19937 rng(seed);
+
+    std::vector tensor_configs;
+    if (has_tensors) {
+        tensor_configs = get_tensor_configs(rng);
+    }
+    std::vector> kv_types;
+    if (has_kv) {
+        kv_types = get_kv_types(rng);
+    }
+
+    bool ok = true;
+
+    if (gguf_get_version(gguf_ctx) != GGUF_VERSION) {
+        ok = false;
+    }
+    if (gguf_get_n_tensors(gguf_ctx) != int(tensor_configs.size())) {
+        ok = false;
+    }
+    if (gguf_get_n_kv(gguf_ctx) != int(alignment_defined ? kv_types.size() + 1 : kv_types.size())) {
+        ok = false;
+    }
+
+    return ok;
+}
+
+static bool handcrafted_check_kv(const gguf_context * gguf_ctx, const unsigned int seed, const bool has_tensors, const bool alignment_defined) {
+    if (!gguf_ctx) {
+        return false;
+    }
+
+    std::mt19937 rng(seed);
+
+    std::vector tensor_configs;
+    if (has_tensors) {
+        tensor_configs = get_tensor_configs(rng);
+    }
+
+    std::vector> kv_types = get_kv_types(rng);
+
+    bool ok = true;
+
+    for (int i = 0; i < int(kv_types.size()); ++i) {
+        const enum gguf_type type     = gguf_type(kv_types[i].first);
+        const enum gguf_type type_arr = gguf_type(kv_types[i].second);
+
+        const std::string key = "my_key_" + std::to_string(i);
+
+        uint32_t data[16];
+        for (int j = 0; j < 16; ++j) {
+            data[j] = rng();
+            if (type == GGUF_TYPE_STRING || type_arr == GGUF_TYPE_STRING) {
+                data[j] |= 0x01010101; // avoid random null-termination of string
+            }
+        }
+
+        const char * data8 = reinterpret_cast(data);
+        const int id = gguf_find_key(gguf_ctx, key.c_str());
+
+        if (type == GGUF_TYPE_STRING) {
+            const char * str = gguf_get_val_str(gguf_ctx, id);
+            const uint64_t n = strlen(str);
+            const uint64_t n_expected = rng() % sizeof(data);
+            if (n != n_expected) {
+                ok = false;
+                continue;
+            }
+            if (!std::equal(str, str + n, data8)) {
+                ok = false;
+            }
+            continue;
+        }
+
+        if (type == GGUF_TYPE_ARRAY) {
+            const size_t type_size = gguf_type_size(type_arr);
+            const uint64_t arr_n = gguf_get_arr_n(gguf_ctx, id);
+
+            if (type_arr == GGUF_TYPE_STRING) {
+                const uint64_t nstr_expected = rng() % (16 + 1);
+                if (arr_n != nstr_expected) {
+                    ok = false;
+                    continue;
+                }
+                for (uint64_t istr = 0; istr < nstr_expected; ++istr) {
+                    const char * str = gguf_get_arr_str(gguf_ctx, id, istr);
+                    const uint64_t n = strlen(str);
+                    const uint64_t n_expected = rng() % (sizeof(uint32_t) + 1);
+
+                    if (n != n_expected) {
+                        ok = false;
+                        continue;
+                    }
+                    const char * str_expected = reinterpret_cast(&data[istr]);
+                    if (strncmp(str, str_expected, n) != 0) {
+                        ok = false;
+                        continue;
+                    }
+                }
+                continue;
+            }
+
+            const uint64_t arr_n_expected = (rng() % sizeof(data)) / type_size;
+            if (arr_n != arr_n_expected) {
+                ok = false;
+                continue;
+            }
+
+            const char * data_gguf = reinterpret_cast(gguf_get_arr_data(gguf_ctx, id));
+
+            if (type_arr == GGUF_TYPE_BOOL) {
+                for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    if (bool(data8[arr_i]) != bool(data_gguf[arr_i])) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
+            if (!std::equal(data8, data8 + arr_n*type_size, data_gguf)) {
+                ok = false;
+            }
+            continue;
+        }
+
+        const char * data_gguf = reinterpret_cast(gguf_get_val_data(gguf_ctx, id));
+
+        if (type == GGUF_TYPE_BOOL) {
+            if (bool(*data8) != bool(*data_gguf)) {
+                ok = false;
+            }
+            continue;
+        }
+
+        if (!std::equal(data8, data8 + gguf_type_size(type), data_gguf)) {
+            ok = false;
+        }
+    }
+
+    const uint32_t expected_alignment = alignment_defined ? 1 : GGUF_DEFAULT_ALIGNMENT;
+    if (gguf_get_alignment(gguf_ctx) != expected_alignment) {
+        ok = false;
+    }
+
+    return ok;
+}
+
+static bool handcrafted_check_tensors(const gguf_context * gguf_ctx, const unsigned int seed) {
+    if (!gguf_ctx) {
+        return false;
+    }
+
+    std::mt19937 rng(seed);
+
+    std::vector tensor_configs = get_tensor_configs(rng);
+
+    // Call get_kv_types to get the same RNG state:
+    get_kv_types(rng);
+
+    bool ok = true;
+
+    const int id_alignment = gguf_find_key(gguf_ctx, GGUF_KEY_GENERAL_ALIGNMENT);
+    const uint32_t alignment = id_alignment >= 0 ? gguf_get_val_u32(gguf_ctx, id_alignment) : GGUF_DEFAULT_ALIGNMENT;
+
+    uint64_t expected_offset = 0;
+    for (int i = 0; i < int(tensor_configs.size()); ++i) {
+        const ggml_type                          type  = tensor_configs[i].first;
+        const std::array shape = tensor_configs[i].second;
+
+        const std::string name = "my_tensor_" + std::to_string(i);
+        const int id = gguf_find_tensor(gguf_ctx, name.c_str());
+
+        if (id >= 0) {
+            if (std::string(gguf_get_tensor_name(gguf_ctx, id)) != name) {
+                ok = false;
+            }
+
+            if (gguf_get_tensor_type(gguf_ctx, id) != type) {
+                ok = false;
+            }
+        } else {
+            ok = false;
+            continue;
+        }
+
+        const size_t offset = gguf_get_tensor_offset(gguf_ctx, id);
+
+        if (offset != expected_offset) {
+            ok = false;
+        }
+
+        int64_t ne = shape[0];
+        for (size_t j = 1; j < GGML_MAX_DIMS; ++j) {
+            ne *= shape[j];
+        }
+        expected_offset += GGML_PAD(ggml_row_size(type, ne), alignment);
+    }
+
+    return ok;
+}
+
+static bool handcrafted_check_tensor_data(const gguf_context * gguf_ctx, const unsigned int seed, FILE * file) {
+    if (!gguf_ctx) {
+        return false;
+    }
+
+    std::mt19937 rng(seed);
+
+    std::vector tensor_configs = get_tensor_configs(rng);
+
+    bool ok = true;
+
+    for (int i = 0; i < int(tensor_configs.size()); ++i) {
+        const ggml_type                          type  = tensor_configs[i].first;
+        const std::array shape = tensor_configs[i].second;
+
+        int64_t ne = shape[0];
+        for (size_t j = 1; j < GGML_MAX_DIMS; ++j) {
+            ne *= shape[j];
+        }
+        const size_t size = ggml_row_size(type, ne);
+
+        const std::string name = "my_tensor_" + std::to_string(i);
+        const size_t offset = gguf_get_tensor_offset(gguf_ctx, gguf_find_tensor(gguf_ctx, name.c_str()));
+
+        std::vector data(size);
+        GGML_ASSERT(fseek(file, gguf_get_data_offset(gguf_ctx) + offset, SEEK_SET) == 0);
+        GGML_ASSERT(fread(data.data(), 1, data.size(), file) == data.size());
+
+        for (size_t j = 0; j < size; ++j) {
+            const uint8_t expected_byte = (j + offset) % 256;
+            if (data[j] != expected_byte) {
+                ok = false;
+            }
+        }
+    }
+
+    return ok;
+}
+
+static std::pair test_handcrafted_file(const unsigned int seed) {
+    int npass = 0;
+    int ntest = 0;
+
+    const std::vector hfts = {
+        HANDCRAFTED_HEADER_BAD_MAGIC,
+        HANDCRAFTED_HEADER_BAD_VERSION_1,
+        HANDCRAFTED_HEADER_BAD_VERSION_FUTURE,
+        HANDCRAFTED_HEADER_BAD_N_KV,
+        HANDCRAFTED_HEADER_BAD_N_TENSORS,
+        HANDCRAFTED_HEADER_EMPTY,
+
+        HANDCRAFTED_KV_BAD_KEY_SIZE,
+        HANDCRAFTED_KV_BAD_TYPE,
+        HANDCRAFTED_KV_DUPLICATE_KEY,
+        HANDCRAFTED_KV_BAD_ALIGN,
+        HANDCRAFTED_KV_SUCCESS,
+
+        HANDCRAFTED_TENSORS_BAD_NAME_SIZE,
+        HANDCRAFTED_TENSORS_BAD_N_DIMS,
+        HANDCRAFTED_TENSORS_BAD_SHAPE,
+        HANDCRAFTED_TENSORS_NE_TOO_BIG,
+        HANDCRAFTED_TENSORS_BAD_TYPE,
+        HANDCRAFTED_TENSORS_BAD_OFFSET,
+        HANDCRAFTED_TENSORS_DUPLICATE_NAME,
+        HANDCRAFTED_TENSORS_BAD_ALIGN,
+        HANDCRAFTED_TENSORS_INCONSISTENT_ALIGN,
+        HANDCRAFTED_TENSORS_SUCCESS,
+        HANDCRAFTED_TENSORS_CUSTOM_ALIGN,
+
+        HANDCRAFTED_DATA_NOT_ENOUGH_DATA,
+        HANDCRAFTED_DATA_BAD_ALIGN,
+        HANDCRAFTED_DATA_INCONSISTENT_ALIGN,
+        HANDCRAFTED_DATA_SUCCESS,
+        HANDCRAFTED_DATA_CUSTOM_ALIGN,
+    };
+
+    for (enum handcrafted_file_type hft : hfts) {
+        printf("%s: handcrafted_file_type=%s\n", __func__, handcrafted_file_type_name(hft).c_str());
+        FILE * file = get_handcrafted_file(seed, hft);
+
+#ifdef _WIN32
+        if (!file) {
+            printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
+            printf("%s: skipping tests");
+            continue;
+        }
+#else
+        GGML_ASSERT(file);
+#endif // _WIN32
+
+        struct ggml_context * ctx = nullptr;
+        struct gguf_init_params gguf_params = {
+            /*no_alloc =*/ false,
+            /*ctx      =*/ hft >= offset_has_data ? &ctx : nullptr,
+        };
+
+        struct gguf_context * gguf_ctx = gguf_init_from_file_impl(file, gguf_params);
+
+        if (expect_context_not_null(hft)) {
+            printf("%s:   - context_not_null: ", __func__);
+        } else {
+            printf("%s:   - context_null: ", __func__);
+        }
+        if (bool(gguf_ctx) == expect_context_not_null(hft)) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+
+        if (hft >= offset_has_data && !expect_context_not_null(hft)) {
+            printf("%s:   - no_dangling_ggml_context_pointer: ", __func__);
+            if (ctx) {
+                printf("\033[1;31mFAIL\033[0m\n");
+            } else {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            }
+            ntest++;
+        }
+
+        const bool alignment_defined = hft == HANDCRAFTED_TENSORS_CUSTOM_ALIGN || hft == HANDCRAFTED_DATA_CUSTOM_ALIGN;
+
+        if (expect_context_not_null(hft)) {
+            printf("%s:   - check_header: ", __func__);
+            if (handcrafted_check_header(gguf_ctx, seed, hft >= offset_has_kv, hft >= offset_has_tensors, alignment_defined)) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+
+        if (expect_context_not_null(hft) && hft >= offset_has_kv) {
+            printf("%s:   - check_kv: ", __func__);
+            if (handcrafted_check_kv(gguf_ctx, seed, hft >= offset_has_tensors, alignment_defined)) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+
+        if (expect_context_not_null(hft) && hft >= offset_has_tensors) {
+            printf("%s:   - check_tensors: ", __func__);
+            if (handcrafted_check_tensors(gguf_ctx, seed)) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+
+        if (expect_context_not_null(hft) && hft >= offset_has_data) {
+            printf("%s:   - check_tensor_data: ", __func__);
+            if (handcrafted_check_tensor_data(gguf_ctx, seed, file)) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
+
+        fclose(file);
+        if (gguf_ctx) {
+            ggml_free(ctx);
+            gguf_free(gguf_ctx);
+        }
+        printf("\n");
+    }
+
+
+    return std::make_pair(npass, ntest);
+}
+
+struct random_gguf_context_result {
+    struct gguf_context * gguf_ctx;
+    struct ggml_context * ctx;
+    ggml_backend_buffer_t buffer;
+};
+
+static struct random_gguf_context_result get_random_gguf_context(ggml_backend_t backend, const unsigned int seed) {
+    std::mt19937 rng(seed);
+
+    struct gguf_context * gguf_ctx = gguf_init_empty();
+
+    for (int i = 0; i < 256; ++i) {
+        const std::string key = "my_key_" + std::to_string(rng() % 1024);
+        const enum gguf_type type = gguf_type(rng() % GGUF_TYPE_COUNT);
+
+        switch (type) {
+            case GGUF_TYPE_UINT8:   gguf_set_val_u8  (gguf_ctx, key.c_str(), rng() % (1 <<  7));             break;
+            case GGUF_TYPE_INT8:    gguf_set_val_i8  (gguf_ctx, key.c_str(), rng() % (1 <<  7) - (1 <<  6)); break;
+            case GGUF_TYPE_UINT16:  gguf_set_val_u16 (gguf_ctx, key.c_str(), rng() % (1 << 15));             break;
+            case GGUF_TYPE_INT16:   gguf_set_val_i16 (gguf_ctx, key.c_str(), rng() % (1 << 15) - (1 << 14)); break;
+            case GGUF_TYPE_UINT32:  gguf_set_val_u32 (gguf_ctx, key.c_str(), rng());                         break;
+            case GGUF_TYPE_INT32:   gguf_set_val_i32 (gguf_ctx, key.c_str(), rng()             - (1 << 30)); break;
+            case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (gguf_ctx, key.c_str(), rng() % 1024      - 512);       break;
+            case GGUF_TYPE_BOOL:    gguf_set_val_bool(gguf_ctx, key.c_str(), rng() % 2 == 0);                break;
+            case GGUF_TYPE_STRING:  gguf_set_val_str (gguf_ctx, key.c_str(), std::to_string(rng()).c_str()); break;
+            case GGUF_TYPE_UINT64:  gguf_set_val_u64 (gguf_ctx, key.c_str(), rng());                         break;
+            case GGUF_TYPE_INT64:   gguf_set_val_i64 (gguf_ctx, key.c_str(), rng()             - (1 << 30)); break;
+            case GGUF_TYPE_FLOAT64: gguf_set_val_f32 (gguf_ctx, key.c_str(), rng() % 1024      - 512);       break;
+            case GGUF_TYPE_ARRAY: {
+                const enum gguf_type type_arr = gguf_type(rng() % GGUF_TYPE_COUNT);
+                const uint64_t ne = rng() % 1024;
+
+                switch (type_arr) {
+                    case GGUF_TYPE_UINT8:
+                    case GGUF_TYPE_INT8:
+                    case GGUF_TYPE_UINT16:
+                    case GGUF_TYPE_INT16:
+                    case GGUF_TYPE_UINT32:
+                    case GGUF_TYPE_INT32:
+                    case GGUF_TYPE_FLOAT32:
+                    case GGUF_TYPE_BOOL:
+                    case GGUF_TYPE_UINT64:
+                    case GGUF_TYPE_INT64:
+                    case GGUF_TYPE_FLOAT64: {
+                        const size_t nbytes = ne*gguf_type_size(type_arr);
+                        std::vector random_data((nbytes + sizeof(uint32_t) - 1) / sizeof(uint32_t));
+                        for (size_t j = 0; j < random_data.size(); ++j) {
+                            random_data[j] = rng();
+                            if (type_arr == GGUF_TYPE_BOOL) {
+                                random_data[j] &= 0x01010101; // the sanitizer complains if booleans are not 0 or 1
+                            }
+                        }
+                        gguf_set_arr_data(gguf_ctx, key.c_str(), type_arr, random_data.data(), ne);
+                    } break;
+                    case GGUF_TYPE_STRING: {
+                        std::vector  data_cpp(ne);
+                        std::vector data_c(ne);
+                        for (size_t j = 0; j < data_cpp.size(); ++j) {
+                            data_cpp[j] = std::to_string(rng());
+                            data_c[j]   = data_cpp[j].c_str();
+                        }
+                        gguf_set_arr_str(gguf_ctx, key.c_str(), data_c.data(), ne);
+                    } break;
+                    case GGUF_TYPE_ARRAY: {
+                        break; // not supported
+                    }
+                    case GGUF_TYPE_COUNT:
+                    default: {
+                        GGML_ABORT("fatal error");
+                    }
+                }
+            } break;
+            case GGUF_TYPE_COUNT:
+            default: {
+                GGML_ABORT("fatal error");
+            }
+        }
+    }
+
+    struct ggml_init_params ggml_params = {
+        /*.mem_size   =*/ 256*ggml_tensor_overhead(),
+        /*.mem_buffer =*/ nullptr,
+        /*.no_alloc   =*/ true,
+    };
+    struct ggml_context * ctx = ggml_init(ggml_params);
+
+    for (int i = 0; i < 256; ++i) {
+        const std::string name = "my_tensor_" + std::to_string(i);
+        const enum ggml_type type = ggml_type(rng() % GGML_TYPE_COUNT);
+        const size_t type_size = ggml_type_size(type);
+
+        if (type_size == 0) {
+            continue;
+        }
+
+        const int n_dims = 1 + rng() % GGML_MAX_DIMS;
+        int64_t ne[GGML_MAX_DIMS];
+        ne[0] = (1 + rng() % 10) * ggml_blck_size(type);
+        for (int j = 1; j < n_dims; ++j) {
+            ne[j] = 1 + rng() % 10;
+        }
+
+        struct ggml_tensor * tensor = ggml_new_tensor(ctx, type, n_dims, ne);
+        ggml_set_name(tensor, name.c_str());
+    }
+
+    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
+    for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
+        const size_t nbytes = ggml_nbytes(t);
+        std::vector random_data((nbytes + sizeof(uint32_t) - 1) / sizeof(uint32_t));
+        for (size_t j = 0; j < random_data.size(); ++j) {
+            random_data[j] = rng();
+        }
+        ggml_backend_tensor_set(t, random_data.data(), 0, nbytes);
+
+        gguf_add_tensor(gguf_ctx, t);
+    }
+
+    return {gguf_ctx, ctx, buf};
+}
+
+static bool all_kv_in_other(const gguf_context * ctx, const gguf_context * other) {
+    bool ok = true;
+
+    const int n_kv = gguf_get_n_kv(ctx);
+    for (int id = 0; id < n_kv; ++id) {
+        const char * name = gguf_get_key(ctx, id);
+
+        const int idx_other = gguf_find_key(other, name);
+        if (idx_other < 0) {
+            ok = false;
+            continue;
+        }
+
+        const gguf_type type = gguf_get_kv_type(ctx, id);
+        if (type != gguf_get_kv_type(other, idx_other)) {
+            ok = false;
+            continue;
+        }
+
+        if (type == GGUF_TYPE_ARRAY) {
+            const size_t arr_n = gguf_get_arr_n(ctx, id);
+            if (arr_n != gguf_get_arr_n(other, idx_other)) {
+                ok = false;
+                continue;
+            }
+
+            const gguf_type type_arr = gguf_get_arr_type(ctx, id);
+            if (type_arr != gguf_get_arr_type(other, idx_other)) {
+                ok = false;
+                continue;
+            }
+
+            if (type_arr == GGUF_TYPE_BOOL) {
+                const int8_t * data       = reinterpret_cast(gguf_get_arr_data(ctx,   id));
+                const int8_t * data_other = reinterpret_cast(gguf_get_arr_data(other, idx_other));
+                for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    if (bool(data[arr_i]) != bool(data_other[arr_i])) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
+            if (type_arr == GGUF_TYPE_STRING) {
+                for (size_t arr_i = 0; arr_i < arr_n; ++arr_i) {
+                    const std::string str       = gguf_get_arr_str(ctx,   id,       arr_i);
+                    const std::string str_other = gguf_get_arr_str(other, idx_other, arr_i);
+                    if (str != str_other) {
+                        ok = false;
+                    }
+                }
+                continue;
+            }
+
+            const int8_t * data       = reinterpret_cast(gguf_get_arr_data(ctx,   id));
+            const int8_t * data_other = reinterpret_cast(gguf_get_arr_data(other, idx_other));
+            if (!std::equal(data, data + arr_n*gguf_type_size(type_arr), data_other)) {
+                ok = false;
+            }
+            continue;
+        }
+
+        if (type == GGUF_TYPE_STRING) {
+            const std::string str       = gguf_get_val_str(ctx,   id);
+            const std::string str_other = gguf_get_val_str(other, idx_other);
+            if (str != str_other) {
+                ok = false;
+            }
+            continue;
+        }
+
+        const char * data       = reinterpret_cast(gguf_get_val_data(ctx,   id));
+        const char * data_other = reinterpret_cast(gguf_get_val_data(other, idx_other));
+        if (!std::equal(data, data + gguf_type_size(type), data_other)) {
+            ok = false;
+        }
+    }
+
+    return ok;
+}
+
+static bool all_tensors_in_other(const gguf_context * ctx, const gguf_context * other) {
+    bool ok = true;
+
+    const int n_tensors = gguf_get_n_tensors(ctx);
+    for (int id = 0; id < n_tensors; ++id) {
+        const std::string name = gguf_get_tensor_name(ctx, id);
+
+        const int idx_other = gguf_find_tensor(other, name.c_str());
+        if (id != idx_other) {
+            ok = false;
+            if (idx_other < 0) {
+                continue;
+            }
+        }
+
+        const ggml_type type = gguf_get_tensor_type(ctx, id);
+        if (type != gguf_get_tensor_type(other, id)) {
+            ok = false;
+        }
+
+        const size_t offset = gguf_get_tensor_offset(ctx, id);
+        if (offset != gguf_get_tensor_offset(other, id)) {
+            ok = false;
+        }
+    }
+
+    return ok;
+}
+
+static bool same_tensor_data(const struct ggml_context * orig, const struct ggml_context * read) {
+    bool ok = true;
+
+    struct ggml_tensor * t_orig = ggml_get_first_tensor(orig);
+    struct ggml_tensor * t_read = ggml_get_first_tensor(read);
+
+    if (std::string(t_read->name) != "GGUF tensor data binary blob") {
+        return false;
+    }
+    t_read = ggml_get_next_tensor(read, t_read);
+
+    while (t_orig) {
+        if (!t_read) {
+            ok = false;
+            break;
+        }
+
+        const size_t nbytes = ggml_nbytes(t_orig);
+        if (ggml_nbytes(t_read) != nbytes) {
+            ok = false;
+            break;
+        }
+        std::vector data_orig(nbytes);
+        ggml_backend_tensor_get(t_orig, data_orig.data(), 0, nbytes);
+        if (!std::equal(data_orig.data(), data_orig.data() + nbytes, reinterpret_cast(t_read->data))) {
+            ok = false;
+        }
+
+        t_orig = ggml_get_next_tensor(orig, t_orig);
+        t_read = ggml_get_next_tensor(read, t_read);
+    }
+    if (t_read) {
+        ok = false;
+    }
+
+    return ok;
+}
+
+static std::pair test_roundtrip(ggml_backend_dev_t dev, const unsigned int seed, const bool only_meta) {
+    ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+    printf("%s: device=%s, backend=%s, only_meta=%s\n",
+        __func__, ggml_backend_dev_description(dev), ggml_backend_name(backend), only_meta ? "yes" : "no");
+
+    int npass = 0;
+    int ntest = 0;
+
+    struct gguf_context * gguf_ctx_0;
+    struct ggml_context * ctx_0;
+    ggml_backend_buffer_t bbuf;
+    {
+        struct random_gguf_context_result result = get_random_gguf_context(backend, seed);
+        gguf_ctx_0 = result.gguf_ctx;
+        ctx_0      = result.ctx;
+        bbuf       = result.buffer;
+    }
+
+    FILE * file = tmpfile();
+
+#ifdef _WIN32
+    if (!file) {
+        printf("%s: failed to create tmpfile(), needs elevated privileges on Windows");
+        printf("%s: skipping tests");
+        return std::make_pair(0, 0);
+    }
+#else
+    GGML_ASSERT(file);
+#endif // _WIN32
+
+    {
+        std::vector buf;
+        gguf_write_to_buf(gguf_ctx_0, buf, only_meta);
+        GGML_ASSERT(fwrite(buf.data(), 1, buf.size(), file) == buf.size());
+        rewind(file);
+    }
+
+    struct ggml_context * ctx_1 = nullptr;
+    struct gguf_init_params gguf_params = {
+        /*no_alloc =*/ false,
+        /*ctx      =*/ only_meta ? nullptr : &ctx_1,
+    };
+    struct gguf_context * gguf_ctx_1 = gguf_init_from_file_impl(file, gguf_params);
+
+    printf("%s: same_version: ", __func__);
+    if (gguf_get_version(gguf_ctx_0) == gguf_get_version(gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: same_n_kv: ", __func__);
+    if (gguf_get_n_kv(gguf_ctx_0) == gguf_get_n_kv(gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: same_n_tensors: ", __func__);
+    if (gguf_get_n_tensors(gguf_ctx_0) == gguf_get_n_tensors(gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_orig_kv_in_read: ", __func__);
+    if (all_kv_in_other(gguf_ctx_0, gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_read_kv_in_orig: ", __func__);
+    if (all_kv_in_other(gguf_ctx_1, gguf_ctx_0)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_orig_tensors_in_read: ", __func__);
+    if (all_tensors_in_other(gguf_ctx_0, gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_read_tensors_in_orig: ", __func__);
+    if (all_tensors_in_other(gguf_ctx_1, gguf_ctx_0)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    if (!only_meta) {
+        printf("%s: same_tensor_data: ", __func__);
+        if (same_tensor_data(ctx_0, ctx_1)) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+    }
+
+    ggml_backend_buffer_free(bbuf);
+    ggml_free(ctx_0);
+    ggml_free(ctx_1);
+    gguf_free(gguf_ctx_0);
+    gguf_free(gguf_ctx_1);
+    ggml_backend_free(backend);
+    fclose(file);
+
+    printf("\n");
+    return std::make_pair(npass, ntest);
+}
+
+static std::pair test_gguf_set_kv(ggml_backend_dev_t dev, const unsigned int seed) {
+    ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
+    printf("%s: device=%s, backend=%s\n", __func__, ggml_backend_dev_description(dev), ggml_backend_name(backend));
+
+    int npass = 0;
+    int ntest = 0;
+
+    struct gguf_context * gguf_ctx_0;
+    struct ggml_context * ctx_0;
+    ggml_backend_buffer_t bbuf_0;
+    {
+        struct random_gguf_context_result result = get_random_gguf_context(backend, seed);
+        gguf_ctx_0 = result.gguf_ctx;
+        ctx_0      = result.ctx;
+        bbuf_0     = result.buffer;
+    }
+
+    struct gguf_context * gguf_ctx_1;
+    struct ggml_context * ctx_1;
+    ggml_backend_buffer_t bbuf_1;
+    {
+        struct random_gguf_context_result result = get_random_gguf_context(backend, seed + 1);
+        gguf_ctx_1 = result.gguf_ctx;
+        ctx_1      = result.ctx;
+        bbuf_1     = result.buffer;
+    }
+
+    struct gguf_context * gguf_ctx_2 = gguf_init_empty();
+
+    gguf_set_kv(gguf_ctx_1, gguf_ctx_0);
+    gguf_set_kv(gguf_ctx_2, gguf_ctx_0);
+
+    printf("%s: same_n_kv: ", __func__);
+    if (gguf_get_n_kv(gguf_ctx_0) == gguf_get_n_kv(gguf_ctx_2)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_kv_0_in_1: ", __func__);
+    if (all_kv_in_other(gguf_ctx_0, gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_kv_0_in_2: ", __func__);
+    if (all_kv_in_other(gguf_ctx_0, gguf_ctx_2)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    gguf_set_kv(gguf_ctx_0, gguf_ctx_1);
+
+    printf("%s: same_n_kv_after_double_copy: ", __func__);
+    if (gguf_get_n_kv(gguf_ctx_0) == gguf_get_n_kv(gguf_ctx_1)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    printf("%s: all_kv_1_in_0_after_double_copy: ", __func__);
+    if (all_kv_in_other(gguf_ctx_1, gguf_ctx_0)) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    ggml_backend_buffer_free(bbuf_0);
+    ggml_backend_buffer_free(bbuf_1);
+    ggml_free(ctx_0);
+    ggml_free(ctx_1);
+    gguf_free(gguf_ctx_0);
+    gguf_free(gguf_ctx_1);
+    gguf_free(gguf_ctx_2);
+    ggml_backend_free(backend);
+
+    printf("\n");
+    return std::make_pair(npass, ntest);
+}
+
+static void print_usage() {
+    printf("usage: test-gguf [seed]\n");
+    printf("  if no seed is unspecified then a random seed is used\n");
+}
+
+int main(int argc, char ** argv) {
+    if (argc > 2) {
+        print_usage();
+        return 1;
+    }
+
+    std::random_device rd;
+    const unsigned int seed = argc < 2 ? rd() : std::stoi(argv[1]);
+
+    // Initialize ggml backends early so the prints aren't interleaved with the test results:
+    ggml_backend_dev_count();
+    fprintf(stderr, "\n");
+
+    int npass = 0;
+    int ntest = 0;
+    {
+        std::pair result = test_handcrafted_file(seed);
+        npass += result.first;
+        ntest += result.second;
+    }
+
+    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+
+        for (bool only_meta : {true, false}) {
+            std::pair result = test_roundtrip(dev, seed, only_meta);
+            npass += result.first;
+            ntest += result.second;
+        }
+
+        {
+            std::pair result = test_gguf_set_kv(dev, seed);
+            npass += result.first;
+            ntest += result.second;
+        }
+    }
+
+    printf("%d/%d tests passed\n", npass, ntest);
+    if (npass != ntest) {
+        printf("\033[1;31mFAIL\033[0m\n");
+        return 1;
+    }
+    printf("\033[1;32mOK\033[0m\n");
+    return 0;
+}
diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp
deleted file mode 100644
index 1834c11d8..000000000
--- a/tests/test-grad0.cpp
+++ /dev/null
@@ -1,1679 +0,0 @@
-#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
-#include "ggml.h"
-
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-#include 
-
-#if defined(_MSC_VER)
-#pragma warning(disable: 4244 4267) // possible loss of data
-#endif
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Wdouble-promotion"
-#endif
-
-#define MAX_NARGS 3
-
-#undef MIN
-#undef MAX
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-
-#define GGML_SILU_FP16
-
-//
-// logging
-//
-
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
-#if (GGML_DEBUG >= 5)
-#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_5(...)
-#endif
-
-#if (GGML_DEBUG >= 10)
-#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_10(...)
-#endif
-
-#define GGML_PRINT(...) printf(__VA_ARGS__)
-
-static float frand(void) {
-    return (float)rand()/(float)RAND_MAX;
-}
-
-static int irand(int n) {
-    if (n == 0) return 0;
-    return rand()%n;
-}
-
-static void get_random_dims(int64_t * dims, int ndims) {
-    dims[0] = dims[1] = dims[2] = dims[3] = 1;
-
-    for (int i = 0; i < ndims; i++) {
-        dims[i] = 1 + irand(4);
-    }
-}
-
-static struct ggml_tensor * get_random_tensor_f32(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        float fmin,
-        float fmax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    }
-
-    return result;
-}
-
-static struct ggml_tensor * get_random_tensor_f16(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        float fmin,
-        float fmax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    }
-
-    return result;
-}
-
-static struct ggml_tensor * get_random_tensor_i32(
-        struct ggml_context * ctx0,
-        int ndims,
-        int64_t ne[],
-        int32_t imin,
-        int32_t imax) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_I32, ndims, ne);
-
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((int32_t *)result->data)[i0] = irand(imax - imin) + imin;
-            }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((int32_t *)result->data)[i1*ne[0] + i0] = irand(imax - imin) + imin;
-                }
-            }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((int32_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
-                    }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((int32_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin;
-                        }
-                    }
-                }
-            }
-            break;
-        default:
-            assert(false);
-    }
-
-    return result;
-}
-
-static bool check_gradient(
-        const char * op_name,
-        struct ggml_context * ctx0,
-        struct ggml_tensor * x[],
-        struct ggml_tensor * f,
-        int ndims,
-        int nargs,
-        float eps,
-        float max_error_abs,
-        float max_error_rel,
-        std::vector expected_vals) {
-
-    static int n_threads = -1;
-    if (n_threads < 0) {
-        n_threads = GGML_DEFAULT_N_THREADS;
-
-        const char *env = getenv("GGML_N_THREADS");
-        if (env) {
-            n_threads = atoi(env);
-        }
-
-        printf("GGML_N_THREADS = %d\n", n_threads);
-    }
-
-    struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-    struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true);
-    ggml_build_forward_expand(gf, f);
-    ggml_graph_cpy(gf, gb);
-    ggml_build_backward_expand(ctx0, gf, gb, false);
-
-    ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
-
-    ggml_graph_reset  (gf);
-    ggml_set_f32      (f->grad, 1.0f);
-
-    ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-    // ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot");
-    // ggml_graph_dump_dot(gb, gf,  "test-grad0-backward.dot");
-
-    for (int i = 0; i < nargs; ++i) {
-        bool all_g0_bad = true;
-        const int nelements = ggml_nelements(x[i]);
-        for (int k = 0; k < nelements; ++k) {
-            // Calculate gradient numerically:
-            const float x0 = ggml_get_f32_1d(x[i], k);
-            const float xm = x0 - eps;
-            const float xp = x0 + eps;
-            ggml_set_f32_1d(x[i], k, xp);
-
-            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
-
-            const double f0 = ggml_get_f32_1d(f, 0);
-
-            ggml_set_f32_1d(x[i], k, xm);
-
-            ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
-
-            const double f1 = ggml_get_f32_1d(f, 0);
-            const double g0 = (f0 - f1)/(2.0*(double) eps);
-
-            // The numerical calculation of the gradient fails around noncontinuities (e.g. 0 for ReLU).
-            // In such cases, provide a vector of expected values and skip the comparison for failed calculations.
-            if (!expected_vals.empty()) {
-                bool matches_any = false;
-                for (const double & ev : expected_vals) {
-                    const double error_abs = std::fabs(g0 - ev);
-                    if (error_abs > max_error_abs) {
-                        continue;
-                    }
-                    const double error_rel = g0 != 0.0 ? fabs(g0 - ev)/fabs(g0) : 0.0;
-                    if (error_rel > max_error_rel) {
-                        continue;
-                    }
-                    matches_any = true;
-                    break;
-                }
-                if (!matches_any) {
-                    continue;
-                }
-            }
-            all_g0_bad = false;
-
-            ggml_set_f32_1d(x[i], k, x0);
-
-            // compute gradient using backward graph
-            ggml_graph_reset  (gf);
-            ggml_set_f32      (f->grad, 1.0f);
-
-            ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
-
-            const double g1 = ggml_get_f32_1d(x[i]->grad, k);
-
-            const double error_abs = fabs(g0 - g1);
-            const double error_rel = g0 != 0.0 ? fabs(g0 - g1)/fabs(g0) : 0.0;
-
-            if (error_abs > max_error_abs || error_rel > max_error_rel) {
-                printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n",
-                            op_name, ndims, i, k, x0, xm, xp, f0, f1, g0, g1, eps, error_abs, error_rel);
-                //assert(false);
-                return false;
-            }
-        }
-        if (all_g0_bad) {
-            printf("%s: numerical calculation of the gradient failed for all values\n", op_name);
-            return false;
-        }
-    }
-
-    return true;
-}
-
-// TODO: clean-up this ..
-static bool check_mat_mul(
-        const struct ggml_tensor * y,
-        const struct ggml_tensor * x0,
-        const struct ggml_tensor * x1) {
-    float * dst  = (float *) y->data;
-    float * src0 = (float *) x0->data;
-    float * src1 = (float *) x1->data;
-
-    const int nc = x0->ne[1];
-    const int nr = x1->ne[1];
-    const int nk = x0->ne[0];
-
-    GGML_PRINT_DEBUG("check_mat_mul: nc=%d, nr=%d, nk=%d\n", nc, nr, nk);
-
-    GGML_PRINT_DEBUG("x0:\n");
-    for (int j = 0; j < x0->ne[1]; ++j) {
-        for (int i = 0; i < x0->ne[0]; ++i) {
-            GGML_PRINT_DEBUG("%6.3f ", src0[j*nk + i]);
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-    GGML_PRINT_DEBUG("\n");
-
-    GGML_PRINT_DEBUG("x1:\n");
-    for (int j = 0; j < x1->ne[1]; ++j) {
-        for (int i = 0; i < x1->ne[0]; ++i) {
-            GGML_PRINT_DEBUG("%6.3f ", src1[j*nk + i]);
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-    GGML_PRINT_DEBUG("\n");
-
-    GGML_PRINT_DEBUG("y: n_dims = %d, (%lld, %lld)\n", y->n_dims, y->ne[0], y->ne[1]);
-    for (int j = 0; j < y->ne[1]; ++j) {
-        for (int i = 0; i < y->ne[0]; ++i) {
-            GGML_PRINT_DEBUG("%6.3f ", dst[j*nr + i]);
-        }
-        GGML_PRINT_DEBUG("\n");
-    }
-
-    for (int i = 0; i < nr; ++i) {
-        for (int j = 0; j < nc; ++j) {
-            float sum = 0.0f;
-
-            for (int k = 0; k < nk; ++k) {
-                sum += src0[j*nk + k]*src1[i*nk + k];
-            }
-
-            if (fabsf(dst[i*nc + j] - sum) > 1e-5f) {
-                fprintf(stderr, "check_mat_mul: dst[%d] = %f, sum = %f\n", i*nc + j, dst[i*nc + j], sum);
-                assert(false);
-                return false;
-            }
-        }
-    }
-
-    return true;
-}
-
-#define NUM_PERMUTATIONS (4*3*2*1)
-
-int main(int argc, const char ** argv) {
-    struct ggml_init_params params = {
-        /* .mem_size   = */ 256*1024*1024,
-        /* .mem_buffer = */ NULL,
-        /* .no_alloc   = */ false,
-    };
-
-    int64_t ne[4];
-
-    int all_permutations[4 * NUM_PERMUTATIONS];
-    {
-        int count = 0;
-        for (int ax0=0; ax0<4; ++ax0) {
-            for (int ax1=0; ax1<4; ++ax1) {
-                if (ax1 == ax0) continue;
-                for (int ax2=0; ax2<4; ++ax2) {
-                    if (ax2 == ax0) continue;
-                    if (ax2 == ax1) continue;
-                    for (int ax3=0; ax3<4; ++ax3) {
-                        if (ax3 == ax0) continue;
-                        if (ax3 == ax1) continue;
-                        if (ax3 == ax2) continue;
-                        assert(count < NUM_PERMUTATIONS);
-                        all_permutations[count*4+0] = ax0;
-                        all_permutations[count*4+1] = ax1;
-                        all_permutations[count*4+2] = ax2;
-                        all_permutations[count*4+3] = ax3;
-                        ++count;
-                    }
-                }
-            }
-        }
-    }
-
-    unsigned seed_iter = 1;
-
-    // original loop: 1000
-    int niter = 4;
-    const char *env = getenv("GGML_NLOOP");
-    if (env != NULL) {
-        niter = atoi(env);
-    }
-    if (argc > 1) {
-        niter = atoi(argv[1]);
-    }
-    for (int iter = 0; iter < niter; ++iter) {
-        srand(seed_iter);
-        seed_iter = rand();
-        unsigned seed = rand();
-
-        printf("test-grad0: iter:%d/%d\n", (iter+1), niter);
-        struct ggml_context * ctx0 = ggml_init(params);
-
-        get_random_dims(ne, 4);
-
-        struct ggml_tensor * x[MAX_NARGS];
-
-        // add f32
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
-
-                check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f, {});
-            }
-        }
-
-        // add f16
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
-
-                check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f, {});
-            }
-        }
-
-        // sub
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1]));
-
-                check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // mul
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1]));
-
-                check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // div
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1]));
-
-                check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f, {});
-            }
-        }
-
-        // sqr
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0]));
-
-                check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // sqrt
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0]));
-
-                check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f, {});
-            }
-        }
-
-        // log
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0]));
-
-                check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f, {});
-            }
-        }
-
-        // sum
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, x[0]);
-
-                check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-
-        // sum_rows
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0])));
-
-                check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
-            }
-        }
-
-        // mean, not yet fully implemented
-        if(0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
-
-                check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // argmax
-        if (0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
-
-                check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // repeat
-        {
-            srand(seed);
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            ne2[0] = ne[0] * ne2[0];
-            ne2[1] = ne[1] * ne2[1];
-            ne2[2] = 1;
-            ne2[3] = 1;
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
-
-                check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
-            }
-        }
-
-        // repeat back
-        {
-            srand(seed);
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            ne2[0] = ne[0] * ne2[0];
-            ne2[1] = ne[1] * ne2[1];
-            ne2[2] = 1;
-            ne2[3] = 1;
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
-
-                check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {});
-            }
-        }
-
-        // abs
-        {
-           const int nargs = 1;
-
-           for (int ndims = 1; ndims <= 4; ++ndims) {
-               for (int i = 0; i < nargs; ++i) {
-                   x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                   ggml_set_param(ctx0, x[i]);
-               }
-
-               struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0]));
-
-               check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f, {-1.0, 1.0});
-           }
-        }
-
-        // sgn
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
-
-                check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
-            }
-        }
-
-        // neg
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
-
-                check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // step
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
-
-                check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0});
-            }
-        }
-
-        // tanh, not yet fully implemented
-        if(0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
-
-                check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // mul_mat
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 2; ndims <= 4; ++ndims) {
-                int max_nrep = (ndims >= 3) ? 2 : 1;
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                for (int nrep2 = 1; nrep2 < max_nrep; ++nrep2) {
-                    for (int nrep3 = 1; nrep3 < max_nrep; ++nrep3) {
-                        {
-                            int64_t ne2[4];
-                            get_random_dims(ne2, 4);
-                            ne2[0] = ne[0];
-                            ne2[2] = nrep2 * ne[2];
-                            ne2[3] = nrep3 * ne[3];
-                            x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                        }
-
-                        ggml_set_param(ctx0, x[0]);
-                        ggml_set_param(ctx0, x[1]);
-
-                        struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]);
-                        struct ggml_tensor * f = ggml_sum(ctx0, m);
-
-                        GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims);
-
-                        check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-                        if (ndims == 2) {
-                            // check_mat_mul does not support ndims > 2
-                            check_mat_mul(m, x[1], x[0]);
-                        }
-                    }
-                }
-            }
-        }
-
-        // elu, not yet fully implemented
-        if(0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
-
-                check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // relu
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
-
-                check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {0.0, 1.0});
-            }
-        }
-
-        // gelu, not yet fully implemented
-        if(0)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
-
-                check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {});
-            }
-        }
-
-        // silu
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_silu(ctx0, x[0]));
-
-#ifdef GGML_SILU_FP16
-                // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds.
-                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY, {});
-#else
-                check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-#endif
-            }
-        }
-
-        // rms_norm
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
-
-                check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY, {});
-            }
-        }
-
-        // scale
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                const float s = -1.0f + 2.0f*frand();
-
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s));
-
-                check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // cpy f32
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
-
-                check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // cpy f16
-        {
-            srand(seed);
-            const int nargs = 2;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                for (int i = 0; i < nargs; ++i) {
-                    x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
-                    ggml_set_param(ctx0, x[i]);
-                }
-                // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
-
-                check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
-            }
-        }
-
-        // reshape (1d->nd)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                int64_t ne2[4];
-                ne2[0] = 1;
-                ne2[1] = 1;
-                ne2[2] = 1;
-                ne2[3] = 1;
-                for (int i = 0; i < ndims; ++i) {
-                    ne2[0] *= ne[i];
-                }
-                x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
-                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // reshape (nd->1d)
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            for (int ndims = 1; ndims <= 2; ++ndims) {
-                int64_t ne2[4];
-                ne2[0] = 1;
-                ne2[1] = 1;
-                ne2[2] = 1;
-                ne2[3] = 1;
-                for (int i = 0; i < ndims; ++i) {
-                    ne2[0] *= ne[i];
-                }
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1]));
-                check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // acc 1d
-        {
-            srand(seed);
-            int64_t ne2[4] = { 1, 1, 1, 1 };
-
-            const int nargs = 2;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 1);
-                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 1);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
-                const int offset = irand(max_offset) * ggml_element_size(x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // acc 2d
-        {
-            srand(seed);
-            int64_t ne2[4]         = { 1, 1, 1, 1 };
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 2;
-            for (int ndims = 2; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 2);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 2);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                const int offset = offsets[0] + offsets[1];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // acc 3d
-        {
-            srand(seed);
-            int64_t ne2[4]         = { 1, 1, 1, 1 };
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 2;
-            for (int ndims = 3; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 3);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 3);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
-                const int offset = offsets[0] + offsets[1] + offsets[2];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // acc 4d
-        {
-            srand(seed);
-            int64_t ne2[4]         = { 1, 1, 1, 1 };
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 2;
-            for (int ndims = 4; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 4);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[3] > ne[3]) || (ne2[0]*ne2[1]*ne2[2]*ne2[3] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 4);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]);
-                max_offsets[3] = MAX(0, x[0]->ne[3] - x[1]->ne[3]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                offsets[2] = irand(max_offsets[2]) * x[0]->nb[2];
-                offsets[3] = irand(max_offsets[3]) * x[0]->nb[3];
-                const int offset = offsets[0] + offsets[1] + offsets[2] + offsets[3];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset));
-
-                check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // set_1d
-        {
-            srand(seed);
-            int64_t ne2[4];
-
-            const int nargs = 2;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 1);
-                while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 1);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
-                const int offset = irand(max_offset) * ggml_element_size(x[0]);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset));
-
-                check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // set_2d
-        {
-            srand(seed);
-            int64_t ne2[4];
-            int64_t max_offsets[4] = { 0, 0, 0, 0 };
-            int64_t offsets[4]     = { 0, 0, 0, 0 };
-
-            const int nargs = 1;
-            for (int ndims = 2; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[0]);
-
-                get_random_dims(ne2, 2);
-                while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) {
-                    get_random_dims(ne2, 2);
-                }
-
-                x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
-                ggml_set_param(ctx0, x[1]);
-
-                max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
-                max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]);
-                offsets[0] = irand(max_offsets[0]) * x[0]->nb[0];
-                offsets[1] = irand(max_offsets[1]) * x[0]->nb[1];
-                const int offset = offsets[0] + offsets[1];
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset));
-
-                check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // view_1d
-        {
-            srand(seed);
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int k0 = irand(ggml_nelements(x[0]));
-                const int k1 = irand(ggml_nelements(x[0]));
-                const int i0 = MIN(k0, k1);
-                const int i1 = MAX(k0, k1);
-
-                const int offset = i0 * sizeof(float);
-                const int nelem  = i1 - i0;
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset));
-
-                check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // view_2d
-        {
-            srand(seed);
-            int64_t ne2[4];
-            int64_t nb2[4];
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                get_random_dims(ne2, 2);
-                while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
-                    get_random_dims(ne2, 2);
-                }
-                const int count = ne2[0]*ne2[1];
-
-                nb2[0] = sizeof(float);
-                nb2[1] = nb2[0]*ne2[0];
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int max_offset = ggml_nelements(x[0]) - count;
-                const int offset = irand(max_offset+1) * sizeof(float);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset));
-
-                check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // view_3d
-        {
-            srand(seed);
-            int64_t ne2[4] = {1,1,1,1};
-            int64_t nb2[4] = {0,0,0,0};
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
-
-                get_random_dims(ne2, 3);
-                while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
-                    get_random_dims(ne2, 3);
-                }
-                const int count = ne2[0]*ne2[1]*ne2[2];
-
-                nb2[0] = sizeof(float);
-                nb2[1] = nb2[0]*ne2[0];
-                nb2[2] = nb2[1]*ne2[1];
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int max_offset = ggml_nelements(x[0]) - count;
-                const int offset = irand(max_offset+1) * sizeof(float);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset));
-
-                check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // permute
-        {
-            srand(seed);
-            int64_t ne2[4];
-
-            const int nargs = 1;
-            for (int ndims = 1; ndims <= 4; ++ndims)
-            {
-                // ggml_permute will set axes of dimensions below n_dims to 1.
-                // to make ggml_permute work correctly on all axes,
-                // the input tensor needs maximal n_dim of 4.
-                for (int i=0; i finite differences should not work
-                // instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0)
-                struct ggml_tensor * f = ggml_sum(ctx0,
-                                            ggml_log(ctx0,
-                                                ggml_add1(ctx0,
-                                                    ggml_scale(ctx0,
-                                                        ggml_soft_max(ctx0, x[0]),
-                                                        1.0f - eps),
-                                                    ggml_new_f32(ctx0, eps))));
-
-                check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY, {});
-                // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
-                // this may result in different gradients too finite differences.
-                // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
-                // if only the table lookup causes gradients to differ this is acceptable.
-            }
-        }
-
-        // cross_entropy_loss
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-
-            for (int ndims = 1; ndims <= 4; ++ndims) {
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-                x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
-                // the second argument to cross_entropy_loss must sum up to 1 for each row
-                int nr = ggml_nrows(x[1]);
-                int nc = ggml_nelements(x[1]) / nr;
-                for (int ir = 0; ir < nr; ++ir) {
-                    float sum = 0;
-                    for (int ic = 0; ic < nc; ++ic) {
-                        sum += ((float *) x[1]->data)[ic + ir*nc];
-                    }
-                    for (int ic = 0; ic < nc; ++ic) {
-                        ((float *) x[1]->data)[ic + ir*nc] /= sum;
-                    }
-                }
-                ggml_set_param(ctx0, x[0]);
-
-                struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]);
-
-                check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // rope f32
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-            ne2[0] += ne2[0] % 2;
-            int n_rot = ne2[0];
-
-            for (int ndims = 3; ndims <= 4; ++ndims) {
-                for (int mode = 0; mode < 4; ++mode) {
-                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
-                        x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
-
-                        struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
-                        for (int i = 0; i < ne2[2]; ++i) {
-                            ((int32_t *) p->data)[i] = n_past + i;
-                        }
-
-                        ggml_set_param(ctx0, x[0]);
-
-                        const bool skip_past = (mode & 1);
-                        if (skip_past) {
-                            // we have no past, so this would have to work on uninitialized memory.
-                            // we only test the gradients here;
-                            // skip_past should have no influence on gradient computation.
-                            // so when other modes work, we assume that this does as well.
-                            continue;
-                        }
-
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
-
-                        GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
-                        check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
-                    }
-                }
-            }
-        }
-
-        // rope f16
-        {
-            srand(seed);
-            const int nargs = 1;
-
-            int64_t ne2[4];
-            get_random_dims(ne2, 4);
-            ne2[0] += ne2[0] % 2;
-            int n_rot = ne2[0];
-
-            for (int ndims = 3; ndims <= 4; ++ndims) {
-                for (int mode = 0; mode < 4; ++mode) {
-                    for (int n_past = 1; n_past < ne2[2]; ++n_past) {
-                        x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
-
-                        struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]);
-                        for (int i = 0; i < ne2[2]; ++i) {
-                            ((int32_t *) p->data)[i] = n_past + i;
-                        }
-
-                        ggml_set_param(ctx0, x[0]);
-
-                        const bool skip_past = (mode & 1);
-                        if (skip_past) {
-                            // we have no past, so this would have to work on uninitialized memory.
-                            // we only test the gradients here;
-                            // skip_past should have no influence on gradient computation.
-                            // so when other modes work, we assume that this does as well.
-                            continue;
-                        }
-
-                        struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode));
-
-                        GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
-                        check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {});
-                    }
-                }
-            }
-        }
-
-        // im2col f32
-        {
-            srand(seed);
-            const int nargs = 1;
-            const int ndims = 4;
-
-            for (const bool is_2D : {false, true}) {
-                int64_t ne0[ndims];
-                int64_t ne1[ndims];
-                get_random_dims(ne0, ndims);
-                get_random_dims(ne1, ndims);
-
-                // // Ensure that the output is not zero-sized:
-                ne1[0] += 8;
-                ne1[1] += 8;
-
-                if (is_2D) {
-                    ne1[2] = ne0[2];
-                } else {
-                    ne1[1] = ne0[1];
-                    ne0[3] = 1;
-                    ne1[3] = 1;
-                }
-
-                // The order of arguments is swapped because the first tensor is only used for its shape.
-                x[1] = get_random_tensor_f16(ctx0, ndims, ne0, -1.0f, 1.0f);
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne1, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int s0 =         1 + irand(2);
-                const int s1 = is_2D ? 1 + irand(2) : 0;
-                const int p0 =         0 + irand(2);
-                const int p1 = is_2D ? 0 + irand(2) : 0;
-                const int d0 =         1 + irand(2);
-                const int d1 = is_2D ? 1 + irand(2) : 0;
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_im2col(ctx0, x[1], x[0], s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F32));
-
-                GGML_PRINT_DEBUG("im2col f32: is_2D=%s, s0=%d, s1=%d, p0=%d, p1=%d, d0=%d, d1=%d\n", is_2D ? "yes" : "no", s0, s1, p0, p1, d0, d1);
-                check_gradient("im2col f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {});
-            }
-        }
-
-        // pool_2d f32
-        {
-            srand(seed);
-            const int nargs = 1;
-            const int ndims = 4;
-
-            for (const enum ggml_op_pool op : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
-                int64_t ne0[ndims];
-                get_random_dims(ne0, ndims);
-
-                ne0[0] += 8;
-                ne0[1] += 8;
-
-                x[0] = get_random_tensor_f32(ctx0, ndims, ne0, -1.0f, 1.0f);
-
-                ggml_set_param(ctx0, x[0]);
-
-                const int k0 = 2 + irand(2);
-                const int k1 = 2 + irand(2);
-                const int s0 = 2 + irand(2);
-                const int s1 = 2 + irand(2);
-                const int p0 = 0 + irand(2);
-                const int p1 = 0 + irand(2);
-
-                struct ggml_tensor * f = ggml_sum(ctx0, ggml_pool_2d(ctx0, x[0], op, k0, k1, s0, s1, p0, p1));
-
-                GGML_PRINT_DEBUG("ggml_pool_2d f32: op=%s k0=%d, k1=%d, s0=%d, s1=%d, p0=%d, p1=%d\n",
-                                 op == GGML_OP_POOL_MAX ? "max" : "avg", k0, k1, s0, s1, p0, p1);
-                std::vector expected_vals;
-                if (op == GGML_OP_POOL_MAX) {
-                    expected_vals.push_back(0.0);
-                    expected_vals.push_back(1.0);
-                }
-                check_gradient("ggml_pool_2d f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, expected_vals);
-            }
-        }
-
-        // flash_attn f32
-        // TODO: adapt to ggml_flash_attn_ext() changes
-        //{
-        //    srand(seed);
-        //    const int nargs = 3;
-
-        //    int64_t ne2[4];
-
-        //    get_random_dims(ne2, 4);
-        //    int64_t D = ne2[0];
-        //    int64_t N = ne2[1];
-        //    int64_t M = ne2[2] + N;
-        //    int64_t B = ne2[3];
-
-        //    for (int masked = 0; masked <= 1; ++masked) {
-        //        for (int ndims = 2; ndims <= 4; ++ndims) {
-        //            int max_nrep = (ndims >= 3) ? 2 : 1;
-        //            for (int nrep = 1; nrep < max_nrep; ++nrep) {
-        //                int64_t neq[4] = { D, N, B*nrep, ne[3] };
-        //                int64_t nek[4] = { D, M, B, ne[3] };
-        //                int64_t nev[4] = { M, D, B, ne[3] };
-        //                if (ndims == 2) {
-        //                    neq[2] = 1; neq[3] = 1;
-        //                    nek[2] = 1; nek[3] = 1;
-        //                    nev[2] = 1; nev[3] = 1;
-        //                } else if (ndims == 3) {
-        //                    neq[3] = 1;
-        //                    nek[3] = 1;
-        //                    nev[3] = 1;
-        //                }
-        //                x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
-        //                x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
-        //                x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
-        //                ggml_set_param(ctx0, x[0]);
-        //                ggml_set_param(ctx0, x[1]);
-        //                ggml_set_param(ctx0, x[2]);
-
-        //                struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
-
-        //                check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY, {});
-        //            }
-        //        }
-        //    }
-        //}
-
-        ggml_free(ctx0);
-    }
-
-    return 0;
-}
diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp
index 5cc0cdb04..288e08f51 100644
--- a/tests/test-grammar-integration.cpp
+++ b/tests/test-grammar-integration.cpp
@@ -13,7 +13,7 @@
 using json = nlohmann::ordered_json;
 
 static llama_grammar * build_grammar(const std::string & grammar_str) {
-    return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
+    return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0);
 }
 
 static bool test_build_grammar_fails(const std::string & grammar_str) {
@@ -32,13 +32,10 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
 static bool match_string(const std::string & input, llama_grammar * grammar) {
     const auto cpts = unicode_cpts_from_utf8(input);
 
-    const llama_grammar_rules  & rules      = llama_grammar_get_rules (grammar);
-          llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
+    auto & stacks_cur = llama_grammar_get_stacks(grammar);
 
     for (const auto & cpt : cpts) {
-        const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
-
-        llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
+        llama_grammar_accept(grammar, cpt);
 
         if (stacks_cur.empty()) {
             // no stacks means that the grammar failed to match at this point
@@ -63,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
     auto * grammar = build_grammar(grammar_str);
 
     // Save the original grammar stacks so that we can reset after every new string we want to test
-    const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar);
+    const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); // copy
 
     llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
 
diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp
index 3a89598a8..9d2db91f5 100755
--- a/tests/test-json-schema-to-grammar.cpp
+++ b/tests/test-json-schema-to-grammar.cpp
@@ -696,7 +696,7 @@ static void test_all(const std::string & lang, std::function grammar_rules(parsed_grammar.c_rules());
 
-    grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
-    if (grammar == nullptr)
-    {
+    llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+    if (grammar == nullptr) {
         throw std::runtime_error("Failed to initialize llama_grammar");
     }
 
diff --git a/tests/test-log.cpp b/tests/test-log.cpp
new file mode 100644
index 000000000..306f28c61
--- /dev/null
+++ b/tests/test-log.cpp
@@ -0,0 +1,39 @@
+#include "log.h"
+
+#include 
+#include 
+
+int main() {
+    const int n_thread = 8;
+
+    std::thread threads[n_thread];
+    for (int i = 0; i < n_thread; i++) {
+        threads[i] = std::thread([i]() {
+            const int n_msg = 1000;
+
+            for (int j = 0; j < n_msg; j++) {
+                const int log_type = std::rand() % 4;
+
+                switch (log_type) {
+                    case 0: LOG_INF("Thread %d: %d\n", i, j); break;
+                    case 1: LOG_WRN("Thread %d: %d\n", i, j); break;
+                    case 2: LOG_ERR("Thread %d: %d\n", i, j); break;
+                    case 3: LOG_DBG("Thread %d: %d\n", i, j); break;
+                    default:
+                        break;
+                }
+
+                if (rand () % 10 < 5) {
+                    common_log_set_timestamps(common_log_main(), rand() % 2);
+                    common_log_set_prefix    (common_log_main(), rand() % 2);
+                }
+            }
+        });
+    }
+
+    for (int i = 0; i < n_thread; i++) {
+        threads[i].join();
+    }
+
+    return 0;
+}
diff --git a/tests/test-lora-conversion-inference.sh b/tests/test-lora-conversion-inference.sh
index fe90ce0d1..1d1f4886c 100755
--- a/tests/test-lora-conversion-inference.sh
+++ b/tests/test-lora-conversion-inference.sh
@@ -10,11 +10,16 @@ declare -a params=(
 
 MODELS_REPO=lora-tests
 MODELS_REPO_URL=https://huggingface.co/ggml-org/$MODELS_REPO
+COMMIT=c26d5fb85b4070a9e9c4e65d132c783b98086890
 
 # Clone the Hugging Face repository if the directory does not exist
 if [ ! -d "$MODELS_REPO" ]; then
     echo "Cloning the Hugging Face repository..."
     git clone $MODELS_REPO_URL --depth 1
+    cd $MODELS_REPO
+    git fetch --depth=1 origin $COMMIT
+    git reset --hard $COMMIT
+    cd -
 else
     echo "Repository already exists. Skipping clone."
 fi
@@ -75,18 +80,18 @@ run_conversion_and_inference_lora() {
     # Run inference
     echo -e "\n\n---------------------------\n\n"
     echo "Running llama-cli without lora for $model_name with hidden_size $hidden_size..."
-    OUTPUT_BASE=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
+    OUTPUT_BASE=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
         -p "$EXPECTED_BASE_FIRST_WORD" -n 50 --seed 42 --temp 0)
 
     echo -e "\n\n---------------------------\n\n"
     echo "Running llama-cli with hot lora for $model_name with hidden_size $hidden_size..."
-    OUTPUT_LORA_HOT=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
+    OUTPUT_LORA_HOT=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
         --lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf \
         -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
 
     echo -e "\n\n---------------------------\n\n"
     echo "Running llama-cli with merged lora for $model_name with hidden_size $hidden_size..."
-    OUTPUT_LORA_MERGED=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
+    OUTPUT_LORA_MERGED=$(./llama-cli -no-cnv -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
         -p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
 
     # Remove any initial white space
diff --git a/tests/test-model-load-cancel.cpp b/tests/test-model-load-cancel.cpp
index 858535c3c..9095826fa 100644
--- a/tests/test-model-load-cancel.cpp
+++ b/tests/test-model-load-cancel.cpp
@@ -21,7 +21,7 @@ int main(int argc, char *argv[] ) {
         (void) ctx;
         return progress > 0.50;
     };
-    auto * model = llama_load_model_from_file(model_path, params);
+    auto * model = llama_model_load_from_file(model_path, params);
     llama_backend_free();
     return model == nullptr ? EXIT_SUCCESS : EXIT_FAILURE;
 }
diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp
index 546ca230b..f90c92b4b 100644
--- a/tests/test-opt.cpp
+++ b/tests/test-opt.cpp
@@ -1,181 +1,892 @@
 #include "ggml.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include "ggml-cpu.h"
+#include "ggml-opt.h"
 
 #include 
-#include 
-#include 
-#include 
+#include 
+#include 
+#include 
+#include 
+#include 
 
-#define MAX_NARGS 2
-
-#if defined(__GNUC__)
-#pragma GCC diagnostic ignored "-Wdouble-promotion"
-#endif
-
-//
-// logging
-//
-#define GGML_DEBUG 0
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
-#if (GGML_DEBUG >= 5)
-#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_5(...)
-#endif
-
-#if (GGML_DEBUG >= 10)
-#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG_10(...)
-#endif
-
-#define GGML_PRINT(...) printf(__VA_ARGS__)
-
-
-static float frand(void) {
-    return (float)rand()/(float)RAND_MAX;
+static bool almost_equal(const double a, const double b, const double atol) {
+    return fabs(a - b) < atol;
 }
 
-static struct ggml_tensor * get_random_tensor(
-    struct ggml_context * ctx0, int ndims, int64_t ne[], float fmin, float fmax
-) {
-    struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne);
+constexpr int64_t ne_datapoint = 2;
+constexpr int64_t ne_label     = 1;
+constexpr int64_t ndata        = 6;
 
-    switch (ndims) {
-        case 1:
-            for (int i0 = 0; i0 < ne[0]; i0++) {
-                ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin;
+struct helper_ctx_data {
+    std::vector   datasets_supervised;
+    std::vector data_batch;
+    std::vector labels_batch;
+
+    ggml_opt_dataset_t       dataset_unsupervised;
+    struct ggml_context    * ctx_static;
+    struct ggml_context    * ctx_compute;
+    struct ggml_opt_params   opt_params;
+    ggml_opt_context_t       opt_ctx;
+    struct ggml_tensor     * inputs;
+    struct ggml_tensor     * weights;
+    struct ggml_tensor     * outputs;
+    ggml_backend_buffer_t    buf;
+    ggml_opt_result_t        result;
+    ggml_opt_result_t        result2;
+};
+
+// These default values make it easier to check optimization results vs. expected values.
+static ggml_opt_optimizer_params helper_get_test_opt_pars(void * userdata) {
+    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
+    result.adamw.alpha = 1.0f;
+    result.adamw.beta1 = 0.0f;
+    result.adamw.beta2 = 0.0f;
+    result.adamw.eps   = 0.0f;
+    return result;
+}
+
+static helper_ctx_data helper_get_ctx_data(
+        ggml_backend_sched_t    backend_sched,
+        ggml_backend_t          backend,
+        const bool              init_opt_ctx       = true,
+        const bool              optimizer_defaults = true,
+        int64_t                 nbatch_logical     = 1,
+        int64_t                 nbatch_physical    = 1,
+        enum ggml_opt_loss_type loss_type          = GGML_OPT_LOSS_TYPE_SUM) {
+    std::vector datasets(ndata);
+    for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
+        ggml_opt_dataset_t dataset = ggml_opt_dataset_init(ne_datapoint, ne_label, ndata, ndata_shard);
+
+        float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
+        float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
+
+        for (int64_t idata = 0; idata < ndata; ++idata) {
+            for (int64_t id = 0; id < ne_datapoint; ++id) {
+                data[  idata*ne_datapoint + id] =     16*idata + id;
             }
-            break;
-        case 2:
-            for (int i1 = 0; i1 < ne[1]; i1++) {
-                for (int i0 = 0; i0 < ne[0]; i0++) {
-                    ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
-                }
+            for (int64_t il = 0; il < ne_label;     ++il) {
+                labels[idata*ne_label     + il] = 16*(16*idata + il);
             }
-            break;
-        case 3:
-            for (int i2 = 0; i2 < ne[2]; i2++) {
-                for (int i1 = 0; i1 < ne[1]; i1++) {
-                    for (int i0 = 0; i0 < ne[0]; i0++) {
-                        ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+        }
+
+        datasets[ndata_shard-1] = dataset;
+    }
+
+    ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(1, 0, ndata, /*ndata_shard =*/ 1);
+
+    float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised));
+
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        data[idata] = idata;
+    }
+
+    struct ggml_context * ctx_static;
+    struct ggml_context * ctx_compute;
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ (2*ndata + 2)*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_static = ggml_init(params);
+    }
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_compute = ggml_init(params);
+    }
+
+    std::vector   data_batch(ndata);
+    std::vector labels_batch(ndata);
+    for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) {
+        data_batch[ndata_batch-1]   = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_datapoint);
+        labels_batch[ndata_batch-1] = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_label);
+    }
+
+    struct ggml_tensor * inputs = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, nbatch_physical);
+    ggml_set_name(inputs, "inputs");
+
+    struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
+    ggml_set_name(weights, "weights");
+    ggml_set_param(ctx_static, weights);
+
+    struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights);
+
+    struct ggml_tensor * outputs = ggml_scale(ctx_compute, intermediary, 1.0f);
+    ggml_set_name(outputs, "outputs");
+
+    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
+    const float w0 = float(ndata)/2;
+    ggml_backend_tensor_set(weights, &w0, 0, sizeof(float));
+
+    GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
+    const int32_t opt_period = nbatch_logical / nbatch_physical;
+
+    struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
+    opt_params.opt_period = opt_period;
+    if (!optimizer_defaults) {
+        opt_params.get_opt_pars = helper_get_test_opt_pars;
+    }
+    ggml_opt_context_t opt_ctx = init_opt_ctx ? ggml_opt_init(opt_params) : nullptr;
+
+    ggml_opt_result_t result  = ggml_opt_result_init();
+    ggml_opt_result_t result2 = ggml_opt_result_init();
+
+    return {datasets, data_batch, labels_batch, dataset_unsupervised, ctx_static, ctx_compute, opt_params, opt_ctx, inputs, weights, outputs, buf, result, result2};
+}
+
+static void helper_free_ctx_data(struct helper_ctx_data ctx_data) {
+    ggml_opt_result_free(ctx_data.result);
+    ggml_opt_result_free(ctx_data.result2);
+    ggml_opt_free(ctx_data.opt_ctx);
+    ggml_backend_buffer_free(ctx_data.buf);
+    ggml_free(ctx_data.ctx_static);
+    ggml_free(ctx_data.ctx_compute);
+    for (ggml_opt_dataset_t dataset : ctx_data.datasets_supervised) {
+        ggml_opt_dataset_free(dataset);
+    }
+    ggml_opt_dataset_free(ctx_data.dataset_unsupervised);
+}
+
+static void helper_after_test(
+        const char * func, const bool high_level, const std::string options,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    printf("  %s(high_level=%s%s, subtest=%s): ",
+           func, high_level ? "yes" : "no", options.c_str(), subtest.c_str());
+    if (subtest_ok) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+}
+
+static std::pair test_dataset(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend);
+
+    for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
+        ggml_opt_dataset_t dataset = cd.datasets_supervised[ndata_shard-1];
+
+        if (shuffle) {
+            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        }
+
+        for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) {
+            if (ndata_batch % ndata_shard != 0) {
+                continue;
+            }
+            bool subtest_ok = true;
+
+            struct ggml_tensor *   data_batch =   cd.data_batch[ndata_batch-1];
+            struct ggml_tensor * labels_batch = cd.labels_batch[ndata_batch-1];
+
+            std::vector   data(ggml_nelements(  data_batch));
+            std::vector labels(ggml_nelements(labels_batch));
+
+            std::vector idata_shuffled;
+            const int64_t nbatches = ndata / ndata_batch;
+            for (int64_t ibatch = 0; ibatch < nbatches; ++ibatch) {
+                ggml_opt_dataset_get_batch(dataset, data_batch, labels_batch, ibatch);
+
+                ggml_backend_tensor_get(  data_batch,   data.data(), 0, ggml_nbytes(  data_batch));
+                ggml_backend_tensor_get(labels_batch, labels.data(), 0, ggml_nbytes(labels_batch));
+
+                for (int64_t idata_batch = 0; idata_batch < ndata_batch; ++idata_batch) {
+                    const int64_t idata = ibatch*ndata_batch + idata_batch;
+                    const int64_t idata_found = data[idata_batch*ne_datapoint] / 16;
+                    subtest_ok = subtest_ok && (shuffle || idata_found == idata);
+                    idata_shuffled.push_back(idata_found);
+
+                    for (int64_t id = 0; id < ne_datapoint; ++id) {
+                        if (data[  idata_batch*ne_datapoint + id] != 16*idata_found + id) {
+                            subtest_ok = false;
+                        }
                     }
-                }
-            }
-            break;
-        case 4:
-            for (int i3 = 0; i3 < ne[3]; i3++) {
-                for (int i2 = 0; i2 < ne[2]; i2++) {
-                    for (int i1 = 0; i1 < ne[1]; i1++) {
-                        for (int i0 = 0; i0 < ne[0]; i0++) {
-                            ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin;
+                    for (int64_t il = 0; il < ne_label;     ++il) {
+                        if (labels[idata_batch*ne_label     + il] != 16*(16*idata_found + il)) {
+                            subtest_ok = false;
                         }
                     }
                 }
             }
-            break;
-        default:
-            assert(false);
+
+            if (!shuffle || ndata % ndata_batch == 0) {
+                const int ndata_max = (ndata / ndata_batch) * ndata_batch;
+
+                for (int64_t idata = 0; subtest_ok && idata < ndata_max; ++idata) {
+                    int ninstances = 0;
+                    for (int64_t id : idata_shuffled) {
+                        ninstances += id == idata;
+                    }
+                    if (ninstances != 1) {
+                        subtest_ok = false;
+                    }
+                }
+            }
+
+            printf("  %s(shuffle=%s, ndata_shard=%" PRId64 ", ndata_batch=%" PRId64 "): ",
+                   __func__, shuffle ? "yes" : "no", ndata_shard, ndata_batch);
+            if (subtest_ok) {
+                printf("\033[1;32mOK\033[0m\n");
+                npass++;
+            } else {
+                printf("\033[1;31mFAIL\033[0m\n");
+            }
+            ntest++;
+        }
     }
 
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static std::pair test_grad(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false,
+    /*nbatch_logical =*/ 999999, /*nbatch_physical =*/ 1);
+
+    std::vector grad_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        grad_history[idata] = NAN;
+    }
+
+    for (int idata = 0; idata < ndata; ++idata) {
+        const float idataf = idata;
+        ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+        ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+        ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));
+    }
+
+    {
+        bool subtest_ok = true;
+        for (int idata = 0; idata < ndata; ++idata) {
+            if (grad_history[idata] != idata + 1) {
+                subtest_ok = false;
+            }
+        }
+        printf("  %s(): ", __func__);
+        if (subtest_ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static void helper_after_test_forward_backward(
+        const char * func, const bool high_level, const bool shuffle,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    std::string options = ", shuffle=";
+    options += shuffle ? "yes" : "no";
+    helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
+}
+
+static std::pair test_forward_backward(
+        ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level, const bool shuffle) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
+    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
+
+    std::vector loss_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        loss_history[idata] = NAN;
+    }
+
+    {
+        int64_t ndata;
+        ggml_opt_result_ndata(cd.result, &ndata);
+        double loss;
+        double loss_unc;
+        ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+        double accuracy;
+        double accuracy_unc;
+        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+        const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass);
+    }
+
+    if (high_level) {
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+        if (shuffle) {
+            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        }
+        ggml_opt_epoch(cd.opt_ctx, dataset, nullptr, cd.result, 0, nullptr, nullptr);
+    } else {
+        for (int idata = 0; idata < ndata; ++idata) {
+            const float idataf = idata;
+            ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+            ggml_opt_forward(cd.opt_ctx, cd.result);
+            ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+        }
+    }
+
+    {
+        float weights;
+        ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+        const bool subtest_ok = weights == ndata/2;
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass);
+    }
+    {
+        int64_t ndata;
+        ggml_opt_result_ndata(cd.result, &ndata);
+        bool subtest_ok = ndata == 6;
+
+        double loss;
+        double loss_unc;
+        ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+        subtest_ok = subtest_ok && loss == 33.0 && almost_equal(loss_unc, sqrt(3.5), 1e-10);
+
+        double accuracy;
+        double accuracy_unc;
+        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+        subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "results_after_forward", subtest_ok, ntest, npass);
+    }
+
+    float w0;
+    ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));
+    for (int i = 0; i < 10; ++i) {
+        ggml_opt_forward_backward(cd.opt_ctx, nullptr);
+    }
+    ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));
+
+    ggml_opt_reset(cd.opt_ctx, /*optimizer =*/ false);
+    ggml_opt_result_reset(cd.result);
+
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        loss_history[idata] = NAN;
+    }
+
+    if (high_level) {
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+        if (shuffle) {
+            ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        }
+        ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);
+    } else {
+        for (int idata = 0; idata < ndata; ++idata) {
+            const float idataf = idata;
+            ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+            ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+            ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+        }
+    }
+
+    {
+        float weights;
+        ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+        const bool subtest_ok = weights == -ndata/2;
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
+    }
+    {
+        int64_t ndata;
+        ggml_opt_result_ndata(cd.result, &ndata);
+        bool subtest_ok = ndata == 6;
+
+        double loss;
+        double loss_unc;
+        ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+        subtest_ok = subtest_ok && loss == 18.0 && (shuffle || loss_unc == 0.0);
+
+        double accuracy;
+        double accuracy_unc;
+        ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+        subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+        helper_after_test_forward_backward(__func__, high_level, shuffle, "result_after_forward_backward", subtest_ok, ntest, npass);
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static std::pair test_epoch_vs_fit(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int ntest = 0;
+    int npass = 0;
+
+    float weights_epoch;
+    float weights_fit;
+
+    {
+        struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true);
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+
+        ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
+        ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);
+
+        ggml_backend_tensor_get(cd.weights, &weights_epoch, 0, ggml_nbytes(cd.weights));
+        helper_free_ctx_data(cd);
+    }
+    {
+        struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ false);
+        ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
+
+        ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset,
+            GGML_OPT_LOSS_TYPE_SUM, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true);
+
+        ggml_backend_tensor_get(cd.weights, &weights_fit, 0, ggml_nbytes(cd.weights));
+        helper_free_ctx_data(cd);
+    }
+
+    const bool subtest_ok = weights_epoch == weights_fit;
+
+    printf("  %s(): ", __func__);
+    if (subtest_ok) {
+        printf("\033[1;32mOK\033[0m\n");
+        npass++;
+    } else {
+        printf("\033[1;31mFAIL\033[0m\n");
+    }
+    ntest++;
+
+    return std::make_pair(npass, ntest);
+}
+
+static void helper_after_test_idata_split(
+        const char * func, const bool high_level, const int epoch,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    std::string options = ", epoch=";
+    options += std::to_string(epoch);
+    helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
+}
+
+static std::pair test_idata_split(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
+    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
+    const int idata_split = ndata * 2/3;
+
+    std::vector loss_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        loss_history[idata] = NAN;
+    }
+
+    for (int epoch = 1; epoch <= 4; ++epoch) {
+        if (high_level) {
+            ggml_opt_epoch(cd.opt_ctx, cd.dataset_unsupervised, cd.result, cd.result2, idata_split, nullptr, nullptr);
+        } else {
+            int idata = 0;
+            for (; idata < idata_split; ++idata) {
+                const float idataf = idata;
+                ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+            }
+            for (; idata < ndata; ++idata) {
+                const float idataf = idata;
+                ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
+                ggml_opt_forward(cd.opt_ctx, cd.result2);
+                ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
+            }
+        }
+
+        {
+            float weights;
+            ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+            const bool subtest_ok = weights == ndata/2 - epoch*idata_split;
+            helper_after_test_idata_split(__func__, high_level, epoch, "weights", subtest_ok, ntest, npass);
+        }
+        {
+            int64_t ndata_result;
+            ggml_opt_result_ndata(cd.result, &ndata_result);
+            bool subtest_ok = ndata_result == idata_split;
+
+            double loss;
+            double loss_unc;
+            ggml_opt_result_loss(cd.result, &loss, &loss_unc);
+            subtest_ok = subtest_ok && loss == 28.0 - epoch*16.0 && loss_unc == 0.0;
+
+            double accuracy;
+            double accuracy_unc;
+            ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+            helper_after_test_idata_split(__func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass);
+        }
+        {
+            int64_t ndata_result;
+            ggml_opt_result_ndata(cd.result2, &ndata_result);
+            bool subtest_ok = ndata_result == ndata - idata_split;
+
+            double loss;
+            double loss_unc;
+            ggml_opt_result_loss(cd.result2, &loss, &loss_unc);
+            subtest_ok = subtest_ok && loss == 15.0 - epoch*8 && almost_equal(loss_unc, sqrt(0.5), 1e-10);
+
+            double accuracy;
+            double accuracy_unc;
+            ggml_opt_result_accuracy(cd.result2, &accuracy, &accuracy_unc);
+            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+            helper_after_test_idata_split(__func__, high_level, epoch, "results_forward", subtest_ok, ntest, npass);
+        }
+
+        ggml_opt_result_reset(cd.result);
+        ggml_opt_result_reset(cd.result2);
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static void helper_after_test_gradient_accumulation(
+        const char * func, const int nbatch_physical, const enum ggml_opt_loss_type loss_type, const int epoch,
+        const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
+    std::string options = ", nbatch_physical=";
+    options += std::to_string(nbatch_physical);
+    options += ", loss_type=";
+    options += loss_type == GGML_OPT_LOSS_TYPE_MEAN ? "mean" : "sum";
+    options += ", epoch=";
+    options += std::to_string(epoch);
+    helper_after_test(func, false, options, subtest, subtest_ok, ntest, npass);
+}
+
+static std::pair test_gradient_accumulation(
+        ggml_backend_sched_t backend_sched, ggml_backend_t backend, const int32_t nbatch_physical, const enum ggml_opt_loss_type loss_type) {
+    int ntest = 0;
+    int npass = 0;
+
+    struct helper_ctx_data cd = helper_get_ctx_data(
+        backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);
+    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
+
+    std::vector grad_history(ndata);
+    for (int64_t idata = 0; idata < ndata; ++idata) {
+        grad_history[idata] = NAN;
+    }
+
+    for (int epoch = 1; epoch <= 4; ++epoch) {
+        if (nbatch_physical == 1) {
+            for (int idata = 0; idata < ndata; ++idata) {
+                const float idataf = idata;
+                ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float));
+                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float));
+            }
+        } else if (nbatch_physical == 2) {
+            for (int idata = 0; idata < ndata; idata += 2) {
+                const float idataf[2] = {float(idata + 0), float(idata + 1)};
+                ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float));
+                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+
+                grad_history[idata + 0] = 0.0f;
+                ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float));
+            }
+        } else {
+            GGML_ASSERT(false);
+        }
+
+        {
+            GGML_ASSERT(ndata == 6);
+            constexpr double atol = 1e-6;
+            bool subtest_ok = true;
+            if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {
+                if (nbatch_physical == 1) {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0, atol);
+                } else {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0, atol);
+                }
+                subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0, atol);
+            } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
+                if (nbatch_physical == 1) {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0/ndata, atol);
+                } else {
+                    subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0/ndata, atol);
+                    subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0/ndata, atol);
+                }
+                subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0/ndata, atol);
+            } else {
+                GGML_ASSERT(false);
+            }
+            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "grads", subtest_ok, ntest, npass);
+        }
+        {
+            float weights;
+            ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
+            const bool subtest_ok = weights == (ndata/2) - epoch;
+            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass);
+        }
+        {
+            int64_t ndata_result;
+            ggml_opt_result_ndata(cd.result, &ndata_result);
+            bool subtest_ok = ndata_result == ndata/nbatch_physical;
+
+            double loss;
+            ggml_opt_result_loss(cd.result, &loss, /*loss_unc =*/ nullptr);
+            if (loss_type == GGML_OPT_LOSS_TYPE_SUM) {
+                subtest_ok = subtest_ok && loss == (39.0 - epoch*6.0);
+            } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
+                subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, 1e-6);
+            } else {
+                GGML_ASSERT(false);
+            }
+
+            double accuracy;
+            double accuracy_unc;
+            ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
+            subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
+
+            helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "results", subtest_ok, ntest, npass);
+        }
+
+        ggml_opt_result_reset(cd.result);
+    }
+
+    helper_free_ctx_data(cd);
+
+    return std::make_pair(npass, ntest);
+}
+
+static ggml_opt_optimizer_params helper_get_regression_opt_pars(void * userdata) {
+    ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
+    result.adamw.alpha = 0.1f;
     return result;
 }
 
-int main(void) {
-    struct ggml_init_params params = {
-        /* .mem_size   = */ 1024*1024*1024,
-        /* .mem_buffer = */ NULL,
-        /* .no_alloc   = */ false,
-    };
+static std::pair test_regression(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int ntest = 0;
+    int npass = 0;
 
-    struct ggml_context * ctx = ggml_init(params);
+    // Test for simple regression with f(x) = a*x + b
 
-    int64_t ne1[4] = {4, 128, 1, 1};
-    int64_t ne2[4] = {4, 256, 1, 1};
-    int64_t ne3[4] = {128, 256, 1, 1};
+    constexpr int64_t ndata_regression = 201;
+    constexpr float a_true = 1.2f;
+    constexpr float b_true = 3.4f;
 
-    struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
-    struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);
-    ggml_set_param(ctx, a);
-    ggml_set_param(ctx, b);
+    std::mt19937 gen(12345);
+    std::normal_distribution nd{0.0f, 0.1f};
 
-    struct ggml_tensor * c = get_random_tensor(ctx, 2, ne3, -1, +1);
+    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(1, 1, ndata_regression, ndata_regression);
 
-    struct ggml_tensor * ab = ggml_mul_mat(ctx, a, b);
-    struct ggml_tensor * d  = ggml_sub(ctx, c, ab);
-    struct ggml_tensor * e  = ggml_sum(ctx, ggml_sqr(ctx, d));
+    float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
+    float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
 
-    struct ggml_cgraph * ge = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true);
-    ggml_build_forward_expand(ge, e);
-    ggml_graph_reset(ge);
+    constexpr float x_min = -100.0f;
+    constexpr float x_max =  100.0f;
 
-    ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
+    for (int64_t idata = 0; idata < ndata_regression; ++idata) {
+        const float x = x_min + (x_max - x_min) * idata/(ndata_regression-1);
+        const float y = a_true*x + b_true + nd(gen);
 
-    const float fe = ggml_get_f32_1d(e, 0);
-    printf("%s: e = %.4f\n", __func__, fe);
+        data[idata]   = x;
+        labels[idata] = y;
+    }
 
-    struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
+    struct ggml_context * ctx_static;
+    struct ggml_context * ctx_compute;
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ 3*ggml_tensor_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_static = ggml_init(params);
+    }
+    {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(),
+            /*.mem_buffer =*/ nullptr,
+            /*.no_alloc   =*/ true,
+        };
+        ctx_compute = ggml_init(params);
+    }
 
-    ggml_opt(ctx, opt_params, e);
+    // The first dimension is the dimension of the datapoints, the second dimension is the number of datapoints.
+    struct ggml_tensor * x = ggml_new_tensor_2d(ctx_static, GGML_TYPE_F32, 1, ndata_regression);
+    ggml_set_name(x, "x");
 
-    ggml_graph_reset(ge);
+    struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
+    ggml_set_name(a, "a");
+    ggml_set_param(ctx_static, a);
 
-    ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
+    struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
+    ggml_set_name(b, "b");
+    ggml_set_param(ctx_static, b);
 
-    const float fe_opt = ggml_get_f32_1d(e, 0);
-    printf("%s: original  e = %.4f\n", __func__, fe);
-    printf("%s: optimized e = %.4f\n", __func__, fe_opt);
+    struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b);
+    ggml_set_name(f, "f");
+    ggml_set_param(ctx_static, f);
 
-    const bool success = (fe_opt <= fe);
-    assert(success);
+    ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
+    const float a0 = 1.0f;
+    const float b0 = 3.0f;
+    ggml_backend_tensor_set(a, &a0, 0, sizeof(float));
+    ggml_backend_tensor_set(b, &b0, 0, sizeof(float));
 
-    ggml_free(ctx);
-    return success ? 0 : -1;
+    ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
+        helper_get_regression_opt_pars, 100, ndata_regression, 0.0f, true);
+
+    {
+        float a_fit;
+        ggml_backend_tensor_get(a, &a_fit, 0, sizeof(float));
+        float b_fit;
+        ggml_backend_tensor_get(b, &b_fit, 0, sizeof(float));
+        const bool subtest_ok = almost_equal(a_fit, a_true, 1e-2) && almost_equal(b_fit, b_true, 1e-2);
+        printf("  %s(subtest=weights): ", __func__);
+        if (subtest_ok) {
+            printf("\033[1;32mOK\033[0m\n");
+            npass++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+        ntest++;
+    }
+
+    ggml_backend_buffer_free(buf);
+    ggml_free(ctx_static);
+    ggml_opt_dataset_free(dataset);
+
+    return std::make_pair(npass, ntest);
 }
-// int64_t ne1[4] = {4, 128, 1, 1};
-// int64_t ne2[4] = {4, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 25890.9375
-// main: optimized e = 10094.7031
 
-// int64_t ne1[4] = {8, 128, 1, 1};
-// int64_t ne2[4] = {8, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 39429.5078
-// main: optimized e = 9275.8936
+static std::pair test_backend(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
+    int npass = 0;
+    int ntest = 0;
 
-// int64_t ne1[4] = {16, 128, 1, 1};
-// int64_t ne2[4] = {16, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 68371.1328
-// main: optimized e = 7854.4502
+    for (bool shuffle : {false, true}) {
+        std::pair partial = test_dataset(backend_sched, backend, shuffle);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    {
+        std::pair partial = test_grad(backend_sched, backend);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    for (bool high_level : {false, true}){
+        for (bool shuffle : {false, true}) {
+            if (!high_level && shuffle) {
+                continue;
+            }
 
+            std::pair partial = test_forward_backward(backend_sched, backend, high_level, shuffle);
+            npass += partial.first;
+            ntest += partial.second;
+        }
+    }
+    {
+        std::pair partial = test_epoch_vs_fit(backend_sched, backend);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    for (bool high_level : {false, true}){
+        std::pair partial = test_idata_split(backend_sched, backend, high_level);
+        npass += partial.first;
+        ntest += partial.second;
+    }
+    for (int32_t nbatch_physical : {2, 1}) {
+        for (enum ggml_opt_loss_type loss_type : {GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN}) {
+            std::pair partial = test_gradient_accumulation(backend_sched, backend, nbatch_physical, loss_type);
+            npass += partial.first;
+            ntest += partial.second;
+        }
+    }
+    {
+        std::pair partial = test_regression(backend_sched, backend);
+        npass += partial.first;
+        ntest += partial.second;
+    }
 
-// int64_t ne1[4] = {32, 128, 1, 1};
-// int64_t ne2[4] = {32, 256, 1, 1};;
-// int64_t ne3[4] = {128, 256, 1, 1};
-// main: original  e = 126061.1953
-// main: optimized e = 5451.0166
+    return std::make_pair(npass, ntest);
+}
 
-// int64_t ne1[4] = {4, 1024, 1, 1};
-// int64_t ne2[4] = {4, 2048, 1, 1};;
-// int64_t ne3[4] = {1024, 2048, 1, 1};
-// main: original  e = 1620817.8750
-// main: optimized e = 698387.6875
+int main(void) {
+    const size_t dev_count = ggml_backend_dev_count();
+    printf("Testing %zu devices\n\n", dev_count);
+    size_t n_ok = 0;
 
-// another run on M1
-// int64_t ne1[4] = {4, 1024, 1, 1};
-// int64_t ne2[4] = {4, 2048, 1, 1};;
-// int64_t ne3[4] = {1024, 2048, 1, 1};
-// main: original  e = 1629595.6250
-// main: optimized e = 698169.1250
+    std::vector devs;
+    std::vector     backends;
 
-// int64_t ne1[4] = {32, 1024, 1, 1};
-// int64_t ne2[4] = {32, 2048, 1, 1};;
-// int64_t ne3[4] = {1024, 2048, 1, 1};
-// main: original  e = 8146770.5000
-// main: optimized e = 651119.1250
+    for (size_t i = 0; i < dev_count; ++i) {
+        devs.push_back(ggml_backend_dev_get(i));
+
+        ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL);
+        GGML_ASSERT(backend != NULL);
+
+        if (ggml_backend_is_cpu(backend)) {
+            ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
+        }
+
+        backends.push_back(backend);
+    }
+
+    for (size_t i = 0; i < dev_count; ++i) {
+        // Put the backend to be tested in front so that it's prioritized:
+        std::vector backends_modded = {backends[i]};
+        backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
+
+        ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
+            backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false);
+
+        printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i]));
+        printf("  Device description: %s\n", ggml_backend_dev_description(devs[i]));
+        size_t free, total; // NOLINT
+        ggml_backend_dev_memory(devs[i], &free, &total);
+        printf("  Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
+        printf("\n");
+
+        std::pair result = test_backend(backend_sched, backends[i]);
+
+        printf("  %d/%d tests passed\n", result.first, result.second);
+        printf("  Backend %s: ", ggml_backend_name(backends[i]));
+        if (result.first == result.second) {
+            printf("\033[1;32mOK\033[0m\n");
+            n_ok++;
+        } else {
+            printf("\033[1;31mFAIL\033[0m\n");
+        }
+
+        printf("\n");
+
+        ggml_backend_sched_free(backend_sched);
+    }
+
+    for (ggml_backend_t backend : backends) {
+        ggml_backend_free(backend);
+    }
+
+    printf("%zu/%zu backends passed\n", n_ok, dev_count);
+    if (n_ok != dev_count) {
+        printf("\033[1;31mFAIL\033[0m\n");
+        return 1;
+    }
+    printf("\033[1;32mOK\033[0m\n");
+    return 0;
+}
diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp
index ccf5721a3..c77c8ed13 100644
--- a/tests/test-quantize-fns.cpp
+++ b/tests/test-quantize-fns.cpp
@@ -1,6 +1,7 @@
 // Unit tests for quantization specific functions - quantize, dequantize and dot product
 
 #include "ggml.h"
+#include "ggml-cpu.h"
 
 #undef NDEBUG
 #include 
@@ -44,26 +45,27 @@ static float array_rmse(const float * a1, const float * a2, size_t n) {
 }
 
 // Total quantization error on test data
-static float total_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
+static float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
     std::vector tmp_q(2*test_size);
     std::vector tmp_out(test_size);
 
-    qfns.from_float(test_data, tmp_q.data(), test_size);
-    qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
+    qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
+    qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
     return array_rmse(test_data, tmp_out.data(), test_size);
 }
 
 // Total quantization error on test data
-static float reference_quantization_error(ggml_type_traits_t & qfns, size_t test_size, const float * test_data) {
+static float reference_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
     std::vector tmp_q(2*test_size);
     std::vector tmp_out(test_size);
     std::vector tmp_out_ref(test_size);
 
-    qfns.from_float(test_data, tmp_q.data(), test_size);
-    qfns.to_float(tmp_q.data(), tmp_out.data(), test_size);
+    // FIXME: why is done twice?
+    qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
+    qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
 
-    qfns.from_float_ref(test_data, tmp_q.data(), test_size);
-    qfns.to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
+    qfns->from_float_ref(test_data, tmp_q.data(), test_size);
+    qfns->to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
 
     return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
 }
@@ -77,19 +79,19 @@ static float dot_product(const float * a1, const float * a2, size_t test_size) {
 }
 
 // Total dot product error
-static float dot_product_error(
-    ggml_type_traits_t & qfns, size_t test_size, const float * test_data1, const float *test_data2
-) {
+static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2) {
+    GGML_UNUSED(qfns);
+
     std::vector tmp_q1(2*test_size);
     std::vector tmp_q2(2*test_size);
 
-    auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
+    const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
 
-    qfns.from_float(test_data1, tmp_q1.data(), test_size);
-    vdot.from_float(test_data2, tmp_q2.data(), test_size);
+    qfns_cpu->from_float(test_data1, tmp_q1.data(), test_size);
+    vdot->from_float(test_data2, tmp_q2.data(), test_size);
 
     float result = INFINITY;
-    qfns.vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
+    qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
 
     const float dot_ref = dot_product(test_data1, test_data2, test_size);
 
@@ -131,10 +133,11 @@ int main(int argc, char * argv[]) {
 
     for (int i = 0; i < GGML_TYPE_COUNT; i++) {
         ggml_type type = (ggml_type) i;
-        ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
+        const auto * qfns = ggml_get_type_traits(type);
+        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
 
         // deprecated - skip
-        if (qfns.blck_size == 0) {
+        if (qfns->blck_size == 0) {
             continue;
         }
 
@@ -143,8 +146,8 @@ int main(int argc, char * argv[]) {
         printf("Testing %s\n", ggml_type_name((ggml_type) i));
         ggml_quantize_init(ei);
 
-        if (qfns.from_float && qfns.to_float) {
-            const float total_error = total_quantization_error(qfns, test_size, test_data.data());
+        if (qfns_cpu->from_float && qfns->to_float) {
+            const float total_error = total_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
             const float max_quantization_error =
                 type == GGML_TYPE_TQ1_0   ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
                 type == GGML_TYPE_TQ2_0   ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
@@ -159,14 +162,14 @@ int main(int argc, char * argv[]) {
                 printf("%5s absolute quantization error:    %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
             }
 
-            const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
+            const float reference_error = reference_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
             failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
             num_failed += failed;
             if (failed || verbose) {
                 printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
             }
 
-            const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
+            const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
             const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
                                             type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
                                           ? MAX_DOT_PRODUCT_ERROR_LOWBIT
diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp
index 24e066053..288288493 100644
--- a/tests/test-quantize-perf.cpp
+++ b/tests/test-quantize-perf.cpp
@@ -1,12 +1,12 @@
 // Benchmark quantization specific functions on synthetic data
 
 #include "ggml.h"
+#include "ggml-cpu.h"
 
 #undef NDEBUG
 #include 
 #include 
 #include 
-#include 
 #include 
 #include 
 #include 
@@ -122,9 +122,10 @@ static void usage(char * argv[]) {
     printf("  --type TYPE           set test type as");
     for (int i = 0; i < GGML_TYPE_COUNT; i++) {
         ggml_type type = (ggml_type) i;
-        ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
+        const auto * qfns     = ggml_get_type_traits(type);
+        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
         if (ggml_type_name(type) != NULL) {
-            if (qfns.from_float && qfns.to_float) {
+            if (qfns_cpu->from_float && qfns->to_float) {
                 printf(" %s", ggml_type_name(type));
             }
         }
@@ -270,12 +271,13 @@ int main(int argc, char * argv[]) {
 
     for (int i = 0; i < GGML_TYPE_COUNT; i++) {
         ggml_type type = (ggml_type) i;
-        ggml_type_traits_t qfns = ggml_internal_get_type_traits(type);
+        const auto * qfns = ggml_get_type_traits(type);
+        const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
         if (!params.include_types.empty() && ggml_type_name(type) && std::find(params.include_types.begin(), params.include_types.end(), ggml_type_name(type)) == params.include_types.end()) {
             continue;
         }
 
-        if (qfns.from_float && qfns.to_float) {
+        if (qfns_cpu->from_float && qfns->to_float) {
             printf("%s\n", ggml_type_name(type));
 
             ggml_quantize_init(type);
@@ -285,7 +287,7 @@ int main(int argc, char * argv[]) {
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void) -> float {
-                        qfns.from_float_ref(test_data1, test_q1, size);
+                        qfns->from_float_ref(test_data1, test_q1, size);
                         return test_q1[0];
                     };
                     size_t quantized_size = ggml_row_size(type, size);
@@ -299,7 +301,7 @@ int main(int argc, char * argv[]) {
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void) -> float {
-                        qfns.from_float(test_data1, test_q1, size);
+                        qfns_cpu->from_float(test_data1, test_q1, size);
                         return test_q1[0];
                     };
                     size_t quantized_size = ggml_row_size(type, size);
@@ -310,11 +312,11 @@ int main(int argc, char * argv[]) {
 
             if (params.op_dequantize_row_q) {
                 printf("  dequantize_row_q\n");
-                qfns.from_float(test_data1, test_q1, largest);
+                qfns_cpu->from_float(test_data1, test_q1, largest);
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void) -> float {
-                        qfns.to_float(test_q1, test_out, size);
+                        qfns->to_float(test_q1, test_out, size);
                         return test_out[0];
                     };
                     size_t quantized_size = ggml_row_size(type, size);
@@ -328,8 +330,8 @@ int main(int argc, char * argv[]) {
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void) -> float {
-                        auto vdot = ggml_internal_get_type_traits(qfns.vec_dot_type);
-                        vdot.from_float(test_data1, test_q1, size);
+                        const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
+                        vdot->from_float(test_data1, test_q1, size);
                         return test_q1[0];
                     };
                     size_t quantized_size = ggml_row_size(type, size);
@@ -340,13 +342,13 @@ int main(int argc, char * argv[]) {
 
             if (params.op_vec_dot_q) {
                 printf("  vec_dot_q\n");
-                qfns.from_float(test_data1, test_q1, largest);
-                qfns.from_float(test_data2, test_q2, largest);
+                qfns_cpu->from_float(test_data1, test_q1, largest);
+                qfns_cpu->from_float(test_data2, test_q2, largest);
                 for (size_t size : params.test_sizes) {
                     printf("    %zu values (%.2f MB)\n", size, 4*size/(float)(1024*1024));
                     auto quantize_fn = [&](void) -> float {
                         float result;
-                        qfns.vec_dot(size, &result, 0, test_q1, 0, test_q2, 0, 1);
+                        qfns_cpu->vec_dot(size, &result, 0, test_q1, 0, test_q2, 0, 1);
                         return result;
                     };
                     size_t quantized_size = ggml_row_size(type, size);
diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp
index 246bb227d..322b8bb99 100644
--- a/tests/test-rope.cpp
+++ b/tests/test-rope.cpp
@@ -1,4 +1,5 @@
 #include "ggml.h"
+#include "ggml-cpu.h"
 
 #include 
 #include 
@@ -137,7 +138,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
     struct ggml_tensor * x;
 
     // rope f32
-    for (int m = 0; m < 3; ++m) {
+    for (int m = 0; m < 5; ++m) {
         const int ndims = 4;
 
         const int64_t n_rot = 128;
@@ -146,28 +147,69 @@ int main(int /*argc*/, const char ** /*argv*/) {
         const int n_past_0 = 100;
         const int n_past_2 = 33;
 
-        struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
-        struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
-        struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
-
-        for (int i = 0; i < ne[2]; ++i) {
-            ((int32_t *) p0->data)[i] = n_past_0 + i;
-            ((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
-            ((int32_t *) p2->data)[i] = n_past_2 + i;
-        }
-
-        // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
-        const int mode = m == 0 ? 0 : m == 1 ? 2 : 4;
-
+        struct ggml_tensor * r0;
+        struct ggml_tensor * r1;
+        struct ggml_tensor * r2;
         x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
+        int mode = -1;
 
-        // 100, 101, 102, ..., 172
-        struct ggml_tensor * r0 = ggml_rope(ctx0, x,  p0, n_rot, mode);
-        // -67, -67, -67, ..., -67
-        struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
+        if (m < 3) {
+            struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
+            struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
+            struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
 
-        //  33,  34,  35, ..., 105
-        struct ggml_tensor * r2 = ggml_rope(ctx0, x,  p2, n_rot, mode);
+            for (int i = 0; i < ne[2]; ++i) {
+                ((int32_t *) p0->data)[i] = n_past_0 + i;
+                ((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
+                ((int32_t *) p2->data)[i] = n_past_2 + i;
+            }
+            // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
+            mode = m == 0 ? 0 : m == 1 ? 2 : 4;
+
+            // 100, 101, 102, ..., 172
+            r0 = ggml_rope(ctx0, x,  p0, n_rot, mode);
+            // -67, -67, -67, ..., -67
+            r1 = ggml_rope(ctx0, r0, p1, n_rot, mode); // "context swap", i.e. forget n_past_0 - n_past_2 tokens
+
+            //  33,  34,  35, ..., 105
+            r2 = ggml_rope(ctx0, x,  p2, n_rot, mode);
+        } else {
+            // testing multi-dimension rope position embedding mode
+            struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
+            struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
+            struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
+
+            int sections[4] = {16, 24, 24, 0};
+            mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : GGML_ROPE_TYPE_VISION;
+
+            for (int i = 0; i < ne[2]; ++i) {
+                for (int j = 0; j < 4; ++j) {
+                    ((int32_t *) p0->data)[i + ne[2] * j] = n_past_0 + i + j;
+                    ((int32_t *) p1->data)[i + ne[2] * j] = n_past_2 - n_past_0;
+                    ((int32_t *) p2->data)[i + ne[2] * j] = n_past_2 + i + j;
+                }
+            }
+
+            // [[100, 101, 102, ..., 172],
+            // [101, 102, 103, ..., 173],
+            // [102, 103, 104, ..., 174]]
+            r0 = ggml_rope_multi(
+                ctx0, x, p0, nullptr,
+                n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
+            // [[-67, -67, -67, ..., -67]
+            // [-67, -67, -67, ..., -67]
+            // [-67, -67, -67, ..., -67]]
+            r1 = ggml_rope_multi(
+                ctx0, r0, p1, nullptr,
+                n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
+
+            //  [[33,  34,  35, ..., 105]
+            //  [34,  35,  36, ..., 106]
+            //  [35,  36,  37, ..., 107]]
+            r2 = ggml_rope_multi(
+                ctx0, x, p2, nullptr,
+                n_rot, sections, mode, 32768, 1000000, 1, 0, 1, 32, 1);
+        }
 
         ggml_cgraph * gf = ggml_new_graph(ctx0);
 
diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp
index 37400c179..61bd67850 100644
--- a/tests/test-sampling.cpp
+++ b/tests/test-sampling.cpp
@@ -1,6 +1,5 @@
 #include "ggml.h"
 #include "llama.h"
-#include "llama-sampling.h"
 
 #ifdef NDEBUG
 #undef NDEBUG
@@ -11,6 +10,8 @@
 #include 
 #include 
 
+extern struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers);
+
 static void dump(const llama_token_data_array * cur_p) {
     for (size_t i = 0; i < cur_p->size; i++) {
         printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
@@ -19,181 +20,187 @@ static void dump(const llama_token_data_array * cur_p) {
 
 #define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0)
 
-#define APPLY(__cnstr, __cur_p) do { \
-    auto * cnstr = (__cnstr); \
-    llama_sampler_apply(cnstr, (__cur_p)); \
-    llama_sampler_free(cnstr); \
-} while(0)
+struct sampler_tester {
+    sampler_tester(size_t n_vocab) {
+        cur.reserve(n_vocab);
+        for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
+            const float logit = logf(token_id);
+            cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
+        }
 
-static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) {
-    const size_t n_vocab = probs.size();
+        cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
+    }
+
+    sampler_tester(const std::vector & probs, const std::vector & probs_expected) : probs_expected(probs_expected) {
+        cur.reserve(probs.size());
+        for (llama_token token_id = 0; token_id < (llama_token)probs.size(); token_id++) {
+            const float logit = logf(probs[token_id]);
+            cur.emplace_back(llama_token_data{token_id, logit, probs[token_id]});
+        }
+
+        cur_p = llama_token_data_array { cur.data(), cur.size(), -1, false };
+    }
+
+    void apply(llama_sampler * sampler) {
+        llama_sampler_apply(sampler, &cur_p);
+        llama_sampler_free(sampler);
+    }
+
+    void check() {
+        GGML_ASSERT(cur_p.size == probs_expected.size());
+        for (size_t i = 0; i < cur_p.size; i++) {
+            GGML_ASSERT(fabs(cur_p.data[i].p - probs_expected[i]) < 1e-5);
+        }
+    }
+
+    llama_token_data_array cur_p;
+
+private:
+    const std::vector probs_expected;
 
     std::vector cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+};
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_top_k(k), &cur_p);
-    DUMP(&cur_p);
+static void test_temp(const std::vector & probs, const std::vector & probs_expected, float temp) {
+    sampler_tester tester(probs, probs_expected);
 
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5);
-    }
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_temp(temp));
+    tester.apply(llama_sampler_init_dist(0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
 }
 
-static void test_top_p(const std::vector & probs, const std::vector & expected_probs, float p) {
-    const size_t n_vocab = probs.size();
+static void test_temp_ext(const std::vector & probs, const std::vector & probs_expected, float temp, float delta, float exponent) {
+    sampler_tester tester(probs, probs_expected);
 
-    std::vector cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_top_p(p, 1), &cur_p);
-    DUMP(&cur_p);
-
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
-    }
+    tester.check();
 }
 
-static void test_tfs(const std::vector & probs, const std::vector & expected_probs, float z) {
-    const size_t n_vocab = probs.size();
+static void test_top_k(const std::vector & probs, const std::vector & probs_expected, int k) {
+    sampler_tester tester(probs, probs_expected);
 
-    std::vector cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_top_k(k));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_tail_free(z, 1), &cur_p);
-    DUMP(&cur_p);
-
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
-    }
+    tester.check();
 }
 
-static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) {
-    const size_t n_vocab = probs.size();
+static void test_top_p(const std::vector & probs, const std::vector & probs_expected, float p) {
+    sampler_tester tester(probs, probs_expected);
 
-    std::vector cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_top_p(p, 1));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_min_p(p, 1), &cur_p);
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
-    }
+    tester.check();
 }
 
-static void test_typical(const std::vector & probs, const std::vector & expected_probs, float p) {
-    const size_t n_vocab = probs.size();
+static void test_min_p(const std::vector & probs, const std::vector & probs_expected, float p) {
+    sampler_tester tester(probs, probs_expected);
 
-    std::vector cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_min_p(p, 1));
+    tester.apply(llama_sampler_init_dist (0));
+    DUMP(&tester.cur_p);
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-    DUMP(&cur_p);
-    APPLY(llama_sampler_init_typical(p, 1), &cur_p);
-    DUMP(&cur_p);
+    tester.check();
+}
 
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
-    }
+static void test_xtc(const std::vector & probs, const std::vector & probs_expected, float p, float t) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_xtc(p, t, 0, 0));
+    DUMP(&tester.cur_p);
+
+    tester.check();
+}
+
+static void test_typical(const std::vector & probs, const std::vector & probs_expected, float p) {
+    sampler_tester tester(probs, probs_expected);
+
+    DUMP(&tester.cur_p);
+    tester.apply(llama_sampler_init_typical(p, 1));
+    DUMP(&tester.cur_p);
+
+    tester.check();
 }
 
 static void test_penalties(
     const std::vector & probs, const std::vector & last_tokens,
-    const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
+    const std::vector & probs_expected, float repeat_penalty, float alpha_frequency, float alpha_presence
 ) {
-    GGML_ASSERT(probs.size() == expected_probs.size());
+    GGML_ASSERT(probs.size() == probs_expected.size());
 
-    const size_t n_vocab = probs.size();
+    sampler_tester tester(probs, probs_expected);
 
-    std::vector cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(probs[token_id]);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
-
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
-
-    auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
+    auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
 
     for (size_t i = 0; i < last_tokens.size(); i++) {
         llama_sampler_accept(sampler, last_tokens[i]);
     }
 
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-    DUMP(&cur_p);
-    APPLY(sampler, &cur_p);
-    APPLY(llama_sampler_init_softmax(), &cur_p);
-    DUMP(&cur_p);
+    DUMP(&tester.cur_p);
+    tester.apply(sampler);
+    tester.apply(llama_sampler_init_dist(0));
+    DUMP(&tester.cur_p);
 
-    GGML_ASSERT(cur_p.size == expected_probs.size());
-    for (size_t i = 0; i < cur_p.size; i++) {
-        GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3);
+    tester.check();
+}
+
+static void test_dry(
+    const std::vector & probs, const std::vector & last_tokens,
+    const std::vector & expected_probs, float dry_multiplier, float dry_base,
+    int dry_allowed_length, int dry_penalty_last_n,
+    const std::vector> & seq_breakers
+) {
+    GGML_ASSERT(probs.size() == expected_probs.size());
+
+    sampler_tester tester(probs, expected_probs);
+
+    auto * sampler = llama_sampler_init_dry_testing(1024, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers);
+
+    for (size_t i = 0; i < last_tokens.size(); i++) {
+        llama_sampler_accept(sampler, last_tokens[i]);
     }
+
+    DUMP(&tester.cur_p);
+    tester.apply(sampler);
+    tester.apply(llama_sampler_init_dist(0));
+    DUMP(&tester.cur_p);
+    tester.check();
 }
 
 static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
 ) {
-    std::vector cur;
-    cur.reserve(n_vocab);
-    for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
-        const float logit = logf(token_id);
-        cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
-    }
-
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    sampler_tester tester(n_vocab);
 
           llama_token min_token_id = 0;
     const llama_token max_token_id = n_vocab-1;
 
     for (auto s : samplers_sequence) {
         switch (s){
-            case 'k': APPLY(llama_sampler_init_top_k(top_k), &cur_p); break;
-            case 'f': GGML_ABORT("tail_free test not implemented");
+            case 'k': tester.apply(llama_sampler_init_top_k(top_k)); break;
             case 'y': GGML_ABORT("typical test not implemented");
-            case 'p': APPLY(llama_sampler_init_top_p(top_p, 1), &cur_p); break;
-            case 'm': APPLY(llama_sampler_init_min_p(min_p, 1), &cur_p); break;
+            case 'p': tester.apply(llama_sampler_init_top_p(top_p, 1)); break;
+            case 'm': tester.apply(llama_sampler_init_min_p(min_p, 1)); break;
             case 't': GGML_ABORT("temperature test not implemented");
             default : GGML_ABORT("Unknown sampler");
         }
 
-        APPLY(llama_sampler_init_softmax(), &cur_p); // make sure tokens are sorted for tests
+        tester.apply(llama_sampler_init_dist(0));
+
+        auto & cur_p = tester.cur_p;
 
         const int size = cur_p.size;
 
@@ -245,22 +252,66 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler
         }
     }
 
-    printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n",
+    printf("Sampler queue %3s OK with n_vocab=%05zu top_k=%05d top_p=%f min_p=%f\n",
            samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
 }
 
+static void bench(llama_sampler * cnstr, const char * cnstr_name, const std::vector & data, int n_iter) {
+    std::vector cur(data.size());
+    std::copy(data.begin(), data.end(), cur.begin());
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+    llama_sampler_apply(cnstr, &cur_p);
+    llama_sampler_reset(cnstr);
+    const int64_t t_start = ggml_time_us();
+    for (int i = 0; i < n_iter; i++) {
+        std::copy(data.begin(), data.end(), cur.begin());
+        llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+        llama_sampler_apply(cnstr, &cur_p);
+        llama_sampler_reset(cnstr);
+    }
+    const int64_t t_end = ggml_time_us();
+    llama_sampler_free(cnstr);
+    printf("%-43s: %8.3f us/iter\n", cnstr_name, (t_end - t_start) / (float)n_iter);
+}
+
+#define BENCH(__cnstr, __data, __n_iter) bench((__cnstr), #__cnstr, (__data), (__n_iter))
+
+static void test_perf() {
+    const int n_vocab = 1 << 17;
+
+    std::vector data;
+
+    data.reserve(n_vocab);
+    for (int i = 0; i < n_vocab; i++) {
+        const float logit = 2.0f*((double)(rand())/RAND_MAX - 0.5);
+        data.emplace_back(llama_token_data{i, logit, 0.0f});
+    }
+
+    BENCH(llama_sampler_init_top_k  (40),                     data, 32);
+    BENCH(llama_sampler_init_top_p  (0.8f, 1),                data, 32);
+    BENCH(llama_sampler_init_min_p  (0.2f, 1),                data, 32);
+    BENCH(llama_sampler_init_typical(0.5f, 1),                data, 32);
+    BENCH(llama_sampler_init_xtc    (1.0f, 0.1f, 1, 1),       data, 32);
+}
+
 int main(void) {
     ggml_time_init();
 
-    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
-    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
+    test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
+    test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
+
+    test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f);
+    test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f);
+
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
     test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
     test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
 
-    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
-    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
-    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
-    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 0);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f}, 0.7f);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 0.8f);
+    test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
 
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
@@ -271,9 +322,13 @@ int main(void) {
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  0.76f);
     test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f},                                  1.00f);
 
-    test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
-    test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
-    test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
+    printf("XTC should:\n");
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.1f},                                0.99f, 0.09f);
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.2f, 0.1f},                          0.99f, 0.19f);
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.3f, 0.2f, 0.1f},                    0.99f, 0.29f);
+
+    printf("XTC should not:\n");
+    test_xtc({0.4f, 0.3f, 0.2f, 0.1f},   {0.4f, 0.3f, 0.2f, 0.1f},              0.99f, 0.39f);
 
     test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
     test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
@@ -286,6 +341,13 @@ int main(void) {
     test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2},       {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
     test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
 
+
+    test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1}, {0.25f, 0.25f, 0.25f, 0.25f}, 1.0f, 1.1f, 2, 4, {});
+    test_dry({0.25f, 0.25f, 0.25f, 0.25f}, {0, 1, 2, 0, 1}, {0.296923f, 0.296923f, 0.296923f, 0.109232f}, 1.0f, 1.1f, 2, 5, {});
+    test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 2, 6, {{3}});
+    test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
+    test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
+
     test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
     test_sampler_queue(10000, "k",     1, 1.0f, 1.0f);
     test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
@@ -316,5 +378,7 @@ int main(void) {
 
     printf("OK\n");
 
+    test_perf();
+
     return 0;
 }
diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp
index d3d21331b..59dda4877 100644
--- a/tests/test-tokenizer-0.cpp
+++ b/tests/test-tokenizer-0.cpp
@@ -7,6 +7,7 @@
 #include 
 #include 
 #include 
+#include 
 
 //static const std::map> & k_tests() {
 //    static std::map> _k_tests = {
@@ -151,7 +152,7 @@ int main(int argc, char **argv) {
 
         mparams.vocab_only = true;
 
-        model = llama_load_model_from_file(fname.c_str(), mparams);
+        model = llama_model_load_from_file(fname.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@@ -160,11 +161,11 @@ int main(int argc, char **argv) {
 
         auto cparams = llama_context_default_params();
 
-        ctx = llama_new_context_with_model(model, cparams);
+        ctx = llama_init_from_model(model, cparams);
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
-            llama_free_model(model);
+            llama_model_free(model);
             return 1;
         }
     }
@@ -194,45 +195,64 @@ int main(int argc, char **argv) {
 
     const bool add_special = false;
 
-    for (const auto & test_kv : k_tests) {
-        const std::vector res = llama_tokenize(ctx, test_kv.first, add_special, false);
+    // multi-threaded tokenization
+    const int nthread = std::thread::hardware_concurrency();
+    std::vector threads(nthread);
 
-        printf("\n");
-        printf("src: '%s'\n", test_kv.first.c_str());
-        printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
-        printf("tok: ");
-        for (const auto & tok : res) {
-            printf("%d ", tok);
-        }
-        printf("\n");
+    for (int i = 0; i < nthread; i++) {
+        threads[i] = std::thread([&, i]() {
+            for (const auto & test_kv : k_tests) {
+                const std::vector res = common_tokenize(ctx, test_kv.first, add_special, false);
 
-        bool correct = res.size() == test_kv.second.size();
-        for (int i = 0; i < (int) res.size() && correct; ++i) {
-            if (test_kv.second[i] != res[i]) {
-                correct = false;
+                // here only print the result of the first thread
+                // because the other threads are running the same tests
+                if (i != 0) {
+                    continue;
+                }
+
+                printf("\n");
+                printf("src: '%s'\n", test_kv.first.c_str());
+                printf("res: '%s'\n", common_detokenize(ctx, res).c_str());
+                printf("tok: ");
+                for (const auto & tok : res) {
+                    printf("%d ", tok);
+                }
+                printf("\n");
+
+                bool correct = res.size() == test_kv.second.size();
+                for (int i = 0; i < (int) res.size() && correct; ++i) {
+                    if (test_kv.second[i] != res[i]) {
+                        correct = false;
+                    }
+                }
+
+                if (!correct) {
+                    fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
+                    fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
+                        common_detokenize(ctx, res).c_str(),
+                        common_detokenize(ctx, test_kv.second).c_str());
+                    fprintf(stderr, "%s : expected tokens: ", __func__);
+                    for (const auto & t : test_kv.second) {
+                        fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
+                    }
+                    fprintf(stderr, "\n");
+                    fprintf(stderr, "%s : got tokens:      ", __func__);
+                    for (const auto & t : res) {
+                        fprintf(stderr, "%6d '%s', ", t, common_token_to_piece(ctx, t).c_str());
+                    }
+                    fprintf(stderr, "\n");
+
+                    success = false;
+                }
             }
-        }
-
-        if (!correct) {
-            fprintf(stderr, "%s : failed test:    '%s'\n", __func__, test_kv.first.c_str());
-            fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
-                llama_detokenize(ctx, res).c_str(),
-                llama_detokenize(ctx, test_kv.second).c_str());
-            fprintf(stderr, "%s : expected tokens: ", __func__);
-            for (const auto & t : test_kv.second) {
-                fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
-            }
-            fprintf(stderr, "\n");
-            fprintf(stderr, "%s : got tokens:      ", __func__);
-            for (const auto & t : res) {
-                fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
-            }
-            fprintf(stderr, "\n");
-
-            success = false;
-        }
+        });
     }
 
+    for (int i = 0; i < nthread; i++) {
+        threads[i].join();
+    }
+
+    // single threaded tokenization
     if (!fname_text.empty()) {
         fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
 
@@ -253,7 +273,7 @@ int main(int argc, char **argv) {
         {
             const auto t_start = ggml_time_us();
 
-            res = llama_tokenize(ctx, text, add_special, false);
+            res = common_tokenize(ctx, text, add_special, false);
 
             const auto t_end = ggml_time_us();
 
@@ -280,7 +300,7 @@ int main(int argc, char **argv) {
         fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str());
     }
 
-    llama_free_model(model);
+    llama_model_free(model);
     llama_free(ctx);
 
     llama_backend_free();
diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp
index 9498387e0..55425d88a 100644
--- a/tests/test-tokenizer-1-bpe.cpp
+++ b/tests/test-tokenizer-1-bpe.cpp
@@ -46,7 +46,7 @@ int main(int argc, char **argv) {
 
         mparams.vocab_only = true;
 
-        model = llama_load_model_from_file(fname.c_str(), mparams);
+        model = llama_model_load_from_file(fname.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@@ -55,17 +55,19 @@ int main(int argc, char **argv) {
 
         auto cparams = llama_context_default_params();
 
-        ctx = llama_new_context_with_model(model, cparams);
+        ctx = llama_init_from_model(model, cparams);
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
-            llama_free_model(model);
+            llama_model_free(model);
             return 1;
         }
     }
 
-    //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE);
-    if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    //GGML_ASSERT(llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_BPE);
+    if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_BPE) {
         return 99;
     }
 
@@ -75,13 +77,13 @@ int main(int argc, char **argv) {
     atexit([]() { console::cleanup(); });
 #endif
 
-    const int n_vocab = llama_n_vocab(model);
+    const int n_vocab = llama_vocab_n_tokens(vocab);
 
     for (int i = 0; i < n_vocab; ++i) {
-        std::string str = llama_detokenize(ctx, std::vector(1, i));
+        std::string str = common_detokenize(ctx, std::vector(1, i));
         try {
             auto cps = unicode_cpts_from_utf8(str);
-            std::vector tokens = llama_tokenize(ctx, str, false, true);
+            std::vector tokens = common_tokenize(ctx, str, false, true);
             if (ignore_merges && tokens.size() > 1) {
                 fprintf(stderr,
                         "%s : error: token %d detokenizes to '%s'(%zu) but "
@@ -94,7 +96,7 @@ int main(int argc, char **argv) {
                 fprintf(stderr, "]\n");
                 return 2;
             }
-            std::string check = llama_detokenize(ctx, tokens);
+            std::string check = common_detokenize(ctx, tokens);
             if (check != str) {
                 fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
                     __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
@@ -123,8 +125,8 @@ int main(int argc, char **argv) {
                     }
 
                     std::string str = unicode_cpt_to_utf8(cp);
-                    std::vector tokens = llama_tokenize(ctx, str, false);
-                    std::string check = llama_detokenize(ctx, tokens);
+                    std::vector tokens = common_tokenize(ctx, str, false);
+                    std::string check = common_detokenize(ctx, tokens);
                     if (cp != 9601 && str != check) {
                         fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
                                 cp, check.c_str(), check.length(), str.c_str(), str.length());
@@ -143,7 +145,7 @@ int main(int argc, char **argv) {
         }
     }
 
-    llama_free_model(model);
+    llama_model_free(model);
     llama_free(ctx);
 
     llama_backend_free();
diff --git a/tests/test-tokenizer-1-spm.cpp b/tests/test-tokenizer-1-spm.cpp
index 7ca9e2ca6..9e7b77f31 100644
--- a/tests/test-tokenizer-1-spm.cpp
+++ b/tests/test-tokenizer-1-spm.cpp
@@ -34,7 +34,7 @@ int main(int argc, char ** argv) {
 
         mparams.vocab_only = true;
 
-        model = llama_load_model_from_file(fname.c_str(), mparams);
+        model = llama_model_load_from_file(fname.c_str(), mparams);
 
         if (model == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@@ -43,17 +43,19 @@ int main(int argc, char ** argv) {
 
         auto cparams = llama_context_default_params();
 
-        ctx = llama_new_context_with_model(model, cparams);
+        ctx = llama_init_from_model(model, cparams);
 
         if (ctx == NULL) {
             fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
-            llama_free_model(model);
+            llama_model_free(model);
             return 1;
         }
     }
 
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
     //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
-    if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_SPM) {
+    if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_SPM) {
         return 99;
     }
 
@@ -63,12 +65,12 @@ int main(int argc, char ** argv) {
     atexit([]() { console::cleanup(); });
 #endif
 
-    const int n_vocab = llama_n_vocab(model);
+    const int n_vocab = llama_vocab_n_tokens(vocab);
 
     for (int i = 0; i < n_vocab; ++i) {
-        std::string str = llama_detokenize(ctx, std::vector(1, i), true);
-        std::vector tokens = llama_tokenize(ctx, str, false, true);
-        std::string check = llama_detokenize(ctx, tokens);
+        std::string str = common_detokenize(ctx, std::vector(1, i), true);
+        std::vector tokens = common_tokenize(ctx, str, false, true);
+        std::string check = common_detokenize(ctx, tokens);
         if (check != str) {
             fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
                 __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
@@ -93,8 +95,8 @@ int main(int argc, char ** argv) {
                     }
 
                     std::string str = unicode_cpt_to_utf8(cp);
-                    std::vector tokens = llama_tokenize(ctx, str, false, true);
-                    std::string check = llama_detokenize(ctx, tokens);
+                    std::vector tokens = common_tokenize(ctx, str, false, true);
+                    std::string check = common_detokenize(ctx, tokens);
                     if (cp != 9601 && str != check) {
                         fprintf(stderr, "error: codepoint 0x%x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
                                 cp, check.c_str(), check.length(), str.c_str(), str.length());
@@ -113,7 +115,7 @@ int main(int argc, char ** argv) {
         }
     }
 
-    llama_free_model(model);
+    llama_model_free(model);
     llama_free(ctx);
 
     llama_backend_free();
diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py
index 9ebe6c891..c6cdcb554 100644
--- a/tests/test-tokenizer-random.py
+++ b/tests/test-tokenizer-random.py
@@ -76,7 +76,7 @@ class LibLlamaModel:
         self.ffi = libllama.ffi
         if isinstance(mparams, dict):
             mparams = libllama.model_default_params(**mparams)
-        self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
+        self.model = self.lib.llama_model_load_from_file(path_model.encode(), mparams)
         if not self.model:
             raise RuntimeError("error: failed to load model '%s'" % path_model)
         if isinstance(cparams, dict):
@@ -92,7 +92,7 @@ class LibLlamaModel:
         if self.ctx:
             self.lib.llama_free(self.ctx)
         if self.model:
-            self.lib.llama_free_model(self.model)
+            self.lib.llama_model_free(self.model)
         self.ctx = None
         self.model = None
         self.lib = None